Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] add cuDNN flash attention support in XLA (1st PR with only SE changes) #6293

Closed
wants to merge 3 commits into from

Conversation

Cjkkkk
Copy link
Contributor

@Cjkkkk Cjkkkk commented Oct 13, 2023

This is the 1st PR of splitting #5910 with only SE changes.

  • Added flash attention forward & backward graph generation.
  • Added flash attention runner.

@@ -467,7 +467,8 @@ class Stream {
std::optional<dnn::TensorDescriptor> activation_descriptor,
std::optional<dnn::TensorDescriptor> mask_descriptor,
std::optional<dnn::TensorDescriptor> bias_descriptor, double scale,
std::optional<double> dropout_rate, std::optional<int64_t> seed) {
std::optional<double> dropout_rate, std::optional<int64_t> seed,
bool is_flash_attention) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having such a large number of arguments isn't great, maybe we could have an options struct instead?

Copy link
Contributor Author

@Cjkkkk Cjkkkk Oct 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have

struct FusedMHAOp {
  using Signature = FusedMHASignature;
  struct Config {
    FusedMHAKind kind;
    double scale;
    const MatmulTensorDescriptor& bmm1_lhs_descriptor;
    const MatmulTensorDescriptor& bmm1_rhs_descriptor;
    const MatmulTensorDescriptor& bmm2_rhs_descriptor;
    const MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor;
    const TensorDescriptor& output_descriptor;
    std::optional<TensorDescriptor> bias_descriptor;
    std::optional<TensorDescriptor> mask_descriptor;
    std::optional<TensorDescriptor> activation_descriptor;
    std::optional<double> dropout_rate;
    std::optional<int64_t> seed;
    bool is_flash_attention;
  };

in the lazy_op_runner.h. We kind of pass in each member separately right now. Do we want to just accept a config in your opinion?

@@ -3703,7 +3707,7 @@ tsl::StatusOr<cudnn_frontend::Tensor> CreateCudnnMaskTensor(
// Create the mask output tensor
TF_ASSIGN_OR_RETURN(
auto mask_out_tensor,
CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 300,
CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 400,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you give some context for magic constants?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every tensor in the graph needs to have a unique ID, VIRTUAL_ID is just some large number to begin with so it does not overlap with the non-virtual tensor defined in enum CudnnfMHAUid. Tensor ID in scale/bias/mask/softmax/dropout will start with +200, +300, +400, +500, +600 respectively.

@Cjkkkk
Copy link
Contributor Author

Cjkkkk commented Oct 20, 2023

@cheshire Update a version that contains just SE changes and a few lines in runner to make things compilable. Let's review this PR first. Will open 2 more PR for rewriter/thunk/mlir OP. Meanwhile, keep this #5910 open for tracking additional bug fixes(if any) and tests. Will Keep 3 PR for view and sync any change required.

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Oct 20, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Oct 20, 2023
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Oct 24, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Oct 24, 2023
@tdanyluk
Copy link
Member

PR rotation: @cheshire - gentle ping

Copy link
Member

@tdanyluk tdanyluk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Cjkkkk,
Sorry but clang-tidy surfaced a few errors and possible improvements, so I'm adding them in a review.

xla/service/gpu/gpu_fused_mha_runner.cc Show resolved Hide resolved
xla/stream_executor/cuda/cuda_dnn.cc Outdated Show resolved Hide resolved
xla/stream_executor/cuda/cuda_dnn.cc Outdated Show resolved Hide resolved
xla/stream_executor/cuda/cuda_dnn.cc Show resolved Hide resolved
xla/stream_executor/cuda/cuda_dnn.cc Outdated Show resolved Hide resolved
xla/stream_executor/cuda/cuda_dnn.cc Outdated Show resolved Hide resolved
xla/stream_executor/cuda/cuda_dnn.cc Outdated Show resolved Hide resolved
xla/stream_executor/cuda/cuda_dnn.cc Outdated Show resolved Hide resolved
xla/stream_executor/cuda/cuda_dnn.cc Outdated Show resolved Hide resolved
xla/stream_executor/cuda/cuda_dnn.cc Outdated Show resolved Hide resolved
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Oct 27, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Oct 27, 2023
@Cjkkkk
Copy link
Contributor Author

Cjkkkk commented Oct 27, 2023

@tdanyluk @cheshire Hi, All clang-tidy issue addressed, could you take a look to see if there is anything else?

@tdanyluk
Copy link
Member

Thanks, trying to get it submitted internally, hopefully we'll get the approval in about a day.

copybara-service bot pushed a commit that referenced this pull request Oct 27, 2023
…with only SE changes)

Imported from GitHub PR #6293

This is the 1st PR of splitting #5910 with only SE changes.

* Added flash attention forward & backward graph generation.
* Added flash attention runner.
Copybara import of the project:

--
588269c by cjkkkk <ske@nvidia.com>:

add SE changes

--
cc79394 by cjkkkk <ske@nvidia.com>:

add bias && add causal mask as optional && fix bwd pattern match

Merging this change closes #6293

FUTURE_COPYBARA_INTEGRATE_REVIEW=#6293 from Cjkkkk:flash_attention_SE 27275d8
PiperOrigin-RevId: 576782330
copybara-service bot pushed a commit that referenced this pull request Oct 30, 2023
…with only SE changes)

Imported from GitHub PR #6293

This is the 1st PR of splitting #5910 with only SE changes.

* Added flash attention forward & backward graph generation.
* Added flash attention runner.
Copybara import of the project:

--
588269c by cjkkkk <ske@nvidia.com>:

add SE changes

--
cc79394 by cjkkkk <ske@nvidia.com>:

add bias && add causal mask as optional && fix bwd pattern match

Merging this change closes #6293

FUTURE_COPYBARA_INTEGRATE_REVIEW=#6293 from Cjkkkk:flash_attention_SE 27275d8
PiperOrigin-RevId: 576782330
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 30, 2023
…with only SE changes)

