-
Notifications
You must be signed in to change notification settings - Fork 405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XLA:GPU] add cuDNN flash attention support in XLA (3rd PR with only rewriter changes) #6872
Conversation
FYI 2nd PR is getting rolled back as we speak unfortunately, I'll have more details tomorrow (or maybe @cheshire can look into it if he has time tomorrow) |
Thanks for the info. Let me know once you have more details. |
@akuegel Hi Adrian, do we have any other items that need to be done for this PR since the back pointer stuff is merged? |
a180546
to
9cb0712
Compare
b4c7fba
to
603a053
Compare
@akuegel Hi, Adrian, I am not able to see why |
I also get that error when I click the link. I am not sure where I would be able to see the results. But I will see the results from presubmit testing once this CL is imported, and if any error related to this PR shows up, I will let you know. |
74021ed
to
47aceb1
Compare
…with only rewriter changes) Imported from GitHub PR #6872 This is the 3nd PR of splitting #5910 with only rewriter changes 1st PR #6293 merged. 2nd PR #6657 merged. * Add pattern match for causal mask * Add paxml dropout pattern match * Add flash attention fusion * Add flash attention support cuDNN version guard * Add tests for flash attention rewriter/e2e Copybara import of the project: -- deaf693 by cjkkkk <ske@nvidia.com>: init flash attention rewriter -- 2756b51 by cjkkkk <ske@nvidia.com>: use while body back pointer to find causal mask -- 5fee679 by cjkkkk <ske@nvidia.com>: add gpu backend to fmha e2e tests && address some format issues -- f18ff4b by cjkkkk <ske@nvidia.com>: fix rebase error -- e5937e6 by cjkkkk <ske@nvidia.com>: Use GPT3_5B model pre rewriter HLo -- 9e56735 by cjkkkk <ske@nvidia.com>: add flash attention cuDNN version check && restore fwd graph is dbias/mask is not supported -- 72404a0 by cjkkkk <ske@nvidia.com>: fix case with no bias but also no causal_mask -- 3162207 by cjkkkk <ske@nvidia.com>: remove unused branch -- 712d86f by cjkkkk <ske@nvidia.com>: rebased and address some comments -- 48ffd45 by cjkkkk <ske@nvidia.com>: make causal mask/bias both optional -- 542fe9f by cjkkkk <ske@nvidia.com>: address some comments and fix wrong layout for softmax stat if O is not [batch, num_heads, seq, head] layout -- 1804894 by cjkkkk <ske@nvidia.com>: add flash attention cross attention with cuDNN > 8.9.4 -- 28071dc by cjkkkk <ske@nvidia.com>: fix fwd graph restore -- 9de2496 by cjkkkk <ske@nvidia.com>: fix rebase error -- 3b9ff9b by cjkkkk <ske@nvidia.com>: add guard for optional dropout_rate Merging this change closes #6872 FUTURE_COPYBARA_INTEGRATE_REVIEW=#6872 from Cjkkkk:flash_attention_rewriter 3b9ff9b PiperOrigin-RevId: 596851183
…/mask is not supported
…ot [batch, num_heads, seq, head] layout
3b9ff9b
to
42d55c8
Compare
…with only rewriter changes) Imported from GitHub PR openxla/xla#6872 This is the 3nd PR of splitting openxla/xla#5910 with only rewriter changes 1st PR openxla/xla#6293 merged. 2nd PR openxla/xla#6657 merged. * Add pattern match for causal mask * Add paxml dropout pattern match * Add flash attention fusion * Add flash attention support cuDNN version guard * Add tests for flash attention rewriter/e2e Copybara import of the project: -- 4b1505c6beafb1d2e7ee526001169778ba2cec4f by cjkkkk <ske@nvidia.com>: init flash attention rewriter -- 745c2ad8dd0637e1c5ee71febcc9d40b53ff282d by cjkkkk <ske@nvidia.com>: use while body back pointer to find causal mask -- 5f95a054b19d8205ef90db35a4fac01dc359f523 by cjkkkk <ske@nvidia.com>: add gpu backend to fmha e2e tests && address some format issues -- 4bbd7df031b5c5c88de4dac5f2cdd32f2d4a6c68 by cjkkkk <ske@nvidia.com>: fix rebase error -- d1d9f2c87748c8eee84bff58a27c3a6968749f6b by cjkkkk <ske@nvidia.com>: Use GPT3_5B model pre rewriter HLo -- 84e9ff3886e68f380c45c7999916f9888e547546 by cjkkkk <ske@nvidia.com>: add flash attention cuDNN version check && restore fwd graph is dbias/mask is not supported -- 39699759e6cb0feaa903848098da2cf94b6625aa by cjkkkk <ske@nvidia.com>: fix case with no bias but also no causal_mask -- f529af2cb2443c742211e7273092c92cdf8dbc47 by cjkkkk <ske@nvidia.com>: remove unused branch -- 9af8953ad1d32d700380fe76d2373b2ad6ee0748 by cjkkkk <ske@nvidia.com>: rebased and address some comments -- 73928b4aa50778757944a1bc72a3e3c4ea70d9d8 by cjkkkk <ske@nvidia.com>: make causal mask/bias both optional -- 7c573dc2f508213a6e8e77930fc014471d4c09a1 by cjkkkk <ske@nvidia.com>: address some comments and fix wrong layout for softmax stat if O is not [batch, num_heads, seq, head] layout -- 7390db9ba53a1621d98d0c82420d8af1b894c10c by cjkkkk <ske@nvidia.com>: add flash attention cross attention with cuDNN > 8.9.4 -- f8db19b902b59c13e2a4063525c7378f482da94e by cjkkkk <ske@nvidia.com>: fix fwd graph restore -- 031a11a878380486c0a1230bb147adba4c9fc101 by cjkkkk <ske@nvidia.com>: fix rebase error -- f5c6b0347279062ab988b6b5c748fdc74edf1dcf by cjkkkk <ske@nvidia.com>: add guard for optional dropout_rate -- 42d55c85de6eef6716831a0c2f4601d31dbd6ed5 by cjkkkk <ske@nvidia.com>: fix gpubackend config Merging this change closes #6872 PiperOrigin-RevId: 598554223
Adding the JAX T5x FMHA E2E system test to check for fmha lowering support. Following are the steps implemented in the test: FMHA lowering flag is enabled by default now, enabled the dumping of hlo to track fmha forward and backward instructions. Added the test as part of _ci.yaml file and also added a nightly workflow file for it. We will add this test as part of performance benchmarking later and add hlo to baseline. Also added changes for correction of seq length of decoder (should be a multiple of 64) The test was failing with following error related to CUDNN_STATUS_BAD_PARAM. The fix for this is added in the [PR] (openxla/xla#6872) in upstream which is now merged and the test passes. [Bug](https://nvbugspro.nvidia.com/bug/4409713) for this error. run for these changes: [workflow run link](https://github.com/NVIDIA/JAX-Toolbox/actions/runs/7894631992) --------- Co-authored-by: Terry Kong <terryk@nvidia.com>
This is the 3nd PR of splitting #5910 with only rewriter changes
1st PR #6293 merged.
2nd PR #6657 merged.