Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support int8 KVCache Quant in Vllm #1507

Open
wants to merge 57 commits into
base: main
Choose a base branch
from
Open

Conversation

AniZpZ
Copy link

@AniZpZ AniZpZ commented Oct 30, 2023

Quantization for kv cache can lift the throughput with minimal loss in model performance. We impelement int8 kv cache quantization which can achieve a 15% throughput improvement. This pr is part of #1112. We spilted the huge PR into 2 independent parts for easier review.
The usage of int8 KV Cache quant is simple:

python ./vllm/entrypoints/api_server.py --model=/path/to/quantized/model --tokenizer=/path/to/tokenizer --max-num-batched-tokens=70000 --block-size=16 --swap-space=20 --kv-cache-dtype=int8 --kv-quant-params-path=/path/to/kv_params_dir

The loss in model performance is minor . The following data is our experiemnt result on mmlu dataset

model STEM Social Sciences Humanities Other Average
fp16 0.4056 0.4965 0.4750 0.4855 0.4657
fp16 with int8kv 0.4037 0.4956 0.4748 0.4849 0.4648

You can find more details like how to gernerate KV cache scales in original PR #1112
You can use the method with w8a8 inference #1508 for best throughput

@ZiyueHuang
Copy link

Hi,

Thanks for the wonderful PR, and I implement the FP8 (e5m2 / e3m4) KVCache quantization based on this PR at this branch, which does not require any calibration (as for the case of INT8).

  • Throughput (ShareGPT on V100-32G): token/s increases by 62% for Qwen-14B, and 6% for Qwen-7B.

  • Accuracy: same accuracy for Qwen-7B on 5-shot MMLU-STEM, both achieved 42.02. I also tried a few hand-crafted prompts (including tool usage), and didn't observe any quality degradation on the generated text.

cc @zhuohan123 @WoosukKwon

@AniZpZ
Copy link
Author

AniZpZ commented Nov 3, 2023

Hi,

Thanks for the wonderful PR, and I implement the FP8 (e5m2 / e3m4) KVCache quantization based on this PR at this branch, which does not require any calibration (as for the case of INT8).

  • Throughput (ShareGPT on V100-32G): token/s increases by 62% for Qwen-14B, and 6% for Qwen-7B.
  • Accuracy: same accuracy for Qwen-7B on 5-shot MMLU-STEM, both achieved 42.02. I also tried a few hand-crafted prompts (including tool usage), and didn't observe any quality degradation on the generated text.

cc @zhuohan123 @WoosukKwon

Awesome work!
I found that you directly cast fp32 to fp8 without scaling. Have you observed the kv cache data distribution and found that the data can be represented in fp8 without loss, leading to that choice?

@ZiyueHuang
Copy link

ZiyueHuang commented Nov 3, 2023

FP8 e5m2 has the same dynamic range with fp16, so I think it should be safe to cast, while FP8 e4m3 has more precision but may cost more cycles than e5m2 when casting from fp16. Meanwhile, the casting implementation seems to be done in a similar way with cuda native satfinite instruction (clamping into the range of the destination format at first, IIUC).

This idea originates from two parts: 1) Recently I tried fp16 & fp32 & bf16 mixed precision training on V100 (fp32 <--> bf16 is roughly similar to fp16 <--> fp8 at high level). More analysis and details are posted here. 2) I also bumped into this idea somewhere (a blog, but no implementation provided. I couldn't find the exact reference now).

@zhaoyang-star
Copy link
Contributor

I also implement the FP8 (e5m2 / e3m4) KVCache quantization. And analysis show the kv cache data distribution is ok for FP8.
image

Above is the kv cache data distribution in finetuned WizardCoder.

Dataset Baseline(KV Cache FP16) KV Cache FP8 E5M2 KV Cache FP8 E4M3
HumanEval-Python-EN 68.293% 65.854% (↓ 2.439%) 67.683% (↓ 0.61%)
HumanEval-Python-CN 59.146% 59.146% (=) 59.756% (↑ 0.61%)

I think we could both support int8 and fp8 kv cache. FYI TRT-LLM also support both quant methods for kv cache.

@ZiyueHuang
Copy link

I also implement the FP8 (e5m2 / e3m4) KVCache quantization. And analysis show the kv cache data distribution is ok for FP8. image

Above is the kv cache data distribution in finetuned WizardCoder.

Dataset Baseline(KV Cache FP16) KV Cache FP8 E5M2 KV Cache FP8 E4M3
HumanEval-Python-EN 68.293% 65.854% (↓ 2.439%) 67.683% (↓ 0.61%)
HumanEval-Python-CN 59.146% 59.146% (=) 59.756% (↑ 0.61%)
I think we could both support int8 and fp8 kv cache. FYI TRT-LLM also support both quant methods for kv cache.

Thanks for sharing the experiments and the benchmarks. How to understand the plot (e.g., x-axis and y-axis)?

@zhaoyang-star
Copy link
Contributor

zhaoyang-star commented Nov 8, 2023

@ZiyueHuang I dump the kv cache of each transformer layer and get the mean data among all transformer layers.

The prompt length is 14,generation tokens' length is 65. So the total sequence length is 14+65=79. The hidden dims is 6144, head_num=48,kv_dims=6144/48=128. So the dim of each layer's key/value cache is [b, seq_len, 128].

The x-axis is the kv dim ([0, 63] is the key cache and [64, 127] is the value cache) and the y-axis is the sequence length.
The key cache data has several outliers while value cache data is more smooth. So we can say value cache is more easy to quantize and will almost have no side-affect on loss of accuracy, while quantize key cache data will loss of accuracy.

Note that the kv cache data distribution is related to model.

attn_dtype tgt_value = __ldg(&value[src_value_idx]);
value_cache[tgt_value_idx] = quant(tgt_value, v_scale, v_zp);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are lot of redundant codes in reshape_and_cache_quantized_kernel compared reshape_and_cache__kernel. Is it better to merge to one function?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have merged the functions and eliminated the redundant code.

v_scale,
v_zp);
});
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better to merge reshape_and_cache_quantized to reshape_and_cache?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have merged reshape_and_cache_quantized to reshape_and_cache

Copy link
Contributor

@zhaoyang-star zhaoyang-star left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Several comments left.

@AniZpZ
Copy link
Author

AniZpZ commented Nov 8, 2023

Great work! Several comments left.

Thank you for taking the time to review our code. We will take your advice and respond as soon as possible.

@gesanqiu
Copy link
Contributor

gesanqiu commented Nov 10, 2023

Thanks for your great work! Met compile errors, left some comments here, but seems need more modify, I can't run kv_int8 model with this PR(turn to #1112 finally).

