Skip to content
This repository has been archived by the owner on Jan 20, 2024. It is now read-only.

Commit

Permalink
[MLIR][OpenMP] Prevent CSE and constant materialization from crossing…
Browse files Browse the repository at this point in the history
… some OpenMP region boundaries

Operations this patch prevents constants and common subexpressions being
extracted from:
  - omp.target
  - omp.teams
  - omp.parallel
  • Loading branch information
skatrak committed Aug 30, 2023
1 parent 0ba9c89 commit 8de55ab
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 3 deletions.
34 changes: 34 additions & 0 deletions mlir/include/mlir/Interfaces/CSEInterfaces.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//===- CSEInterfaces.h ------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INTERFACES_CSEINTERFACES_H_
#define MLIR_INTERFACES_CSEINTERFACES_H_

#include "mlir/IR/DialectInterface.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"

namespace mlir {

/// Define an interface to allow for dialects to control specific aspects of
/// common subexpression elimination behavior for operations they define.
class DialectCSEInterface : public DialectInterface::Base<DialectCSEInterface> {
public:
DialectCSEInterface(Dialect *dialect) : Base(dialect) {}

/// Registered hook to check if an operation that is *not* isolated from
/// above, should allow common subexpressions to be extracted out of its
/// regions.
virtual bool subexpressionExtractionAllowed(Operation *op) const {
return true;
}
};

} // namespace mlir

#endif // MLIR_INTERFACES_CSEINTERFACES_H_
13 changes: 12 additions & 1 deletion mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Interfaces/CSEInterfaces.h"
#include "mlir/Interfaces/FoldInterfaces.h"

#include "llvm/ADT/BitVector.h"
Expand Down Expand Up @@ -46,12 +47,21 @@ struct PointerLikeModel
}
};

struct OpenMPDialectCSEInterface : public DialectCSEInterface {
using DialectCSEInterface::DialectCSEInterface;

bool subexpressionExtractionAllowed(Operation *op) const final {
// Avoid extracting common subexpressions across op boundaries
return !isa<TargetOp, TeamsOp, ParallelOp>(op);
}
};

struct OpenMPDialectFoldInterface : public DialectFoldInterface {
using DialectFoldInterface::DialectFoldInterface;

bool shouldMaterializeInto(Region *region) const final {
// Avoid folding constants across target regions
return isa<TargetOp>(region->getParentOp());
return isa<TargetOp, TeamsOp, ParallelOp>(region->getParentOp());
}
};
} // namespace
Expand All @@ -66,6 +76,7 @@ void OpenMPDialect::initialize() {
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
>();

addInterface<OpenMPDialectCSEInterface>();
addInterface<OpenMPDialectFoldInterface>();
LLVM::LLVMPointerType::attachInterface<
PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
Expand Down
11 changes: 9 additions & 2 deletions mlir/lib/Transforms/CSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/CSEInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
Expand Down Expand Up @@ -61,7 +62,8 @@ namespace {
class CSEDriver {
public:
CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
: rewriter(rewriter), domInfo(domInfo) {}
: rewriter(rewriter), domInfo(domInfo),
interfaces(rewriter.getContext()) {}

/// Simplify all operations within the given op.
void simplify(Operation *op, bool *changed = nullptr);
Expand Down Expand Up @@ -122,6 +124,8 @@ class CSEDriver {
DominanceInfo *domInfo = nullptr;
MemEffectsCache memEffectsCache;

DialectInterfaceCollection<DialectCSEInterface> interfaces;

// Various statistics.
int64_t numCSE = 0;
int64_t numDCE = 0;
Expand Down Expand Up @@ -289,7 +293,10 @@ void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
// If this operation is isolated above, we can't process nested regions
// with the given 'knownValues' map. This would cause the insertion of
// implicit captures in explicit capture only regions.
if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
const DialectCSEInterface *cseInterface = interfaces.getInterfaceFor(&op);
if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
LLVM_UNLIKELY(cseInterface &&
!cseInterface->subexpressionExtractionAllowed(&op))) {
ScopedMapTy nestedKnownValues;
for (auto &region : op.getRegions())
simplifyRegion(nestedKnownValues, region);
Expand Down

0 comments on commit 8de55ab

Please sign in to comment.