Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #6872: [XLA:GPU] add cuDNN flash attention support in XLA (3rd PR with only rewriter changes) #7593

Closed
wants to merge 1 commit into from

Conversation

copybara-service[bot]
Copy link

PR #6872: [XLA:GPU] add cuDNN flash attention support in XLA (3rd PR with only rewriter changes)

Imported from GitHub PR #6872

This is the 3nd PR of splitting #5910 with only rewriter changes
1st PR #6293 merged.
2nd PR #6657 merged.

  • Add pattern match for causal mask
  • Add paxml dropout pattern match
  • Add flash attention fusion
  • Add flash attention support cuDNN version guard
  • Add tests for flash attention rewriter/e2e
    Copybara import of the project:

--
490d0a3 by cjkkkk ske@nvidia.com:

init flash attention rewriter

--
90e765f by cjkkkk ske@nvidia.com:

use while body back pointer to find causal mask

--
a3e5905 by cjkkkk ske@nvidia.com:

add gpu backend to fmha e2e tests && address some format issues

--
c82c064 by cjkkkk ske@nvidia.com:

fix rebase error

--
2f30df0 by cjkkkk ske@nvidia.com:

Use GPT3_5B model pre rewriter HLo

--
47aceb1 by cjkkkk ske@nvidia.com:

add flash attention cuDNN version check && restore fwd graph is dbias/mask is not supported

Merging this change closes #6872

FUTURE_COPYBARA_INTEGRATE_REVIEW=#6872 from Cjkkkk:flash_attention_rewriter 47aceb1

…with only rewriter changes)

Imported from GitHub PR #6872

This is the 3nd PR of splitting #5910 with only rewriter changes
1st PR #6293 merged.
2nd PR #6657 merged.

* Add pattern match for causal mask
* Add paxml dropout pattern match
* Add flash attention fusion
* Add flash attention support cuDNN version guard
* Add tests for flash attention rewriter/e2e
Copybara import of the project:

--
490d0a3 by cjkkkk <ske@nvidia.com>:

init flash attention rewriter

--
90e765f by cjkkkk <ske@nvidia.com>:

use while body back pointer to find causal mask

--
a3e5905 by cjkkkk <ske@nvidia.com>:

add gpu backend to fmha e2e tests && address some format issues

--
c82c064 by cjkkkk <ske@nvidia.com>:

fix rebase error

--
2f30df0 by cjkkkk <ske@nvidia.com>:

Use GPT3_5B model pre rewriter HLo

--
47aceb1 by cjkkkk <ske@nvidia.com>:

add flash attention cuDNN version check && restore fwd graph is dbias/mask is not supported

Merging this change closes #6872

FUTURE_COPYBARA_INTEGRATE_REVIEW=#6872 from Cjkkkk:flash_attention_rewriter 47aceb1
PiperOrigin-RevId: 588714363
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant