-
Notifications
You must be signed in to change notification settings - Fork 406
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PR #6872: [XLA:GPU] add cuDNN flash attention support in XLA (3rd PR …
…with only rewriter changes) Imported from GitHub PR #6872 This is the 3nd PR of splitting #5910 with only rewriter changes 1st PR #6293 merged. 2nd PR #6657 merged. * Add pattern match for causal mask * Add paxml dropout pattern match * Add flash attention fusion * Add flash attention support cuDNN version guard * Add tests for flash attention rewriter/e2e Copybara import of the project: -- 490d0a3 by cjkkkk <ske@nvidia.com>: init flash attention rewriter -- 90e765f by cjkkkk <ske@nvidia.com>: use while body back pointer to find causal mask -- a3e5905 by cjkkkk <ske@nvidia.com>: add gpu backend to fmha e2e tests && address some format issues -- c82c064 by cjkkkk <ske@nvidia.com>: fix rebase error -- 2f30df0 by cjkkkk <ske@nvidia.com>: Use GPT3_5B model pre rewriter HLo -- 47aceb1 by cjkkkk <ske@nvidia.com>: add flash attention cuDNN version check && restore fwd graph is dbias/mask is not supported Merging this change closes #6872 FUTURE_COPYBARA_INTEGRATE_REVIEW=#6872 from Cjkkkk:flash_attention_rewriter 47aceb1 PiperOrigin-RevId: 588714363
- Loading branch information
1 parent
6bcdd3a
commit 633fc7f
Showing
13 changed files
with
1,689 additions
and
315 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,85 +1 @@ | ||
Auto generated patch. Do not edit or delete it, even if empty. | ||
diff -ruN --strip-trailing-cr a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp | ||
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp | ||
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp | ||
@@ -5548,57 +5548,12 @@ | ||
} | ||
}; | ||
|
||
-/// Folds transpose(shape_cast) into a new shape_cast, when the transpose just | ||
-/// permutes a unit dim from the result of the shape_cast. | ||
-class FoldTransposeShapeCast : public OpRewritePattern<TransposeOp> { | ||
- using OpRewritePattern::OpRewritePattern; | ||
- | ||
- LogicalResult matchAndRewrite(TransposeOp transpOp, | ||
- PatternRewriter &rewriter) const override { | ||
- Value transposeSrc = transpOp.getVector(); | ||
- auto shapeCastOp = transposeSrc.getDefiningOp<vector::ShapeCastOp>(); | ||
- if (!shapeCastOp) | ||
- return rewriter.notifyMatchFailure( | ||
- transpOp, "TransposeOp source is not ShapeCastOp"); | ||
- | ||
- auto sourceType = transpOp.getSourceVectorType(); | ||
- auto resultType = transpOp.getResultVectorType(); | ||
- | ||
- auto filterUnitDims = [](VectorType type) { | ||
- return llvm::make_filter_range( | ||
- llvm::zip_equal(type.getShape(), type.getScalableDims()), | ||
- [&](auto dim) { | ||
- auto [size, isScalable] = dim; | ||
- return size != 1 || isScalable; | ||
- }); | ||
- }; | ||
- | ||
- auto sourceWithoutUnitDims = filterUnitDims(sourceType); | ||
- auto resultWithoutUnitDims = filterUnitDims(resultType); | ||
- | ||
- // If this transpose just permutes a unit dim, then we can fold it into the | ||
- // shape_cast. | ||
- for (auto [srcDim, resDim] : | ||
- llvm::zip_equal(sourceWithoutUnitDims, resultWithoutUnitDims)) { | ||
- if (srcDim != resDim) | ||
- return rewriter.notifyMatchFailure(transpOp, | ||
- "TransposeOp permutes non-unit dim"); | ||
- } | ||
- | ||
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resultType, | ||
- shapeCastOp.getSource()); | ||
- | ||
- return success(); | ||
- }; | ||
-}; | ||
- | ||
} // namespace | ||
|
||
void vector::TransposeOp::getCanonicalizationPatterns( | ||
RewritePatternSet &results, MLIRContext *context) { | ||
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast, | ||
- TransposeFolder, FoldTransposeSplat, FoldTransposeShapeCast>( | ||
- context); | ||
+ TransposeFolder, FoldTransposeSplat>(context); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
diff -ruN --strip-trailing-cr a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir | ||
--- a/mlir/test/Dialect/Vector/canonicalize.mlir | ||
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir | ||
@@ -67,18 +67,6 @@ | ||
|
||
// ----- | ||
|
||
-// CHECK-LABEL: transposed_unit_dim_shape_cast_to_shape_cast | ||
-// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32> | ||
-func.func @transposed_unit_dim_shape_cast_to_shape_cast(%vec: vector<[4]xf32>) -> vector<1x[4]xf32> { | ||
- // CHECK: vector.shape_cast %[[VEC]] : vector<[4]xf32> to vector<1x[4]xf32> | ||
- // CHECK-NOT: vector.transpose | ||
- %0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32> | ||
- %1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> | ||
- return %1 : vector<1x[4]xf32> | ||
-} | ||
- | ||
-// ----- | ||
- | ||
// CHECK-LABEL: extract_from_create_mask | ||
// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index | ||
func.func @extract_from_create_mask(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> { |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,85 +1 @@ | ||
Auto generated patch. Do not edit or delete it, even if empty. | ||
diff -ruN --strip-trailing-cr a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp | ||
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp | ||
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp | ||
@@ -5548,57 +5548,12 @@ | ||
} | ||
}; | ||
|
||
-/// Folds transpose(shape_cast) into a new shape_cast, when the transpose just | ||
-/// permutes a unit dim from the result of the shape_cast. | ||
-class FoldTransposeShapeCast : public OpRewritePattern<TransposeOp> { | ||
- using OpRewritePattern::OpRewritePattern; | ||
- | ||
- LogicalResult matchAndRewrite(TransposeOp transpOp, | ||
- PatternRewriter &rewriter) const override { | ||
- Value transposeSrc = transpOp.getVector(); | ||
- auto shapeCastOp = transposeSrc.getDefiningOp<vector::ShapeCastOp>(); | ||
- if (!shapeCastOp) | ||
- return rewriter.notifyMatchFailure( | ||
- transpOp, "TransposeOp source is not ShapeCastOp"); | ||
- | ||
- auto sourceType = transpOp.getSourceVectorType(); | ||
- auto resultType = transpOp.getResultVectorType(); | ||
- | ||
- auto filterUnitDims = [](VectorType type) { | ||
- return llvm::make_filter_range( | ||
- llvm::zip_equal(type.getShape(), type.getScalableDims()), | ||
- [&](auto dim) { | ||
- auto [size, isScalable] = dim; | ||
- return size != 1 || isScalable; | ||
- }); | ||
- }; | ||
- | ||
- auto sourceWithoutUnitDims = filterUnitDims(sourceType); | ||
- auto resultWithoutUnitDims = filterUnitDims(resultType); | ||
- | ||
- // If this transpose just permutes a unit dim, then we can fold it into the | ||
- // shape_cast. | ||
- for (auto [srcDim, resDim] : | ||
- llvm::zip_equal(sourceWithoutUnitDims, resultWithoutUnitDims)) { | ||
- if (srcDim != resDim) | ||
- return rewriter.notifyMatchFailure(transpOp, | ||
- "TransposeOp permutes non-unit dim"); | ||
- } | ||
- | ||
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resultType, | ||
- shapeCastOp.getSource()); | ||
- | ||
- return success(); | ||
- }; | ||
-}; | ||
- | ||
} // namespace | ||
|
||
void vector::TransposeOp::getCanonicalizationPatterns( | ||
RewritePatternSet &results, MLIRContext *context) { | ||
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast, | ||
- TransposeFolder, FoldTransposeSplat, FoldTransposeShapeCast>( | ||
- context); | ||
+ TransposeFolder, FoldTransposeSplat>(context); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
diff -ruN --strip-trailing-cr a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir | ||
--- a/mlir/test/Dialect/Vector/canonicalize.mlir | ||
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir | ||
@@ -67,18 +67,6 @@ | ||
|
||
// ----- | ||
|
||
-// CHECK-LABEL: transposed_unit_dim_shape_cast_to_shape_cast | ||
-// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32> | ||
-func.func @transposed_unit_dim_shape_cast_to_shape_cast(%vec: vector<[4]xf32>) -> vector<1x[4]xf32> { | ||
- // CHECK: vector.shape_cast %[[VEC]] : vector<[4]xf32> to vector<1x[4]xf32> | ||
- // CHECK-NOT: vector.transpose | ||
- %0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32> | ||
- %1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> | ||
- return %1 : vector<1x[4]xf32> | ||
-} | ||
- | ||
-// ----- | ||
- | ||
// CHECK-LABEL: extract_from_create_mask | ||
// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index | ||
func.func @extract_from_create_mask(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> { |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.