Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] add cuDNN flash attention support in XLA (3rd PR with only rewriter changes) #6872

Closed
wants to merge 16 commits into from

Conversation

Cjkkkk
Copy link
Contributor

@Cjkkkk Cjkkkk commented Nov 8, 2023

This is the 3nd PR of splitting #5910 with only rewriter changes
1st PR #6293 merged.
2nd PR #6657 merged.

  • Add pattern match for causal mask
  • Add paxml dropout pattern match
  • Add flash attention fusion
  • Add flash attention support cuDNN version guard
  • Add tests for flash attention rewriter/e2e

@ddunl
Copy link
Member

ddunl commented Nov 8, 2023

FYI 2nd PR is getting rolled back as we speak unfortunately, I'll have more details tomorrow (or maybe @cheshire can look into it if he has time tomorrow)

@Cjkkkk Cjkkkk changed the title [XLA:GPU ] add cuDNN flash attention support in XLA (3rd PR with only rewriter changes) [XLA:GPU] add cuDNN flash attention support in XLA (3rd PR with only rewriter changes) Nov 8, 2023
@Cjkkkk
Copy link
Contributor Author

Cjkkkk commented Nov 8, 2023

FYI 2nd PR is getting rolled back as we speak unfortunately, I'll have more details tomorrow (or maybe @cheshire can look into it if he has time tomorrow)

Thanks for the info. Let me know once you have more details.

@Cjkkkk
Copy link
Contributor Author

Cjkkkk commented Nov 22, 2023

@akuegel Hi Adrian, do we have any other items that need to be done for this PR since the back pointer stuff is merged?

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Nov 22, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Nov 22, 2023
xla/service/gpu/cudnn_fused_mha_rewriter.cc Outdated Show resolved Hide resolved
xla/service/gpu/cudnn_fused_mha_rewriter.cc Outdated Show resolved Hide resolved
xla/service/gpu/cudnn_fused_mha_rewriter.cc Outdated Show resolved Hide resolved
xla/service/gpu/cudnn_fused_mha_rewriter.cc Outdated Show resolved Hide resolved
xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc Outdated Show resolved Hide resolved
xla/service/gpu/tests/gpu_fused_mha_test.cc Show resolved Hide resolved
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Nov 27, 2023
@ddunl ddunl added kokoro:force-run Forces CI to rerun and removed kokoro:force-run Forces CI to rerun labels Nov 28, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Nov 28, 2023
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Nov 28, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Nov 28, 2023
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Nov 30, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Nov 30, 2023
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Nov 30, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Nov 30, 2023
@Cjkkkk
Copy link
Contributor Author

Cjkkkk commented Nov 30, 2023

@akuegel Hi, Adrian, I am not able to see why XLA Linux GPU with NVCC failed. It shows 404 from my side. Could you share some info if you have access to that?

@akuegel
Copy link
Member

akuegel commented Nov 30, 2023

@akuegel Hi, Adrian, I am not able to see why XLA Linux GPU with NVCC failed. It shows 404 from my side. Could you share some info if you have access to that?

I also get that error when I click the link. I am not sure where I would be able to see the results. But I will see the results from presubmit testing once this CL is imported, and if any error related to this PR shows up, I will let you know.

@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 11, 2024
copybara-service bot pushed a commit that referenced this pull request Jan 12, 2024
…with only rewriter changes)

Imported from GitHub PR #6872

This is the 3nd PR of splitting #5910 with only rewriter changes
1st PR #6293 merged.
2nd PR #6657 merged.

* Add pattern match for causal mask
* Add paxml dropout pattern match
* Add flash attention fusion
* Add flash attention support cuDNN version guard
* Add tests for flash attention rewriter/e2e
Copybara import of the project:

--
deaf693 by cjkkkk <ske@nvidia.com>:

init flash attention rewriter

--
2756b51 by cjkkkk <ske@nvidia.com>:

use while body back pointer to find causal mask

--
5fee679 by cjkkkk <ske@nvidia.com>:

add gpu backend to fmha e2e tests && address some format issues

--
f18ff4b by cjkkkk <ske@nvidia.com>:

fix rebase error

--
e5937e6 by cjkkkk <ske@nvidia.com>:

Use GPT3_5B model pre rewriter HLo

--
9e56735 by cjkkkk <ske@nvidia.com>:

add flash attention cuDNN version check && restore fwd graph is dbias/mask is not supported

--
72404a0 by cjkkkk <ske@nvidia.com>:

fix case with no bias but also no causal_mask

--
3162207 by cjkkkk <ske@nvidia.com>:

remove unused branch

--
712d86f by cjkkkk <ske@nvidia.com>:

rebased and address some comments

--
48ffd45 by cjkkkk <ske@nvidia.com>:

make causal mask/bias both optional

--
542fe9f by cjkkkk <ske@nvidia.com>:

address some comments and fix wrong layout for softmax stat if O is not [batch, num_heads, seq, head] layout

--
1804894 by cjkkkk <ske@nvidia.com>:

add flash attention cross attention with cuDNN > 8.9.4

--
28071dc by cjkkkk <ske@nvidia.com>:

fix fwd graph restore

--
9de2496 by cjkkkk <ske@nvidia.com>:

fix rebase error

--
3b9ff9b by cjkkkk <ske@nvidia.com>:

add guard for optional dropout_rate

Merging this change closes #6872

FUTURE_COPYBARA_INTEGRATE_REVIEW=#6872 from Cjkkkk:flash_attention_rewriter 3b9ff9b
PiperOrigin-RevId: 596851183
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 12, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 12, 2024
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jan 15, 2024
…with only rewriter changes)

Imported from GitHub PR openxla/xla#6872

This is the 3nd PR of splitting openxla/xla#5910 with only rewriter changes
1st PR openxla/xla#6293 merged.
2nd PR openxla/xla#6657 merged.

* Add pattern match for causal mask
* Add paxml dropout pattern match
* Add flash attention fusion
* Add flash attention support cuDNN version guard
* Add tests for flash attention rewriter/e2e
Copybara import of the project:

--
4b1505c6beafb1d2e7ee526001169778ba2cec4f by cjkkkk <ske@nvidia.com>:

init flash attention rewriter

--
745c2ad8dd0637e1c5ee71febcc9d40b53ff282d by cjkkkk <ske@nvidia.com>:

use while body back pointer to find causal mask

--
5f95a054b19d8205ef90db35a4fac01dc359f523 by cjkkkk <ske@nvidia.com>:

add gpu backend to fmha e2e tests && address some format issues

--
4bbd7df031b5c5c88de4dac5f2cdd32f2d4a6c68 by cjkkkk <ske@nvidia.com>:

fix rebase error

--
d1d9f2c87748c8eee84bff58a27c3a6968749f6b by cjkkkk <ske@nvidia.com>:

Use GPT3_5B model pre rewriter HLo

--
84e9ff3886e68f380c45c7999916f9888e547546 by cjkkkk <ske@nvidia.com>:

add flash attention cuDNN version check && restore fwd graph is dbias/mask is not supported

--
39699759e6cb0feaa903848098da2cf94b6625aa by cjkkkk <ske@nvidia.com>:

fix case with no bias but also no causal_mask

--
f529af2cb2443c742211e7273092c92cdf8dbc47 by cjkkkk <ske@nvidia.com>:

remove unused branch

--
9af8953ad1d32d700380fe76d2373b2ad6ee0748 by cjkkkk <ske@nvidia.com>:

rebased and address some comments

--
73928b4aa50778757944a1bc72a3e3c4ea70d9d8 by cjkkkk <ske@nvidia.com>:

make causal mask/bias both optional

--
7c573dc2f508213a6e8e77930fc014471d4c09a1 by cjkkkk <ske@nvidia.com>:

address some comments and fix wrong layout for softmax stat if O is not [batch, num_heads, seq, head] layout

--
7390db9ba53a1621d98d0c82420d8af1b894c10c by cjkkkk <ske@nvidia.com>:

add flash attention cross attention with cuDNN > 8.9.4

--
f8db19b902b59c13e2a4063525c7378f482da94e by cjkkkk <ske@nvidia.com>:

fix fwd graph restore

--
031a11a878380486c0a1230bb147adba4c9fc101 by cjkkkk <ske@nvidia.com>:

fix rebase error

--
f5c6b0347279062ab988b6b5c748fdc74edf1dcf by cjkkkk <ske@nvidia.com>:

add guard for optional dropout_rate

--
42d55c85de6eef6716831a0c2f4601d31dbd6ed5 by cjkkkk <ske@nvidia.com>:

fix gpubackend config

Merging this change closes #6872

PiperOrigin-RevId: 598554223
terrykong added a commit to NVIDIA/JAX-Toolbox that referenced this pull request Mar 8, 2024
Adding the JAX T5x FMHA E2E system test to check for fmha lowering
support. Following are the steps implemented in the test:

FMHA lowering flag is enabled by default now, enabled the dumping of hlo
to track fmha forward and backward instructions.
Added the test as part of _ci.yaml file and also added a nightly
workflow file for it. We will add this test as part of performance
benchmarking later and add hlo to baseline.
Also added changes for correction of seq length of decoder (should be a
multiple of 64)

The test was failing with following error related to
CUDNN_STATUS_BAD_PARAM. The fix for this is added in the [PR]
(openxla/xla#6872) in upstream which is now
merged and the test passes.
[Bug](https://nvbugspro.nvidia.com/bug/4409713) for this error.

run for these changes: [workflow run
link](https://github.com/NVIDIA/JAX-Toolbox/actions/runs/7894631992)

---------

Co-authored-by: Terry Kong <terryk@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.