Exception in callback functools.partial(<function _raise_exception_on_finish at 0x7fe75458d120>, request_tracker=<vllm.engine.async_llm_engine.RequestTracker object at 0x7fe6c149b820>)
handle: <Handle functools.partial(<function _raise_exception_on_finish at 0x7fe75458d120>, request_tracker=<vllm.engine.async_llm_engine.RequestTracker object at 0x7fe6c149b820>)>
Traceback (most recent call last):
  File "/workdir/vllm_smoothquant/vllm/vllm/engine/async_llm_engine.py", line 28, in _raise_exception_on_finish
    task.result()
  File "/workdir/vllm_smoothquant/vllm/vllm/engine/async_llm_engine.py", line 351, in run_engine_loop
    has_requests_in_progress = await self.engine_step()
  File "/workdir/vllm_smoothquant/vllm/vllm/engine/async_llm_engine.py", line 330, in engine_step
    request_outputs = await self.engine.step_async()
  File "/workdir/vllm_smoothquant/vllm/vllm/engine/async_llm_engine.py", line 191, in step_async
    output = await self._run_workers_async(
  File "/workdir/vllm_smoothquant/vllm/vllm/engine/async_llm_engine.py", line 216, in _run_workers_async
    output = executor(*args, **kwargs)
  File "/root/anaconda3/envs/vllm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workdir/vllm_smoothquant/vllm/vllm/worker/worker.py", line 369, in execute_model
    output = self.model(
  File "/root/anaconda3/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workdir/vllm_smoothquant/vllm/vllm/model_executor/models/llama.py", line 310, in forward
    hidden_states = self.model(input_ids, positions, kv_caches,
  File "/root/anaconda3/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workdir/vllm_smoothquant/vllm/vllm/model_executor/models/llama.py", line 268, in forward
    hidden_states = layer(
  File "/root/anaconda3/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workdir/vllm_smoothquant/vllm/vllm/model_executor/models/llama.py", line 211, in forward
    hidden_states = self.self_attn(
  File "/root/anaconda3/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workdir/vllm_smoothquant/vllm/vllm/model_executor/models/llama.py", line 159, in forward
    attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
  File "/root/anaconda3/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workdir/vllm_smoothquant/vllm/vllm/model_executor/layers/attention.py", line 405, in forward
    return super().forward(
  File "/workdir/vllm_smoothquant/vllm/vllm/model_executor/layers/attention.py", line 296, in forward
    cache_ops.reshape_and_cache_quantized(
RuntimeError: expected scalar type Int but found Long

My application scenarios is long context(~1000 tokens) but short outputs(~15 tokens), so KV Cache quantize may not help with me, I didn't see any improvement in prefilling or decoding phase. Any ideas about this? Hoping to use W8A16 asap, thanks.

EDIT: FYI, prefilling and decoding latency of KV_INT8 model is closely to FP16 model with small batch. In my test(A40 * 1), KV_INT8 got ~40% throughout improvement, W8A8 got ~45% throughout improvement, W8A8+KV_INT8 ~100% throughout improvement......Amazing, can't believe this.

@HandH1998
Copy link
Contributor

Now int8 kv cache only supports llama. It is better to make it explict in int8_kv_cache.rst.

Fixed.

Copy link
Contributor

@zhaoyang-star zhaoyang-star left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall now the pr is lgtm. @zhuohan123 could you please take time to review?

@AniZpZ
Copy link
Author

AniZpZ commented Feb 20, 2024

Overall now the pr is lgtm. @zhuohan123 could you please take time to review?

Hi! @zhuohan123 @WoosukKwon We think this PR is ready. Could you please take some time to review it? Much appreciated!

@zhuohan123 zhuohan123 mentioned this pull request Mar 2, 2024
3 tasks
@zhaoyang-star
Copy link
Contributor

@AniZpZ Could you please solve the conflicts? We are plannig to review it again and merge it if all is OK.

@zhuohan123 zhuohan123 self-assigned this Mar 25, 2024
@AniZpZ
Copy link
Author

AniZpZ commented Mar 25, 2024

@AniZpZ Could you please solve the conflicts? We are plannig to review it again and merge it if all is OK.

Sure, I will sovle the conficts.

@AniZpZ
Copy link
Author

AniZpZ commented Mar 26, 2024

@AniZpZ Could you please solve the conflicts? We are plannig to review it again and merge it if all is OK.

@zhuohan123 @zhaoyang-star Hi!The conflicts have been solved now.

@zhaoyang-star
Copy link
Contributor

LGTM

@hmellor hmellor removed the v0.3.4 label Apr 20, 2024
@hmellor
Copy link
Collaborator

hmellor commented Apr 20, 2024

@zhuohan123 this was included in the release tracker for 4.0.0, but ended up not being merged in time. Should it be added to the new release tracker?

@hikq123
Copy link

hikq123 commented Jun 24, 2024

Hi, will you rebase your code on vllm 0.5.0?

4 similar comments
@hikq123
Copy link

hikq123 commented Jun 24, 2024

Hi, will you rebase your code on vllm 0.5.0?

@hikq123
Copy link

hikq123 commented Jun 24, 2024

Hi, will you rebase your code on vllm 0.5.0?

@hikq123
Copy link

hikq123 commented Jun 24, 2024

Hi, will you rebase your code on vllm 0.5.0?

@hikq123
Copy link

hikq123 commented Jun 24, 2024

Hi, will you rebase your code on vllm 0.5.0?

@hmellor
Copy link
Collaborator

hmellor commented Aug 2, 2024

@AniZpZ @simon-mo @WoosukKwon

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants