forked from hpcaitech/ColossalAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[inference] add llama mlp for smoothquant (hpcaitech#4854)
* add llama mlp for smoothquant * fix down out scale * remove duplicate lines * add llama mlp check * delete useless code
- Loading branch information
Showing
4 changed files
with
182 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import torch | ||
from torch_int.functional.quantization import quantize_per_tensor_absmax | ||
|
||
try: | ||
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder | ||
|
||
smoothquant_cuda = SmoothquantBuilder().load() | ||
HAS_SMOOTHQUANT_CUDA = True | ||
except ImportError: | ||
HAS_SMOOTHQUANT_CUDA = False | ||
raise ImportError("CUDA smoothquant linear is not installed") | ||
|
||
|
||
class W8A8BFP32O32LinearSiLU(torch.nn.Module): | ||
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): | ||
super().__init__() | ||
self.in_features = in_features | ||
self.out_features = out_features | ||
|
||
self.register_buffer( | ||
"weight", | ||
torch.randint(-127, 127, (self.out_features, self.in_features), dtype=torch.int8, requires_grad=False), | ||
) | ||
self.register_buffer("bias", torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False)) | ||
self.register_buffer("a", torch.tensor(alpha)) | ||
|
||
def to(self, *args, **kwargs): | ||
super().to(*args, **kwargs) | ||
self.weight = self.weight.to(*args, **kwargs) | ||
self.bias = self.bias.to(*args, **kwargs) | ||
return self | ||
|
||
@torch.no_grad() | ||
def forward(self, x): | ||
x_shape = x.shape | ||
x = x.view(-1, x_shape[-1]) | ||
y = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1.0) | ||
y = y.view(*x_shape[:-1], -1) | ||
return y | ||
|
||
@staticmethod | ||
def from_float(module: torch.nn.Linear, input_scale): | ||
int8_module = W8A8BFP32O32LinearSiLU(module.in_features, module.out_features) | ||
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) | ||
alpha = input_scale * weight_scale | ||
int8_module.weight = int8_weight | ||
int8_module.bias.data.copy_(module.bias.to(torch.float)) | ||
int8_module.a = alpha | ||
return int8_module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import warnings | ||
|
||
import pytest | ||
import torch | ||
from packaging import version | ||
|
||
try: | ||
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder | ||
|
||
smoothquant_cuda = SmoothquantBuilder().load() | ||
HAS_SMOOTHQUANT_CUDA = True | ||
except: | ||
warnings.warn("CUDA smoothquant linear is not installed") | ||
HAS_SMOOTHQUANT_CUDA = False | ||
|
||
|
||
try: | ||
from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP | ||
|
||
HAS_TORCH_INT = True | ||
except: | ||
HAS_TORCH_INT = False | ||
warnings.warn("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int") | ||
|
||
|
||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") | ||
|
||
|
||
def torch_llama_mlp(gate_proj, up_proj, down_proj, x): | ||
gate_out = torch.mm(x, gate_proj) | ||
silu = torch.nn.SiLU() | ||
gate_out = silu(gate_out) | ||
up_out = torch.mm(x, up_proj) | ||
|
||
o_out = gate_out * up_out | ||
|
||
max_up = torch.max(torch.abs(o_out)) | ||
min_up = torch.min(torch.abs(o_out)) | ||
|
||
torch_out = torch.mm(o_out, down_proj) | ||
|
||
return (torch_out, max_up, min_up) | ||
|
||
|
||
@pytest.mark.skipif( | ||
not CUDA_SUPPORT or not HAS_SMOOTHQUANT_CUDA or not HAS_TORCH_INT, | ||
reason="smoothquant linear not installed properly or not install torch_int", | ||
) | ||
def test_llama_mlp(): | ||
hidden_size = 256 | ||
intermediate_size = 512 | ||
|
||
smooth_mlp = LlamaSmoothquantMLP(intermediate_size, hidden_size) | ||
|
||
smooth_mlp.gate_proj.weight = torch.ones((intermediate_size, hidden_size), dtype=torch.int8, device="cuda") | ||
|
||
smooth_mlp.up_proj.weight = torch.randint( | ||
-10, 10, (intermediate_size, hidden_size), dtype=torch.int8, device="cuda" | ||
) | ||
smooth_mlp.down_proj.weight = torch.randint( | ||
-10, 10, (hidden_size, intermediate_size), dtype=torch.int8, device="cuda" | ||
) | ||
|
||
x = torch.ones((1, 256), dtype=torch.int8, device="cuda") | ||
|
||
torch_out, max_inter, min_inter = torch_llama_mlp( | ||
smooth_mlp.gate_proj.weight.transpose(0, 1).to(torch.float) / hidden_size, | ||
smooth_mlp.up_proj.weight.transpose(0, 1).to(torch.float) / 127, | ||
smooth_mlp.down_proj.weight.transpose(0, 1).to(torch.float) / 127, | ||
x.to(torch.float), | ||
) | ||
|
||
smooth_mlp.down_proj_input_scale = max_inter.item() / 127 | ||
smooth_mlp.gate_proj.a = torch.tensor(1 / hidden_size) | ||
smooth_mlp.up_proj.a = torch.tensor(1 / 127) | ||
smooth_mlp.down_proj.a = torch.tensor(1 / 127 * (max_inter.item() / 127)) | ||
|
||
smooth_out = smooth_mlp(x) | ||
|
||
assert torch.allclose(torch_out, smooth_out, rtol=1e-02, atol=1e-01) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_llama_mlp() |