Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative decoding] [Help wanted] [Performance] Optimize draft-model speculative decoding #4630

Closed
6 of 8 tasks
cadedaniel opened this issue May 6, 2024 · 17 comments
Closed
6 of 8 tasks
Labels
help wanted Extra attention is needed performance Performance-related issues speculative-decoding

Comments

@cadedaniel
Copy link
Collaborator

cadedaniel commented May 6, 2024

Proposal to improve performance

With the end-to-end correctness tests merged in #3951, now we will optimize the implementation to get ~50% speedup on 70B model with temperature 1.0.

Work required:

P0/P1 -- priority
(Small/Medium/Large) -- relative size estimate

FAQ

What should the target configuration be for 50% speedup?

In the Anyscale fork we saw a 50% speedup on bs=8 with a 68m-sized draft model on TP1/70B target model on TP8 and a 7B draft model on TP(1|8)/70B target model on TP8. This was with the optimizations listed above as "P0".

Note we can do much better than this, with multi-query scoring (P1), GQA for target model scoring, and a dynamic speculation policy. This is just the starting point!

Why not implement Medusa / tree-attention?

We should implement this! The work here will lay the foundation for future improvements in speculative decoding. For example, Eagle uses the Medusa approach (fine-tuned heads plus tree attention) and even claims to beat Medusa. But for Eagle to work well in vLLM we need to optimize the sampler as listed above.

The north star should be: configurable tree size (top-k .. top-1), which uses multi-query attention for scoring (no batch expansion). This issue is about optimizing vLLM in the top-1 speculation case to get 50% speedup with draft models.

@cadedaniel cadedaniel added performance Performance-related issues speculative-decoding labels May 6, 2024
@youkaichao
Copy link
Member

youkaichao commented May 6, 2024

Support draft model on different tensor-parallel-size than target model

This should be doable. Just need to figure out the UX change of how users use it.

Do spec workers and non-spec workers share process/device? e.g. when we have tp=8 in current code, and want to add another tp=2 for spec decoding, do we want tp=2 to be another 2 processes, or from the subset of the tp=8 processes?

@cadedaniel
Copy link
Collaborator Author

See the code linked here @youkaichao : #4632. The spec worker and non-spec workers share the same process.

@KexinFeng
Copy link

About the tree-attention/Medusa/Eagle, one of the core implementation will be tree attention mask in flash attention, which is currently not ready. I'd like to bring your attention to it Dao-AILab/flash-attention#924. If anyone would like to contribute to it, it would be great.

@sighingnow
Copy link
Contributor

In the Anyscale fork we saw a 50% speedup on bs=8 with a 68m-sized draft model on TP1/70B target model on TP8 and a 7B draft model on TP(1|8)/70B target model on TP8. This was with the optimizations listed above as "P0".

Hi @cadedaniel, I have tried current main branch to evaluate the acceleration of speculative decoding, but encountered the following assertion error:

class RayGPUExecutor(DistributedGPUExecutor):
def _init_executor(self) -> None:
assert (not self.speculative_config
), "Speculative decoding not yet supported for RayGPU backend."

I'm wondering how the 50% speedup is measured, is there still further pending PRs? And, as the draft-model looks so small (64m-sized), may I know if the 50% speedup is measured with greedy sampling or random sampling?

Thanks!

@cadedaniel
Copy link
Collaborator Author

About the tree-attention/Medusa/Eagle, one of the core implementation will be tree attention mask in flash attention, which is currently not ready. I'd like to bring your attention to it Dao-AILab/flash-attention#924. If anyone would like to contribute to it, it would be great.

@LiuXiaoxuanPKU has more on this

@cadedaniel
Copy link
Collaborator Author

@sighingnow this issue is for getting the 50% speedup. once the P0s are done we will get it with temperature 1.0.

@ChuanhongLi
Copy link

