Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement faster RoPE embedding #238

Merged
merged 1 commit into from
Mar 15, 2024

Conversation

HuyNguyen-hust
Copy link
Contributor

@HuyNguyen-hust HuyNguyen-hust commented Mar 12, 2024

PR proposes a bit change to the current RoPE embedding kernel:

  • The current implementation launches 1 block for 1 head on axis 1. Each block has to reload the same sin/cos which is inefficient.
  • Reorganize grid that on axis 1, instead of launching a block for a head, I launch a block for a group of heads (4-8 heads). That enables loading sin/cos only once and reuse it to compute all the heads inside that block.

Benchmark with batch_size=4, head_dim=128, n_heads=32 (// 2 means BLOCK_SIZE=head_dim // 2. If not BLOCK_SIZE=head_dim):
image

The figure indicates that mine is more sensitive to BLOCK_SIZE.

@danielhanchen
Copy link
Contributor

Thanks @HuyNguyen-hust a lot! As per our discussion on Discord - I just want to say thank you again - super apprecitate this! Will do some tests on my end and I'll expedite this PR!

@danielhanchen
Copy link
Contributor

@HuyNguyen-hust I tested the kernel! Can confirm RoPE itself should be faster. The effect on a full training run though is less pronounced sadly, since through Pytorch's Profiler, RoPE itself now takes around 1% of the total runtime, with matrix multiplications taking the bulk of the time. DPO for eg - with your RoPE fix: 1553 seconds. Original: 1542 seconds. So within the margin of error. This was on Colab T4, so I'm pretty sure A100s get more noticeable effects.

However, your kernel works absolute wonders when long sequence lengths come into play! The RoPE kernel does creep up to around 2-3% of the total runtime, which means savings are well deserved!

Thanks so much for wonderful contribution - added this in! :)

I'll probably play around with the group size - it seems like this might be an auto-tunable number!!!

@danielhanchen danielhanchen merged commit 809bdbe into unslothai:main Mar 15, 2024
@chiennv2000
Copy link

awesome @HuyNguyen-hust, congrats on your great work!

1 similar comment
@hieule88
Copy link

hieule88 commented Apr 9, 2024

awesome @HuyNguyen-hust, congrats on your great work!

@mohsen202
Copy link

thanks

@namnh194
Copy link

cool :O

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants