From 8de55ab8c28fa66d4246edcffd74313525215a3a Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Wed, 30 Aug 2023 14:42:27 +0100 Subject: [PATCH] [MLIR][OpenMP] Prevent CSE and constant materialization from crossing some OpenMP region boundaries Operations this patch prevents constants and common subexpressions being extracted from: - omp.target - omp.teams - omp.parallel --- mlir/include/mlir/Interfaces/CSEInterfaces.h | 34 ++++++++++++++++++++ mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 13 +++++++- mlir/lib/Transforms/CSE.cpp | 11 +++++-- 3 files changed, 55 insertions(+), 3 deletions(-) create mode 100644 mlir/include/mlir/Interfaces/CSEInterfaces.h diff --git a/mlir/include/mlir/Interfaces/CSEInterfaces.h b/mlir/include/mlir/Interfaces/CSEInterfaces.h new file mode 100644 index 000000000000..dcf9642fd089 --- /dev/null +++ b/mlir/include/mlir/Interfaces/CSEInterfaces.h @@ -0,0 +1,34 @@ +//===- CSEInterfaces.h ------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_INTERFACES_CSEINTERFACES_H_ +#define MLIR_INTERFACES_CSEINTERFACES_H_ + +#include "mlir/IR/DialectInterface.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { + +/// Define an interface to allow for dialects to control specific aspects of +/// common subexpression elimination behavior for operations they define. +class DialectCSEInterface : public DialectInterface::Base { +public: + DialectCSEInterface(Dialect *dialect) : Base(dialect) {} + + /// Registered hook to check if an operation that is *not* isolated from + /// above, should allow common subexpressions to be extracted out of its + /// regions. + virtual bool subexpressionExtractionAllowed(Operation *op) const { + return true; + } +}; + +} // namespace mlir + +#endif // MLIR_INTERFACES_CSEINTERFACES_H_ diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 2ba5f1aca9cf..73b31cace313 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/CSEInterfaces.h" #include "mlir/Interfaces/FoldInterfaces.h" #include "llvm/ADT/BitVector.h" @@ -46,12 +47,21 @@ struct PointerLikeModel } }; +struct OpenMPDialectCSEInterface : public DialectCSEInterface { + using DialectCSEInterface::DialectCSEInterface; + + bool subexpressionExtractionAllowed(Operation *op) const final { + // Avoid extracting common subexpressions across op boundaries + return !isa(op); + } +}; + struct OpenMPDialectFoldInterface : public DialectFoldInterface { using DialectFoldInterface::DialectFoldInterface; bool shouldMaterializeInto(Region *region) const final { // Avoid folding constants across target regions - return isa(region->getParentOp()); + return isa(region->getParentOp()); } }; } // namespace @@ -66,6 +76,7 @@ void OpenMPDialect::initialize() { #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" >(); + addInterface(); addInterface(); LLVM::LLVMPointerType::attachInterface< PointerLikeModel>(*getContext()); diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 3affd88d158d..fba565aa6060 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/CSEInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/Passes.h" @@ -61,7 +62,8 @@ namespace { class CSEDriver { public: CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo) - : rewriter(rewriter), domInfo(domInfo) {} + : rewriter(rewriter), domInfo(domInfo), + interfaces(rewriter.getContext()) {} /// Simplify all operations within the given op. void simplify(Operation *op, bool *changed = nullptr); @@ -122,6 +124,8 @@ class CSEDriver { DominanceInfo *domInfo = nullptr; MemEffectsCache memEffectsCache; + DialectInterfaceCollection interfaces; + // Various statistics. int64_t numCSE = 0; int64_t numDCE = 0; @@ -289,7 +293,10 @@ void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb, // If this operation is isolated above, we can't process nested regions // with the given 'knownValues' map. This would cause the insertion of // implicit captures in explicit capture only regions. - if (op.mightHaveTrait()) { + const DialectCSEInterface *cseInterface = interfaces.getInterfaceFor(&op); + if (op.mightHaveTrait() || + LLVM_UNLIKELY(cseInterface && + !cseInterface->subexpressionExtractionAllowed(&op))) { ScopedMapTy nestedKnownValues; for (auto ®ion : op.getRegions()) simplifyRegion(nestedKnownValues, region);