Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support W8A8 inference in vllm #1508

Closed
wants to merge 110 commits into from
Closed

Support W8A8 inference in vllm #1508

wants to merge 110 commits into from

Conversation

AniZpZ
Copy link

@AniZpZ AniZpZ commented Oct 30, 2023

We have implemented W8A8 inference in vLLM, which can achieve a 30% improvement in throughput. W4A16 quantization methods require weights to be dequantized into fp16 before compute and lead to a throughput drop under heavier load. This PR is part of #1112. We have split the huge PR into two independent parts for easier review.
The usage of w8a8 inference is simple(support llama for now):

python ./vllm/entrypoints/api_server.py --model=/path/to/quantized/model --tokenizer=/path/to/tokenizer --max-num-batched-tokens=70000 --block-size=16 --swap-space=20 --quantization smoothquant

17th Jan 2024 Updates!!!
We release a repository for implementing SmoothQuant in various models. You can try out export quantized model weights with this repo AutoSmoothQuant

CUDA Graph is not compatible for now~

Updates!!!
We have update the quant method to per token quant for o_proj and down_proj of LLama. Please use the lastest llama-dev branch of smoothquant and per_token_quant branch of torch-int to generate int8 model!!!

You can find more details like how to gernerate int8 weight in original PR #1112
You can use the method with int8 kv cache quant #1507 for best throughput

@HandH1998
Copy link
Contributor

@shiqingzhangCSU, I don't think it can help, as our .cu files all get the current stream like what awq and gpt do. But thank you all the same.
image

@HandH1998
Copy link
Contributor

@AniZpZ @HandH1998 I would like to test SmoothQuant, do you have branch that merges both W8A8 and KV cache quant?

Hi, you can try this branch https://github.com/AniZpZ/vllm/tree/vllmq.

@Hongbosherlock
Copy link

Hi @AniZpZ @HandH1998 , thanks for your great work. I'm curious about the difference between W8A8BFP32OFP32LinearWithSFactor and W8A8BFP32OFP32Linear

@HandH1998
Copy link
Contributor

@Hongbosherlock, W8A8BFP32OFP32Linear is applied in qkv_proj and gate_up_proj, W8A8BFP32OFP32LinearWithSFactor is applied in out_proj and down_proj. We fused quant_scale in LayerNorm before qkv_proj and gate_up_proj, so the two layers use W8A8BFP32OFP32Linear which has no the scale factor. For out_porj and gate_up_proj, they have to use W8A8BFP32OFP32LinearWithSFactor to carry the scales themselves.

@Hongbosherlock
Copy link

@Hongbosherlock, W8A8BFP32OFP32Linear is applied in qkv_proj and gate_up_proj, W8A8BFP32OFP32LinearWithSFactor is applied in out_proj and down_proj. We fused quant_scale in LayerNorm before qkv_proj and gate_up_proj, so the two layers use W8A8BFP32OFP32Linear which has no the scale factor. For out_porj and gate_up_proj, they have to use W8A8BFP32OFP32LinearWithSFactor to carry the scales themselves.

thanks, I have sent you an email, pls check it.

@Hongbosherlock
Copy link

Hongbosherlock commented Feb 3, 2024

@Hongbosherlock, W8A8BFP32OFP32Linear is applied in qkv_proj and gate_up_proj, W8A8BFP32OFP32LinearWithSFactor is applied in out_proj and down_proj. We fused quant_scale in LayerNorm before qkv_proj and gate_up_proj, so the two layers use W8A8BFP32OFP32Linear which has no the scale factor. For out_porj and gate_up_proj, they have to use W8A8BFP32OFP32LinearWithSFactor to carry the scales themselves.

by the way, in AWQ they smooth the o_proj and down_proj here for llama before quantization :
https://github.com/mit-han-lab/llm-awq/blob/main/awq/quantize/auto_scale.py#L202

# attn out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
    scales_list.append(_auto_get_scale(
        prev_op=module.self_attn.v_proj,
        layers=[module.self_attn.o_proj],
         inp=input_feat['self_attn.o_proj'],
))

but in Smoothquant(llama-dev) , it is not the same as that, what is the consideration behind this?

@MingLin-home
Copy link

This is really a nice work!

I am trying to re-produce the W8A8 results. However, after using AutoSmoothQuant to export int8 llama-2-7b model, modifying the "architectures" in config.json from "Int8LlamaForCausalLM" to "LlamaForCausalLM", I load the model:

llm = LLM(model="~/models/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-smoothquant",
          tokenizer="~/models/Llama-2-7b-chat-hf",
          quantization="smoothquant",
          max_context_len_to_capture=256,
          tensor_parallel_size=1)

in vLLM. I get this error:

File "..... /vllm/model_executor/models/llama.py", line 449, in load_weights
    param = params_dict[name]
KeyError: 'model.layers.0.mlp.down_proj.quant_scale'

I manually checked the params_dict and cannot find this "quant_scale". Any idea how to fix? Many thanks ahead!

@HandH1998
Copy link
Contributor

@MingLin-home What is your quant_config.json when running AutoSmoothQuant? It should be { "qkv": "per-tensor", "out": "per-token", "fc1": "per-tensor", "fc2": "per-token" }, as the current w8a8 in vllm only supports this config.

@MingLin-home
Copy link

{ "qkv": "per-tensor", "out": "per-token", "fc1": "per-tensor", "fc2": "per-token" }

Thanks for the quick reply! Using this quant config, the error is gone. Great work!

@MingLin-home
Copy link

@AniZpZ We really like this PR! Do you have any plan on re-implement in triton? We are willing to re-implement a triton version as it is more user-friendly. Many thanks ahead!

@AniZpZ
Copy link
Author

AniZpZ commented Feb 21, 2024

We really like this PR! Do you have any plan on re-implement in triton? We are willing to re-implement a triton version as it is more user-friendly. Many thanks ahead!

We currently do not have plans to implement a Triton version. However, we do have a repository at https://github.com/AniZpZ/AutoSmoothQuant that facilitates easy quantization for LLMs. We would greatly appreciate it if you could reimplement a Triton version and contribute to our repo!

@MingLin-home
Copy link

We really like this PR! Do you have any plan on re-implement in triton? We are willing to re-implement a triton version as it is more user-friendly. Many thanks ahead!

We currently do not have plans to implement a Triton version. However, we do have a repository at https://github.com/AniZpZ/AutoSmoothQuant that facilitates easy quantization for LLMs. We would greatly appreciate it if you could reimplement a Triton version and contribute to our repo!

Thanks @AniZpZ ! We are using AutoSmoothQuant to quantize the model. Will keep you posted on our triton project. Please kindly let me know if you are aware of similar efforts such that we can join the development together.

@MingLin-home
Copy link

MingLin-home commented Feb 29, 2024

========== update ==========
I disable the cuda graph in vllm ("enforce_eager=True"). The output text is normal now.

========== old context =======
@AniZpZ sorry for bugging you again!

I was able to convert the Llama-2-7b model to int8 via AutoSmoothQuant. Loading and inference in vLLM without any RuntimeError too. However, I checked the generated text and it is filled with random meaningless tokens. I further confirm that the fp16 original model file works correctly, generating meaningful text.

To verify that the converted model via AutoSmoothQuant is good, I test the model in AutoSmoothQuant's test_model.py. The output is human readable.

In short, it looks like the vLLM-w8a8 model loading and/or inference are not working correctly.

@nivibilla nivibilla mentioned this pull request Mar 5, 2024
3 tasks
@xyfZzz
Copy link

xyfZzz commented Mar 7, 2024

@AniZpZ Hi, when I used dynamic_rope in the w8a8 branch to extend the length of the model after smoothquant, the following error occurred. How should I solve it?

File "/app/apps/anaconda3/envs/vllmw8a8/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl 
    return forward_call(*args, **kwargs) 
  File "/app/xie/code/w8a8/vllm/vllm/model_executor/models/llama.py", line 276, in forward 
    hidden_states, scale = self.self_attn( 
  File "/app/apps/anaconda3/envs/vllmw8a8/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_im 
pl 
    return self._call_impl(*args, **kwargs) 
  File "/app/apps/anaconda3/envs/vllmw8a8/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl 
    return forward_call(*args, **kwargs) 
  File "/app/xie/code/w8a8/vllm/vllm/model_executor/models/llama.py", line 197, in forward 
    q, k, v = self.rotary_emb(positions, q, k, v, q_dequant_scale, 
  File "/app/apps/anaconda3/envs/vllmw8a8/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_im 
pl 
    return self._call_impl(*args, **kwargs) 
  File "/app/apps/anaconda3/envs/vllmw8a8/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl 
    return forward_call(*args, **kwargs) 
TypeError: forward() takes 6 positional arguments but 8 were given

@AniZpZ
Copy link
Author

AniZpZ commented Mar 8, 2024

@AniZpZ Hi, when I used dynamic_rope in the w8a8 branch to extend the length of the model after smoothquant, the following error occurred. How should I solve it?

File "/app/apps/anaconda3/envs/vllmw8a8/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl 
    return forward_call(*args, **kwargs) 
  File "/app/xie/code/w8a8/vllm/vllm/model_executor/models/llama.py", line 276, in forward 
    hidden_states, scale = self.self_attn( 
  File "/app/apps/anaconda3/envs/vllmw8a8/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_im 
pl 
    return self._call_impl(*args, **kwargs) 
  File "/app/apps/anaconda3/envs/vllmw8a8/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl 
    return forward_call(*args, **kwargs) 
  File "/app/xie/code/w8a8/vllm/vllm/model_executor/models/llama.py", line 197, in forward 
    q, k, v = self.rotary_emb(positions, q, k, v, q_dequant_scale, 
  File "/app/apps/anaconda3/envs/vllmw8a8/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_im 
pl 
    return self._call_impl(*args, **kwargs) 
  File "/app/apps/anaconda3/envs/vllmw8a8/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl 
    return forward_call(*args, **kwargs) 
TypeError: forward() takes 6 positional arguments but 8 were given

It seems that you should modify dynamic_rope to accept extra args for dequantization operation.

@tlrmchlsmth
Copy link
Collaborator

Hi @AniZpZ, I'm starting to work on activation quantization (with a couple of other engineers at @neuralmagic) and wanted to point you to this RFC that I just posted #3975. We've been working with this PR as a starting point so are interested in getting your feedback and collaborating if you're interested.

@zhyncs
Copy link
Contributor

zhyncs commented Apr 24, 2024

Hi @AniZpZ, I'm starting to work on activation quantization (with a couple of other engineers at @neuralmagic) and wanted to point you to this RFC that I just posted #3975. We've been working with this PR as a starting point so are interested in getting your feedback and collaborating if you're interested.

Hi @tlrmchlsmth Our team proposed this PR very early last year. Later, for the convenience of review, the KV Cache Int8 and W8A8 were split into two PRs. But it seems that the vLLM team's focus are not on this. The review and merging progress is very slow. We will continue to make new quantitative attempts such as W4A8 on vLLM, but at present, our development of LLM Serving for production environment has been migrated to LMDeploy.

@babaozhouy5
Copy link

@shiqingzhangCSU, I don't think it can help, as our .cu files all get the current stream like what awq and gpt do. But thank you all the same. image

If bother anyone forgive me.
@AniZpZ Thanks for your great work! I thought maybe the reason why can not enable cuda graph is the cuda stream of cublasINT8MMWrapper is acquired at ctor:
image
when acquired at forward process for example:
image
I can enable cuda graph normally (not need --enforce-eager option)

@HandH1998
Copy link
Contributor

@babaozhouy5 Thank you for addressing the issue. I think I understand how you fix it. Change the cudastream parameter from the class cublasINT8MMWrapper's parameter to the function Gemm_'s parameter, so that the GEMM can capture cudastream dynamically. We will appreciate it if you can provide a PR when you have free time.

@babaozhouy5
Copy link

@babaozhouy5 Thank you for addressing the issue. I think I understand how you fix it. Change the cudastream parameter from the class cublasINT8MMWrapper's parameter to the function Gemm_'s parameter, so that the GEMM can capture cudastream dynamically. We will appreciate it if you can provide a PR when you have free time.

Sure!

@xinyinan9527
Copy link

xinyinan9527 commented Jul 30, 2024

您好,我正在使用您的方法进行int8量化,是mixtral模型,但是会提示
argument --quantization/-q: invalid choice: 'smoothquant' (choose from 'aqlm', 'awq', 'deepspeedfp', 'fp8', 'marlin', 'gptq_marlin_24', 'gptq_marlin', 'gptq', 'squeezellm', 'sparseml', None)

我使用的最新版本vllm。请问如何解决呢

@mgoin
Copy link
Sponsor Collaborator

mgoin commented Aug 29, 2024

Closing this PR as vLLM has supported INT8 W8A8 with custom CUTLASS kernels for a while now. See the documentation for pointers on how to find or make INT8 models! https://docs.vllm.ai/en/v0.5.5/quantization/int8.html

@mgoin mgoin closed this Aug 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.