Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inference] add llama mlp for smoothquant #4854

Merged
merged 5 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/inference/quant/smoothquant/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
)

if HAS_TORCH_INT:
from .smoothquant_layer import LLamaSmoothquantAttention
from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP
49 changes: 49 additions & 0 deletions colossalai/inference/quant/smoothquant/models/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from torch_int.functional.quantization import quantize_per_tensor_absmax

try:
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder

smoothquant_cuda = SmoothquantBuilder().load()
HAS_SMOOTHQUANT_CUDA = True
except ImportError:
HAS_SMOOTHQUANT_CUDA = False
raise ImportError("CUDA smoothquant linear is not installed")


class W8A8BFP32O32LinearSiLU(torch.nn.Module):
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features

self.register_buffer(
"weight",
torch.randint(-127, 127, (self.out_features, self.in_features), dtype=torch.int8, requires_grad=False),
)
self.register_buffer("bias", torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False))
self.register_buffer("a", torch.tensor(alpha))

def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.weight = self.weight.to(*args, **kwargs)
self.bias = self.bias.to(*args, **kwargs)
return self

@torch.no_grad()
def forward(self, x):
x_shape = x.shape
x = x.view(-1, x_shape[-1])
y = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1.0)
y = y.view(*x_shape[:-1], -1)
return y

@staticmethod
def from_float(module: torch.nn.Linear, input_scale):
int8_module = W8A8BFP32O32LinearSiLU(module.in_features, module.out_features)
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
alpha = input_scale * weight_scale
int8_module.weight = int8_weight
int8_module.bias.data.copy_(module.bias.to(torch.float))
int8_module.a = alpha
return int8_module
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from torch import nn
from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
from torch_int.nn.linear import W8A8B8O8Linear, W8A8BFP32OFP32Linear
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP

from colossalai.kernel.triton import int8_rotary_embedding_fwd

from .linear import W8A8BFP32O32LinearSiLU


class LLamaSmoothquantAttention(nn.Module):
def __init__(
Expand Down Expand Up @@ -100,7 +102,11 @@ def forward(
self.rotary_output_scale,
)
int8_rotary_embedding_fwd(
key_states.view(-1, self.num_heads, self.head_dim), cos, sin, self.k_output_scale, self.rotary_output_scale
key_states.view(-1, self.num_heads, self.head_dim),
cos,
sin,
self.k_output_scale,
self.rotary_output_scale,
)

if past_key_value is not None:
Expand Down Expand Up @@ -183,3 +189,44 @@ def forward(
attn_output = self.out_proj(attn_output)

return attn_output, attn_probs_reshaped, past_key_value


class LlamaSmoothquantMLP(nn.Module):
def __init__(self, intermediate_size, hidden_size):
super().__init__()
self.gate_proj = W8A8BFP32O32LinearSiLU(hidden_size, intermediate_size)
self.up_proj = W8A8BFP32OFP32Linear(hidden_size, intermediate_size)
self.down_proj = W8A8BFP32OFP32Linear(intermediate_size, hidden_size)
self.down_proj_input_scale = 1.0
self.inter_out_scale = 1.0

def pack(
self,
mlp_module: LlamaMLP,
gate_proj_input_scale: float,
up_proj_input_scale: float,
down_proj_input_scale: float,
):
int8_module = LlamaSmoothquantMLP(
mlp_module.intermediate_size,
mlp_module.hidden_size,
)

int8_module.gate_proj = W8A8BFP32O32LinearSiLU.from_float(mlp_module.gate_proj, gate_proj_input_scale)
int8_module.up_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.up_proj, up_proj_input_scale)
int8_module.down_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.down_proj, down_proj_input_scale)
self.down_proj_input_scale = down_proj_input_scale
Xu-Kai marked this conversation as resolved.
Show resolved Hide resolved
return int8_module

def forward(
self,
hidden_states: torch.Tensor,
):
x_shape = hidden_states.shape
gate_out = self.gate_proj(hidden_states)
up_out = self.up_proj(hidden_states)
inter_out = gate_out * up_out
inter_out = inter_out.div_(self.inter_out_scale).round().clamp(-128, 127).to(torch.int8)
down_out = self.down_proj(inter_out)
down_out = down_out.view(*x_shape[:-1], -1)
return down_out
85 changes: 85 additions & 0 deletions tests/test_smoothquant/test_llama_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import warnings

import pytest
import torch
from packaging import version

try:
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder

smoothquant_cuda = SmoothquantBuilder().load()
HAS_SMOOTHQUANT_CUDA = True
except ImportError:
warnings.warn("CUDA smoothquant linear is not installed")
HAS_SMOOTHQUANT_CUDA = False

from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP

try:
from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP
Xu-Kai marked this conversation as resolved.
Show resolved Hide resolved

HAS_TORCH_INT = True
except ImportError:
HAS_TORCH_INT = False
warnings.warn("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")


CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")


def torch_llama_mlp(gate_proj, up_proj, down_proj, x):
gate_out = torch.mm(x, gate_proj)
silu = torch.nn.SiLU()
gate_out = silu(gate_out)
up_out = torch.mm(x, up_proj)

o_out = gate_out * up_out

max_up = torch.max(torch.abs(o_out))
min_up = torch.min(torch.abs(o_out))

torch_out = torch.mm(o_out, down_proj)

return (torch_out, max_up, min_up)


@pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_SMOOTHQUANT_CUDA or not HAS_TORCH_INT,
reason="smoothquant linear not installed properly or not install torch_int",
)
def test_linear():
hidden_size = 256
intermediate_size = 512

smooth_mlp = LlamaSmoothquantMLP(intermediate_size, hidden_size)

smooth_mlp.gate_proj.weight = torch.ones((intermediate_size, hidden_size), dtype=torch.int8, device="cuda")

smooth_mlp.up_proj.weight = torch.randint(
-10, 10, (intermediate_size, hidden_size), dtype=torch.int8, device="cuda"
)
smooth_mlp.down_proj.weight = torch.randint(
-10, 10, (hidden_size, intermediate_size), dtype=torch.int8, device="cuda"
)

x = torch.ones((1, 256), dtype=torch.int8, device="cuda")

torch_out, max_inter, min_inter = torch_llama_mlp(
smooth_mlp.gate_proj.weight.transpose(0, 1).to(torch.float) / hidden_size,
smooth_mlp.up_proj.weight.transpose(0, 1).to(torch.float) / 127,
smooth_mlp.down_proj.weight.transpose(0, 1).to(torch.float) / 127,
x.to(torch.float),
)

smooth_mlp.inter_out_scale = max_inter.item() / 127
smooth_mlp.gate_proj.a = torch.tensor(1 / hidden_size)
smooth_mlp.up_proj.a = torch.tensor(1 / 127)
smooth_mlp.down_proj.a = torch.tensor(1 / 127 * (max_inter.item() / 127))

smooth_out = smooth_mlp(x)

assert torch.allclose(torch_out, smooth_out, rtol=1e-02, atol=1e-01)


if __name__ == "__main__":
test_linear()
Loading