Imported from GitHub PR openxla/xla#6293

This is the 1st PR of splitting openxla/xla#5910 with only SE changes.

* Added flash attention forward & backward graph generation.
* Added flash attention runner.
Copybara import of the project:

--
588269c556893bd1eb01f2fbde3733897ec1fa6b by cjkkkk <ske@nvidia.com>:

add SE changes

--
cc7939441b80eb217b624c2369acbe06e36d1ac1 by cjkkkk <ske@nvidia.com>:

add bias && add causal mask as optional && fix bwd pattern match

Merging this change closes #6293

PiperOrigin-RevId: 577772951
copybara-service bot pushed a commit that referenced this pull request Nov 8, 2023
… with only MLIR lowering and thunk/runtime)

Imported from GitHub PR #6657

This is the 2nd PR of splitting #5910 with only MLIR lowering and thunk/runtime
1st PR #6293 merged.

* Added MLIR lowering for flash attention.
* Added thunk/runner/runtime support for flash attention.
Copybara import of the project:

--
6f89a73 by cjkkkk <ske@nvidia.com>:

init mlir lowering and thunk runtime

--
f57b8be by cjkkkk <ske@nvidia.com>:

address some comments

Merging this change closes #6657

COPYBARA_INTEGRATE_REVIEW=#6657 from Cjkkkk:flash_attention_mhlo_lowering f57b8be
PiperOrigin-RevId: 580413629
copybara-service bot pushed a commit to tensorflow/mlir-hlo that referenced this pull request Nov 8, 2023
… with only MLIR lowering and thunk/runtime)

Imported from GitHub PR openxla/xla#6657

This is the 2nd PR of splitting openxla/xla#5910 with only MLIR lowering and thunk/runtime
1st PR openxla/xla#6293 merged.

* Added MLIR lowering for flash attention.
* Added thunk/runner/runtime support for flash attention.
Copybara import of the project:

--
6f89a7355b4b46cbb974b39ca60e07ae08079f1a by cjkkkk <ske@nvidia.com>:

init mlir lowering and thunk runtime

--
f57b8bee2ba1ad361556c32cb9333c4ac4730016 by cjkkkk <ske@nvidia.com>:

address some comments

Merging this change closes #6657

PiperOrigin-RevId: 580413629
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 8, 2023
… with only MLIR lowering and thunk/runtime)

Imported from GitHub PR openxla/xla#6657

This is the 2nd PR of splitting openxla/xla#5910 with only MLIR lowering and thunk/runtime
1st PR openxla/xla#6293 merged.

* Added MLIR lowering for flash attention.
* Added thunk/runner/runtime support for flash attention.
Copybara import of the project:

--
6f89a7355b4b46cbb974b39ca60e07ae08079f1a by cjkkkk <ske@nvidia.com>:

init mlir lowering and thunk runtime

--
f57b8bee2ba1ad361556c32cb9333c4ac4730016 by cjkkkk <ske@nvidia.com>:

address some comments

Merging this change closes #6657

PiperOrigin-RevId: 580413629
copybara-service bot pushed a commit to tensorflow/mlir-hlo that referenced this pull request Nov 9, 2023
… with only MLIR lowering and thunk/runtime)

Imported from GitHub PR openxla/xla#6657

This is the 2nd PR of splitting openxla/xla#5910 with only MLIR lowering and thunk/runtime
1st PR openxla/xla#6293 merged.

* Added MLIR lowering for flash attention.
* Added thunk/runner/runtime support for flash attention.
Copybara import of the project:

-...

PiperOrigin-RevId: 580948071
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 9, 2023
PR #6657: [XLA:GPU ] add cuDNN flash attention support in XLA (2nd PR with only MLIR lowering and thunk/runtime)

Imported from GitHub PR openxla/xla#6657

This is the 2nd PR of splitting openxla/xla#5910 with only MLIR lowering and thunk/runtime
1st PR openxla/xla#6293 merged.

* Added MLIR lowering for flash attention.
* Added thunk/runner/runtime support for flash attention.
Copybara import of the project:

-...

PiperOrigin-RevId: 580948071
copybara-service bot pushed a commit that referenced this pull request Dec 7, 2023
…with only rewriter changes)

Imported from GitHub PR #6872

This is the 3nd PR of splitting #5910 with only rewriter changes
1st PR #6293 merged.
2nd PR #6657 merged.

* Add pattern match for causal mask
* Add paxml dropout pattern match
* Add flash attention fusion
* Add flash attention support cuDNN version guard
* Add tests for flash attention rewriter/e2e
Copybara import of the project:

--
490d0a3 by cjkkkk <ske@nvidia.com>:

init flash attention rewriter

--
90e765f by cjkkkk <ske@nvidia.com>:

use while body back pointer to find causal mask

--
a3e5905 by cjkkkk <ske@nvidia.com>:

add gpu backend to fmha e2e tests && address some format issues

--
c82c064 by cjkkkk <ske@nvidia.com>:

fix rebase error

--
2f30df0 by cjkkkk <ske@nvidia.com>:

Use GPT3_5B model pre rewriter HLo

--
47aceb1 by cjkkkk <ske@nvidia.com>:

add flash attention cuDNN version check && restore fwd graph is dbias/mask is not supported

Merging this change closes #6872

FUTURE_COPYBARA_INTEGRATE_REVIEW=#6872 from Cjkkkk:flash_attention_rewriter 47aceb1
PiperOrigin-RevId: 588714363
copybara-service bot pushed a commit that referenced this pull request Dec 7, 2023
…with only rewriter changes)

Imported from GitHub PR #6872

This is the 3nd PR of splitting #5910 with only rewriter changes
1st PR #6293 merged.
2nd PR #6657 merged.

* Add pattern match for causal mask
* Add paxml dropout pattern match
* Add flash attention fusion
* Add flash attention support cuDNN version guard
* Add tests for flash attention rewriter/e2e
Copybara import of the project:

--
490d0a3 by cjkkkk <ske@nvidia.com>:

init flash attention rewriter

--
90e765f by cjkkkk <ske@nvidia.com>:

use while body back pointer to find causal mask

--
a3e5905 by cjkkkk <ske@nvidia.com>:

add gpu backend to fmha e2e tests && address some format issues

--
c82c064 by cjkkkk <ske@nvidia.com>:

fix rebase error

--
2f30df0 by cjkkkk <ske@nvidia.com>:

Use GPT3_5B model pre rewriter HLo

--
47aceb1 by cjkkkk <ske@nvidia.com>:

add flash attention cuDNN version check && restore fwd graph is dbias/mask is not supported

Merging this change closes #6872

FUTURE_COPYBARA_INTEGRATE_REVIEW=#6872 from Cjkkkk:flash_attention_rewriter 47aceb1
PiperOrigin-RevId: 588714363
copybara-service bot pushed a commit that referenced this pull request Dec 15, 2023
…with only rewriter changes)

Imported from GitHub PR #6872

This is the 3nd PR of splitting #5910 with only rewriter changes
1st PR #6293 merged.
2nd PR #6657 merged.

* Add pattern match for causal mask
* Add paxml dropout pattern match
* Add flash attention fusion
* Add flash attention support cuDNN version guard
* Add tests for flash attention rewriter/e2e
Copybara import of the project:

