Skip to content

Commit

Permalink
[inference] add llama mlp for smoothquant (hpcaitech#4854)
Browse files Browse the repository at this point in the history
* add llama mlp for smoothquant

* fix down out scale

* remove duplicate lines

* add llama mlp check

* delete useless code
  • Loading branch information
Xu-Kai committed Oct 13, 2023
1 parent cea8b38 commit 1d1491e
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 3 deletions.
2 changes: 1 addition & 1 deletion colossalai/inference/quant/smoothquant/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
)

if HAS_TORCH_INT:
from .smoothquant_layer import LLamaSmoothquantAttention
from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP
49 changes: 49 additions & 0 deletions colossalai/inference/quant/smoothquant/models/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from torch_int.functional.quantization import quantize_per_tensor_absmax

try:
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder

smoothquant_cuda = SmoothquantBuilder().load()
HAS_SMOOTHQUANT_CUDA = True
except ImportError:
HAS_SMOOTHQUANT_CUDA = False
raise ImportError("CUDA smoothquant linear is not installed")


class W8A8BFP32O32LinearSiLU(torch.nn.Module):
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features

self.register_buffer(
"weight",
torch.randint(-127, 127, (self.out_features, self.in_features), dtype=torch.int8, requires_grad=False),
)
self.register_buffer("bias", torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False))
self.register_buffer("a", torch.tensor(alpha))

def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.weight = self.weight.to(*args, **kwargs)
self.bias = self.bias.to(*args, **kwargs)
return self

@torch.no_grad()
def forward(self, x):
x_shape = x.shape
x = x.view(-1, x_shape[-1])
y = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1.0)
y = y.view(*x_shape[:-1], -1)
return y

@staticmethod
def from_float(module: torch.nn.Linear, input_scale):
int8_module = W8A8BFP32O32LinearSiLU(module.in_features, module.out_features)
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
alpha = input_scale * weight_scale
int8_module.weight = int8_weight
int8_module.bias.data.copy_(module.bias.to(torch.float))
int8_module.a = alpha
return int8_module
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from torch import nn
from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
from torch_int.nn.linear import W8A8B8O8Linear, W8A8BFP32OFP32Linear
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP

from colossalai.kernel.triton import int8_rotary_embedding_fwd

from .linear import W8A8BFP32O32LinearSiLU


class LLamaSmoothquantAttention(nn.Module):
def __init__(
Expand Down Expand Up @@ -100,7 +102,11 @@ def forward(
self.rotary_output_scale,
)
int8_rotary_embedding_fwd(
key_states.view(-1, self.num_heads, self.head_dim), cos, sin, self.k_output_scale, self.rotary_output_scale
key_states.view(-1, self.num_heads, self.head_dim),
cos,
sin,
self.k_output_scale,
self.rotary_output_scale,
)

if past_key_value is not None:
Expand Down Expand Up @@ -183,3 +189,43 @@ def forward(
attn_output = self.out_proj(attn_output)

return attn_output, attn_probs_reshaped, past_key_value


class LlamaSmoothquantMLP(nn.Module):
def __init__(self, intermediate_size, hidden_size):
super().__init__()
self.gate_proj = W8A8BFP32O32LinearSiLU(hidden_size, intermediate_size)
self.up_proj = W8A8BFP32OFP32Linear(hidden_size, intermediate_size)
self.down_proj = W8A8BFP32OFP32Linear(intermediate_size, hidden_size)
self.down_proj_input_scale = 1.0

def pack(
self,
mlp_module: LlamaMLP,
gate_proj_input_scale: float,
up_proj_input_scale: float,
down_proj_input_scale: float,
):
int8_module = LlamaSmoothquantMLP(
mlp_module.intermediate_size,
mlp_module.hidden_size,
)

int8_module.gate_proj = W8A8BFP32O32LinearSiLU.from_float(mlp_module.gate_proj, gate_proj_input_scale)
int8_module.up_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.up_proj, up_proj_input_scale)
int8_module.down_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.down_proj, down_proj_input_scale)
self.down_proj_input_scale = down_proj_input_scale
return int8_module

def forward(
self,
hidden_states: torch.Tensor,
):
x_shape = hidden_states.shape
gate_out = self.gate_proj(hidden_states)
up_out = self.up_proj(hidden_states)
inter_out = gate_out * up_out
inter_out = inter_out.div_(self.down_proj_input_scale).round().clamp(-128, 127).to(torch.int8)
down_out = self.down_proj(inter_out)
down_out = down_out.view(*x_shape[:-1], -1)
return down_out
84 changes: 84 additions & 0 deletions tests/test_smoothquant/test_llama_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import warnings

import pytest
import torch
from packaging import version

try:
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder

smoothquant_cuda = SmoothquantBuilder().load()
HAS_SMOOTHQUANT_CUDA = True
except:
warnings.warn("CUDA smoothquant linear is not installed")
HAS_SMOOTHQUANT_CUDA = False


try:
from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP

HAS_TORCH_INT = True
except:
HAS_TORCH_INT = False
warnings.warn("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")


CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")


def torch_llama_mlp(gate_proj, up_proj, down_proj, x):
gate_out = torch.mm(x, gate_proj)
silu = torch.nn.SiLU()
gate_out = silu(gate_out)
up_out = torch.mm(x, up_proj)

o_out = gate_out * up_out

max_up = torch.max(torch.abs(o_out))
min_up = torch.min(torch.abs(o_out))

torch_out = torch.mm(o_out, down_proj)

return (torch_out, max_up, min_up)


@pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_SMOOTHQUANT_CUDA or not HAS_TORCH_INT,
reason="smoothquant linear not installed properly or not install torch_int",
)
def test_llama_mlp():
hidden_size = 256
intermediate_size = 512

smooth_mlp = LlamaSmoothquantMLP(intermediate_size, hidden_size)

smooth_mlp.gate_proj.weight = torch.ones((intermediate_size, hidden_size), dtype=torch.int8, device="cuda")

smooth_mlp.up_proj.weight = torch.randint(
-10, 10, (intermediate_size, hidden_size), dtype=torch.int8, device="cuda"
)
smooth_mlp.down_proj.weight = torch.randint(
-10, 10, (hidden_size, intermediate_size), dtype=torch.int8, device="cuda"
)

x = torch.ones((1, 256), dtype=torch.int8, device="cuda")

torch_out, max_inter, min_inter = torch_llama_mlp(
smooth_mlp.gate_proj.weight.transpose(0, 1).to(torch.float) / hidden_size,
smooth_mlp.up_proj.weight.transpose(0, 1).to(torch.float) / 127,
smooth_mlp.down_proj.weight.transpose(0, 1).to(torch.float) / 127,
x.to(torch.float),
)

smooth_mlp.down_proj_input_scale = max_inter.item() / 127
smooth_mlp.gate_proj.a = torch.tensor(1 / hidden_size)
smooth_mlp.up_proj.a = torch.tensor(1 / 127)
smooth_mlp.down_proj.a = torch.tensor(1 / 127 * (max_inter.item() / 127))

smooth_out = smooth_mlp(x)

assert torch.allclose(torch_out, smooth_out, rtol=1e-02, atol=1e-01)


if __name__ == "__main__":
test_llama_mlp()

0 comments on commit 1d1491e

Please sign in to comment.