In the Anyscale fork we saw a 50% speedup on bs=8 with a 68m-sized draft model on TP1/70B target model on TP8 and a 7B draft model on TP(1|8)/70B target model on TP8. This was with the optimizations listed above as "P0".

Hi @cadedaniel, I have tried current main branch to evaluate the acceleration of speculative decoding, but encountered the following assertion error:

class RayGPUExecutor(DistributedGPUExecutor):
def _init_executor(self) -> None:
assert (not self.speculative_config
), "Speculative decoding not yet supported for RayGPU backend."

I'm wondering how the 50% speedup is measured, is there still further pending PRs? And, as the draft-model looks so small (64m-sized), may I know if the 50% speedup is measured with greedy sampling or random sampling?

Thanks!

I have met the same problem. Is there a solution? By the way, is there any documentation on how to evaluate the acceleration of speculative decoding? Thanks!

@sighingnow
Copy link
Contributor

@sighingnow this issue is for getting the 50% speedup. once the P0s are done we will get it with temperature 1.0.

May I know more about the accept rate when we get the 50% speedup? Thanks!

@cadedaniel
Copy link
Collaborator Author

cadedaniel commented May 10, 2024

May I know more about the accept rate when we get the 50% speedup? Thanks!

On llama2 7b / llama2 70b, the acceptance rate was like 80% (no fine tuning). we trained a 68m draft model at anyscale that gets ~50% acceptance rate. btw you can run acceptance rate experiments today (I will push a PR tomorrow for TP>1 support)

I have met the same problem. Is there a solution? By the way, is there any documentation on how to evaluate the acceleration of speculative decoding? Thanks!

Thanks @ChuanhongLi -- FYI there is no acceleration yet. we'll share documentation once there is a useful speedup.

@sighingnow
Copy link
Contributor

On llama2 7b / llama2 70b, the acceptance rate was like 80% (no fine tuning). we trained a 68m draft model at anyscale that gets ~50% acceptance rate. btw you can run acceptance rate experiments today (I will push a PR tomorrow for TP>1 support)

Thanks for the information! Looking forward to the complete speculative decoding support!

@ChuanhongLi
Copy link

Thanks for the information! Looking forward to the complete speculative decoding support!

Thanks for your reply!

@caddfa31434
Copy link

I noticed there's a feature request related to Medusa/Eagle at #4669

@Wanglongzhi2001
Copy link

Wanglongzhi2001 commented Jul 3, 2024

On llama2 7b / llama2 70b, the acceptance rate was like 80% (no fine tuning). we trained a 68m draft model at anyscale that gets ~50% acceptance rate. btw you can run acceptance rate experiments today (I will push a PR tomorrow for TP>1 support)

@cadedaniel May I know how you calculated the acceptance rate?On llama2 7b / llama2 70b, this acceptance rate seems a little high but just 50% speedup.

@sighingnow
Copy link
Contributor

P1 (Large) Replace CPU-based batch expansion with multi-query attention kernel call

Hi @cadedaniel @LiuXiaoxuanPKU, I have pushed a multi-query scorer implementation in #6185. Could you please take a look at it and let me know how do you think about it?

Thanks!

@cadedaniel
Copy link
Collaborator Author

Thanks everyone for the help! We hit a 45% latency reduction. Big thanks to @sroy745 @alexm-neuralmagic @comaniac @wooyeonlee0 @zifeitong @LiuXiaoxuanPKU @rkooo567 @ruisearch42 and everyone else who has helped reduced vLLM overheads!

Screenshot 2024-08-05 at 11 00 07 AM

I expect there to be more performance gains once we move the API server outside of the worker, we can re-run evals then.

@alexm-neuralmagic
Copy link
Contributor

@cadedaniel thanks for leading this project!

@sroy745
Copy link
Contributor

sroy745 commented Aug 5, 2024

@cadedaniel Thanks for leading this effort.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed performance Performance-related issues speculative-decoding
Projects
None yet
Development

No branches or pull requests

10 participants