--
490d0a3 by cjkkkk <ske@nvidia.com>:

init flash attention rewriter

--
90e765f by cjkkkk <ske@nvidia.com>:

use while body back pointer to find causal mask

--
a3e5905 by cjkkkk <ske@nvidia.com>:

add gpu backend to fmha e2e tests && address some format issues

--
c82c064 by cjkkkk <ske@nvidia.com>:

fix rebase error

--
2f30df0 by cjkkkk <ske@nvidia.com>:

Use GPT3_5B model pre rewriter HLo

--
47aceb1 by cjkkkk <ske@nvidia.com>:

add flash attention cuDNN version check && restore fwd graph is dbias/mask is not supported

Merging this change closes #6872

FUTURE_COPYBARA_INTEGRATE_REVIEW=#6872 from Cjkkkk:flash_attention_rewriter 47aceb1
PiperOrigin-RevId: 588714363
copybara-service bot pushed a commit that referenced this pull request Jan 12, 2024
…with only rewriter changes)

Imported from GitHub PR #6872

This is the 3nd PR of splitting #5910 with only rewriter changes
1st PR #6293 merged.
2nd PR #6657 merged.

* Add pattern match for causal mask
* Add paxml dropout pattern match
* Add flash attention fusion
* Add flash attention support cuDNN version guard
* Add tests for flash attention rewriter/e2e
Copybara import of the project:

--
deaf693 by cjkkkk <ske@nvidia.com>:

init flash attention rewriter

--
2756b51 by cjkkkk <ske@nvidia.com>:

use while body back pointer to find causal mask

--
5fee679 by cjkkkk <ske@nvidia.com>:

add gpu backend to fmha e2e tests && address some format issues

--
f18ff4b by cjkkkk <ske@nvidia.com>:

fix rebase error

--
e5937e6 by cjkkkk <ske@nvidia.com>:

Use GPT3_5B model pre rewriter HLo

--
9e56735 by cjkkkk <ske@nvidia.com>:

add flash attention cuDNN version check && restore fwd graph is dbias/mask is not supported

--
72404a0 by cjkkkk <ske@nvidia.com>:

fix case with no bias but also no causal_mask

--
3162207 by cjkkkk <ske@nvidia.com>:

remove unused branch

--
712d86f by cjkkkk <ske@nvidia.com>:

rebased and address some comments

--
48ffd45 by cjkkkk <ske@nvidia.com>:

make causal mask/bias both optional

--
542fe9f by cjkkkk <ske@nvidia.com>:

address some comments and fix wrong layout for softmax stat if O is not [batch, num_heads, seq, head] layout

--
1804894 by cjkkkk <ske@nvidia.com>:

add flash attention cross attention with cuDNN > 8.9.4

--
28071dc by cjkkkk <ske@nvidia.com>:

fix fwd graph restore

--
9de2496 by cjkkkk <ske@nvidia.com>:

fix rebase error

--
3b9ff9b by cjkkkk <ske@nvidia.com>:

add guard for optional dropout_rate

Merging this change closes #6872

FUTURE_COPYBARA_INTEGRATE_REVIEW=#6872 from Cjkkkk:flash_attention_rewriter 3b9ff9b
PiperOrigin-RevId: 596851183
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jan 15, 2024
…with only rewriter changes)

Imported from GitHub PR openxla/xla#6872

This is the 3nd PR of splitting openxla/xla#5910 with only rewriter changes
1st PR openxla/xla#6293 merged.
2nd PR openxla/xla#6657 merged.

* Add pattern match for causal mask
* Add paxml dropout pattern match
* Add flash attention fusion
* Add flash attention support cuDNN version guard
* Add tests for flash attention rewriter/e2e
Copybara import of the project:

--
4b1505c6beafb1d2e7ee526001169778ba2cec4f by cjkkkk <ske@nvidia.com>:

init flash attention rewriter

--
745c2ad8dd0637e1c5ee71febcc9d40b53ff282d by cjkkkk <ske@nvidia.com>:

use while body back pointer to find causal mask

--
5f95a054b19d8205ef90db35a4fac01dc359f523 by cjkkkk <ske@nvidia.com>:

add gpu backend to fmha e2e tests && address some format issues

--
4bbd7df031b5c5c88de4dac5f2cdd32f2d4a6c68 by cjkkkk <ske@nvidia.com>:

fix rebase error

--
d1d9f2c87748c8eee84bff58a27c3a6968749f6b by cjkkkk <ske@nvidia.com>:

Use GPT3_5B model pre rewriter HLo

--
84e9ff3886e68f380c45c7999916f9888e547546 by cjkkkk <ske@nvidia.com>:

add flash attention cuDNN version check && restore fwd graph is dbias/mask is not supported

--
39699759e6cb0feaa903848098da2cf94b6625aa by cjkkkk <ske@nvidia.com>:

fix case with no bias but also no causal_mask

--
f529af2cb2443c742211e7273092c92cdf8dbc47 by cjkkkk <ske@nvidia.com>:

remove unused branch

--
9af8953ad1d32d700380fe76d2373b2ad6ee0748 by cjkkkk <ske@nvidia.com>:

rebased and address some comments

--
73928b4aa50778757944a1bc72a3e3c4ea70d9d8 by cjkkkk <ske@nvidia.com>:

make causal mask/bias both optional

--
7c573dc2f508213a6e8e77930fc014471d4c09a1 by cjkkkk <ske@nvidia.com>:

address some comments and fix wrong layout for softmax stat if O is not [batch, num_heads, seq, head] layout

--
7390db9ba53a1621d98d0c82420d8af1b894c10c by cjkkkk <ske@nvidia.com>:

add flash attention cross attention with cuDNN > 8.9.4

--
f8db19b902b59c13e2a4063525c7378f482da94e by cjkkkk <ske@nvidia.com>:

fix fwd graph restore

--
031a11a878380486c0a1230bb147adba4c9fc101 by cjkkkk <ske@nvidia.com>:

fix rebase error

--
f5c6b0347279062ab988b6b5c748fdc74edf1dcf by cjkkkk <ske@nvidia.com>:

add guard for optional dropout_rate

--
42d55c85de6eef6716831a0c2f4601d31dbd6ed5 by cjkkkk <ske@nvidia.com>:

fix gpubackend config

Merging this change closes #6872

PiperOrigin-RevId: 598554223
copybara-service bot pushed a commit that referenced this pull request Jan 15, 2024
…with only rewriter changes)

Imported from GitHub PR #6872

This is the 3nd PR of splitting #5910 with only rewriter changes
1st PR #6293 merged.
2nd PR #6657 merged.

* Add pattern match for causal mask
* Add paxml dropout pattern match
* Add flash attention fusion
* Add flash attention support cuDNN version guard
* Add tests for flash attention rewriter/e2e
Copybara import of the project:

--
4b1505c by cjkkkk <ske@nvidia.com>:

init flash attention rewriter

--
745c2ad by cjkkkk <ske@nvidia.com>:

use while body back pointer to find causal mask

--
5f95a05 by cjkkkk <ske@nvidia.com>:

add gpu backend to fmha e2e tests && address some format issues

--
4bbd7df by cjkkkk <ske@nvidia.com>:

fix rebase error

--
d1d9f2c by cjkkkk <ske@nvidia.com>:

Use GPT3_5B model pre rewriter HLo

--
84e9ff3 by cjkkkk <ske@nvidia.com>:

add flash attention cuDNN version check && restore fwd graph is dbias/mask is not supported

--
3969975 by cjkkkk <ske@nvidia.com>:

fix case with no bias but also no causal_mask

--
f529af2 by cjkkkk <ske@nvidia.com>:

remove unused branch

--
9af8953 by cjkkkk <ske@nvidia.com>:

rebased and address some comments

--
73928b4 by cjkkkk <ske@nvidia.com>:

make causal mask/bias both optional

--
7c573dc by cjkkkk <ske@nvidia.com>:

address some comments and fix wrong layout for softmax stat if O is not [batch, num_heads, seq, head] layout

--
7390db9 by cjkkkk <ske@nvidia.com>:

add flash attention cross attention with cuDNN > 8.9.4

--
f8db19b by cjkkkk <ske@nvidia.com>:

fix fwd graph restore

--
031a11a by cjkkkk <ske@nvidia.com>:

fix rebase error

--
f5c6b03 by cjkkkk <ske@nvidia.com>:

add guard for optional dropout_rate

--
42d55c8 by cjkkkk <ske@nvidia.com>:

fix gpubackend config

Merging this change closes #6872

COPYBARA_INTEGRATE_REVIEW=#6872 from Cjkkkk:flash_attention_rewriter 42d55c8
PiperOrigin-RevId: 598554223
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants