Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flash Decoding #1151

Closed
zfang opened this issue Oct 14, 2023 · 5 comments
Closed

Add Flash Decoding #1151

zfang opened this issue Oct 14, 2023 · 5 comments

Comments

@zfang
Copy link

zfang commented Oct 14, 2023

Feature request

See https://pytorch.org/blog/flash-decoding/#:~:text=Flash%2DDecoding%20works%20in%203,exp%20of%20the%20attention%20values.

Motivation

Flash decoding further improves attention mechanism compared to FlashAttention V2 on long context

Your contribution

None

@ssmi153
Copy link
Contributor

ssmi153 commented Oct 16, 2023

+1!

It looks like this might be included in FlashAttention v2.2. It's not clear from the blog whether any inference code needs to be changed to see the benefits of this.

@taishan1994
Copy link

+1

@OlivierDehaene
Copy link
Member

It's not clear if it is superior to paged attention: all the tests I saw were vs native Transformers which we know is not optimised.
The kernel fusing is nice though. I will make some tests and report back if we want to have this or not.

@dongs0104
Copy link
Contributor

@OlivierDehaene PA announce to PagedAttention V2 implements a similar idea to boost the performance when the batch size or the number of attention heads per GPU is small.

@OlivierDehaene
Copy link
Member

#1183 instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants