Skip to content

Commit

Permalink
PR #6872: [XLA:GPU] add cuDNN flash attention support in XLA (3rd PR …
Browse files Browse the repository at this point in the history
…with only rewriter changes)

Imported from GitHub PR #6872

This is the 3nd PR of splitting #5910 with only rewriter changes
1st PR #6293 merged.
2nd PR #6657 merged.

* Add pattern match for causal mask
* Add paxml dropout pattern match
* Add flash attention fusion
* Add flash attention support cuDNN version guard
* Add tests for flash attention rewriter/e2e
Copybara import of the project:

--
490d0a3 by cjkkkk <ske@nvidia.com>:

init flash attention rewriter

--
90e765f by cjkkkk <ske@nvidia.com>:

use while body back pointer to find causal mask

--
a3e5905 by cjkkkk <ske@nvidia.com>:

add gpu backend to fmha e2e tests && address some format issues

--
c82c064 by cjkkkk <ske@nvidia.com>:

fix rebase error

--
2f30df0 by cjkkkk <ske@nvidia.com>:

Use GPT3_5B model pre rewriter HLo

--
47aceb1 by cjkkkk <ske@nvidia.com>:

add flash attention cuDNN version check && restore fwd graph is dbias/mask is not supported

Merging this change closes #6872

FUTURE_COPYBARA_INTEGRATE_REVIEW=#6872 from Cjkkkk:flash_attention_rewriter 47aceb1
PiperOrigin-RevId: 588714363
  • Loading branch information
Cjkkkk authored and copybara-github committed Dec 7, 2023
1 parent 6bcdd3a commit 633fc7f
Show file tree
Hide file tree
Showing 13 changed files with 1,689 additions and 315 deletions.
84 changes: 0 additions & 84 deletions third_party/llvm/generated.patch
Original file line number Diff line number Diff line change
@@ -1,85 +1 @@
Auto generated patch. Do not edit or delete it, even if empty.
diff -ruN --strip-trailing-cr a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5548,57 +5548,12 @@
}
};

-/// Folds transpose(shape_cast) into a new shape_cast, when the transpose just
-/// permutes a unit dim from the result of the shape_cast.
-class FoldTransposeShapeCast : public OpRewritePattern<TransposeOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(TransposeOp transpOp,
- PatternRewriter &rewriter) const override {
- Value transposeSrc = transpOp.getVector();
- auto shapeCastOp = transposeSrc.getDefiningOp<vector::ShapeCastOp>();
- if (!shapeCastOp)
- return rewriter.notifyMatchFailure(
- transpOp, "TransposeOp source is not ShapeCastOp");
-
- auto sourceType = transpOp.getSourceVectorType();
- auto resultType = transpOp.getResultVectorType();
-
- auto filterUnitDims = [](VectorType type) {
- return llvm::make_filter_range(
- llvm::zip_equal(type.getShape(), type.getScalableDims()),
- [&](auto dim) {
- auto [size, isScalable] = dim;
- return size != 1 || isScalable;
- });
- };
-
- auto sourceWithoutUnitDims = filterUnitDims(sourceType);
- auto resultWithoutUnitDims = filterUnitDims(resultType);
-
- // If this transpose just permutes a unit dim, then we can fold it into the
- // shape_cast.
- for (auto [srcDim, resDim] :
- llvm::zip_equal(sourceWithoutUnitDims, resultWithoutUnitDims)) {
- if (srcDim != resDim)
- return rewriter.notifyMatchFailure(transpOp,
- "TransposeOp permutes non-unit dim");
- }
-
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resultType,
- shapeCastOp.getSource());
-
- return success();
- };
-};
-
} // namespace

void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
- TransposeFolder, FoldTransposeSplat, FoldTransposeShapeCast>(
- context);
+ TransposeFolder, FoldTransposeSplat>(context);
}

//===----------------------------------------------------------------------===//
diff -ruN --strip-trailing-cr a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -67,18 +67,6 @@

// -----

-// CHECK-LABEL: transposed_unit_dim_shape_cast_to_shape_cast
-// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
-func.func @transposed_unit_dim_shape_cast_to_shape_cast(%vec: vector<[4]xf32>) -> vector<1x[4]xf32> {
- // CHECK: vector.shape_cast %[[VEC]] : vector<[4]xf32> to vector<1x[4]xf32>
- // CHECK-NOT: vector.transpose
- %0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32>
- %1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
- return %1 : vector<1x[4]xf32>
-}
-
-// -----
-
// CHECK-LABEL: extract_from_create_mask
// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
func.func @extract_from_create_mask(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {
4 changes: 2 additions & 2 deletions third_party/llvm/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")

def repo(name):
"""Imports LLVM."""
LLVM_COMMIT = "651a49c4b6bdef81c8deddbe653258c066867a58"
LLVM_SHA256 = "75885391fa9f6479801953ddfc8d3f241dcd9bff6145d11c3fbcbf031ade8ae1"
LLVM_COMMIT = "565dddec6396d84befa122aa69634b055a60da17"
LLVM_SHA256 = "6f72624a5a3473dceac5399cdd126a2f8a0009e353f7d819a3909ce4e4f195f7"

tf_http_archive(
name = name,
Expand Down
16 changes: 16 additions & 0 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -2302,6 +2302,22 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c
+} // namespace experimental
+} // namespace stablehlo
+} // namespace mlir
diff --ruN a/stablehlo/stablehlo/tests/infer_chlo.mlir b/stablehlo/stablehlo/tests/infer_chlo.mlir
--- stablehlo/stablehlo/tests/infer_chlo.mlir
+++ stablehlo/stablehlo/tests/infer_chlo.mlir
@@ -120,10 +120,10 @@
// -----
// CHECK-LABEL: @broadcast_select_reify
func.func @broadcast_select_reify(%arg0: tensor<2xi1>, %arg1: tensor<?xi32>, %arg2: tensor<?xi32>) -> tensor<1xindex> {
- // CHECK: %0 = shape.const_shape [2] : tensor<1xindex>
+ // CHECK: %0 = shape.shape_of %arg0 : tensor<2xi1> -> tensor<1xindex>
// CHECK-NEXT: %1 = shape.shape_of %arg1 : tensor<?xi32> -> tensor<1xindex>
// CHECK-NEXT: %2 = shape.shape_of %arg2 : tensor<?xi32> -> tensor<1xindex>
- // CHECK-NEXT: %3 = shape.broadcast %1, %2, %0 : tensor<1xindex>, tensor<1xindex>, tensor<1xindex> -> tensor<1xindex>
+ // CHECK-NEXT: %3 = shape.broadcast %0, %1, %2 : tensor<1xindex>, tensor<1xindex>, tensor<1xindex> -> tensor<1xindex>
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
%1 = "hlo_test_infer.reify_return_type_shapes"(%0) : (tensor<?xi32>) -> tensor<1xindex>
return %1: tensor<1xindex>
diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h
--- stablehlo/stablehlo/transforms/Passes.h
+++ stablehlo/stablehlo/transforms/Passes.h
Expand Down
84 changes: 0 additions & 84 deletions third_party/tsl/third_party/llvm/generated.patch
Original file line number Diff line number Diff line change
@@ -1,85 +1 @@
Auto generated patch. Do not edit or delete it, even if empty.
diff -ruN --strip-trailing-cr a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5548,57 +5548,12 @@
}
};

-/// Folds transpose(shape_cast) into a new shape_cast, when the transpose just
-/// permutes a unit dim from the result of the shape_cast.
-class FoldTransposeShapeCast : public OpRewritePattern<TransposeOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(TransposeOp transpOp,
- PatternRewriter &rewriter) const override {
- Value transposeSrc = transpOp.getVector();
- auto shapeCastOp = transposeSrc.getDefiningOp<vector::ShapeCastOp>();
- if (!shapeCastOp)
- return rewriter.notifyMatchFailure(
- transpOp, "TransposeOp source is not ShapeCastOp");
-
- auto sourceType = transpOp.getSourceVectorType();
- auto resultType = transpOp.getResultVectorType();
-
- auto filterUnitDims = [](VectorType type) {
- return llvm::make_filter_range(
- llvm::zip_equal(type.getShape(), type.getScalableDims()),
- [&](auto dim) {
- auto [size, isScalable] = dim;
- return size != 1 || isScalable;
- });
- };
-
- auto sourceWithoutUnitDims = filterUnitDims(sourceType);
- auto resultWithoutUnitDims = filterUnitDims(resultType);
-
- // If this transpose just permutes a unit dim, then we can fold it into the
- // shape_cast.
- for (auto [srcDim, resDim] :
- llvm::zip_equal(sourceWithoutUnitDims, resultWithoutUnitDims)) {
- if (srcDim != resDim)
- return rewriter.notifyMatchFailure(transpOp,
- "TransposeOp permutes non-unit dim");
- }
-
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resultType,
- shapeCastOp.getSource());
-
- return success();
- };
-};
-
} // namespace

void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
- TransposeFolder, FoldTransposeSplat, FoldTransposeShapeCast>(
- context);
+ TransposeFolder, FoldTransposeSplat>(context);
}

//===----------------------------------------------------------------------===//
diff -ruN --strip-trailing-cr a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -67,18 +67,6 @@

// -----

-// CHECK-LABEL: transposed_unit_dim_shape_cast_to_shape_cast
-// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
-func.func @transposed_unit_dim_shape_cast_to_shape_cast(%vec: vector<[4]xf32>) -> vector<1x[4]xf32> {
- // CHECK: vector.shape_cast %[[VEC]] : vector<[4]xf32> to vector<1x[4]xf32>
- // CHECK-NOT: vector.transpose
- %0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32>
- %1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
- return %1 : vector<1x[4]xf32>
-}
-
-// -----
-
// CHECK-LABEL: extract_from_create_mask
// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
func.func @extract_from_create_mask(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {
4 changes: 2 additions & 2 deletions third_party/tsl/third_party/llvm/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")

def repo(name):
"""Imports LLVM."""
LLVM_COMMIT = "651a49c4b6bdef81c8deddbe653258c066867a58"
LLVM_SHA256 = "75885391fa9f6479801953ddfc8d3f241dcd9bff6145d11c3fbcbf031ade8ae1"
LLVM_COMMIT = "565dddec6396d84befa122aa69634b055a60da17"
LLVM_SHA256 = "6f72624a5a3473dceac5399cdd126a2f8a0009e353f7d819a3909ce4e4f195f7"

tf_http_archive(
name = name,
Expand Down
Loading

0 comments on commit 633fc7f

Please sign in to comment.