Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inference] add smooth function and delete useless code for smoothquant #4895

Merged
merged 4 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 0 additions & 53 deletions colossalai/inference/quant/smoothquant/calibration.py

This file was deleted.

119 changes: 72 additions & 47 deletions colossalai/inference/quant/smoothquant/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,71 +4,31 @@
import types
import warnings
from abc import abstractmethod
from functools import partial
from os.path import isdir, isfile, join
from typing import Dict, List, Optional, Union

import accelerate
import numpy as np
import torch
import torch.nn as nn
import transformers
from safetensors.torch import save_file as safe_save
from torch import device
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.modeling_utils import no_init_weights
from transformers.utils.generic import ContextManagers
from transformers.utils.hub import PushToHubMixin, cached_file

from ....tensor_parallel.batch_infer_state import BatchInferState
from ....tensor_parallel.kvcache_manager import MemoryManager
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager

CPU = device("cpu")
CUDA_0 = device("cuda:0")

SUPPORTED_MODELS = ["llama"]


def get_module_by_name_suffix(model, module_name: str):
for name, module in model.named_modules():
if name.endswith(module_name):
return module


def simple_dispatch_model(model, device_map):
from accelerate.hooks import AlignDevicesHook, add_hook_to_module

if "" in device_map:
d = device_map[""]
model = model.to(torch.device(d))
model.hf_device_map = device_map
return model

tied_params = accelerate.utils.modeling.find_tied_parameters(model)
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
main_device = "cpu"
else:
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]

cpu_offload_group = [(n, d) for n, d in device_map.items() if d == "cpu"]
prev_hook = None
for idx, (n, d) in enumerate(cpu_offload_group):
m = get_module_by_name_suffix(model, n)
_, prev_hook = accelerate.cpu_offload_with_hook(m, execution_device=main_device, prev_module_hook=prev_hook)
# set first cpu offload module's prev_module_hook to the last cpu offload module's hook
if len(cpu_offload_group) > 1:
get_module_by_name_suffix(model, cpu_offload_group[0][0])._hf_hook.prev_module_hook = prev_hook

for n, d in device_map.items():
m = get_module_by_name_suffix(model, n)
if d != "cpu":
d = torch.device(d)
hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True)
add_hook_to_module(m, hook)
accelerate.utils.modeling.retie_parameters(model, tied_params)
model.hf_device_map = device_map

return model


class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
layer_type: str = None

Expand Down Expand Up @@ -132,6 +92,7 @@ def init_batch_state(self, max_output_len=256, **kwargs):
batch_infer_state.past_key_values_len = 0
batch_infer_state.is_context_stage = True
batch_infer_state.set_cache_manager(self.cache_manager)
batch_infer_state.cache_manager.free_all()
return batch_infer_state

@abstractmethod
Expand All @@ -157,15 +118,79 @@ def generate(self, **kwargs):
if self.config.model_type == "llama":
setattr(self.model.model, "infer_state", batch_infer_state)

batch_infer_state.is_context_stage = True

with torch.inference_mode():
return self.model.generate(**kwargs)

def prepare_inputs_for_generation(self, *args, **kwargs):
"""shortcut for model.prepare_inputs_for_generation"""
return self.model.prepare_inputs_for_generation(*args, **kwargs)

def collect_act_scales(self, model, tokenizer, dataset, device, num_samples=512, seq_len=512):
for text in tqdm(dataset):
input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device)
model(input_ids)
Xu-Kai marked this conversation as resolved.
Show resolved Hide resolved

def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samples=512, seq_len=512):
pbar = tqdm(dataset)
for text in pbar:
input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device)
model(input_ids)
Xu-Kai marked this conversation as resolved.
Show resolved Hide resolved
mean_scale = np.mean([v["input"] for v in act_dict.values()])
pbar.set_description(f"Mean input scale: {mean_scale:.2f}")

def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512):
model.eval()
device = next(model.parameters()).device
act_scales = {}

def stat_tensor(name, tensor):
hidden_dim = tensor.shape[-1]
tensor = tensor.view(-1, hidden_dim).abs().detach()
comming_max = torch.max(tensor, dim=0)[0].float().cpu()
if name in act_scales:
act_scales[name] = torch.max(act_scales[name], comming_max)
else:
act_scales[name] = comming_max

def stat_input_hook(m, x, y, name):
if isinstance(x, tuple):
x = x[0]
stat_tensor(name, x)

hooks = []
for name, m in model.named_modules():
if isinstance(m, nn.Linear):
hooks.append(m.register_forward_hook(partial(stat_input_hook, name=name)))

self.collect_act_scales(model, tokenizer, dataset, device, num_samples, seq_len)

for h in hooks:
h.remove()

return act_scales

@torch.no_grad()
def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5):
if not isinstance(fcs, list):
fcs = [fcs]
for fc in fcs:
assert isinstance(fc, nn.Linear)
assert ln.weight.numel() == fc.in_features == act_scales.numel()

device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
act_scales = act_scales.to(device=device, dtype=dtype)
weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0)
weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)

scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype)

ln.weight.div_(scales)
if hasattr(ln, "bias"):
ln.bias.div_(scales)

for fc in fcs:
fc.weight.mul_(scales.view(1, -1))

def save_quantized(
self,
save_dir: str,
Expand Down
51 changes: 30 additions & 21 deletions colossalai/inference/quant/smoothquant/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,10 @@
from functools import partial
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from torch import nn
from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
from tqdm import tqdm
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.configuration_llama import LlamaConfig
Expand Down Expand Up @@ -756,15 +752,14 @@ class SmoothLlamaForCausalLM(BaseSmoothForCausalLM):
def __init__(self, model: PreTrainedModel, quantized: bool = False):
super().__init__(model, quantized)

def quantized(
def get_act_dict(
self,
tokenizer,
dataset_path,
dataset,
num_samples=512,
seq_len=512,
):
llama_model = self.model
llama_config = llama_model.config

llama_model.eval()
device = next(llama_model.parameters()).device
Expand Down Expand Up @@ -798,23 +793,37 @@ def stat_io_hook(m, x, y, name):
if isinstance(m, torch.nn.Linear):
hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name)))

print("Collecting activation scales...")
pbar = tqdm(range(num_samples))
dataset = load_dataset("json", data_files=dataset_path, split="train")
dataset = dataset.shuffle(seed=42)
for i in pbar:
input_ids = tokenizer(
dataset["rows"][0][i]["row"]["text"],
return_tensors="pt",
max_length=seq_len,
truncation=True,
).input_ids.to(device)
llama_model(input_ids)
mean_scale = np.mean([v["input"] for v in act_dict.values()])
pbar.set_description(f"Mean input scale: {mean_scale:.2f}")
self.collect_act_dict(llama_model, tokenizer, dataset, act_dict, device, num_samples, seq_len)

for hook in hooks:
hook.remove()
return act_dict

def smooth_fn(self, scales, alpha=0.5):
model = self.model
for name, module in model.named_modules():
if isinstance(module, LlamaDecoderLayer):
attn_ln = module.input_layernorm
qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj]
qkv_input_scales = scales[name + ".self_attn.q_proj"]
self.smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)

def quantized(
self,
tokenizer,
dataset,
num_samples=512,
seq_len=512,
alpha=0.5,
):
llama_model = self.model
llama_config = llama_model.config

act_scales = self.get_act_scales(llama_model, tokenizer, dataset, num_samples, seq_len)

self.smooth_fn(act_scales, alpha)

act_dict = self.get_act_dict(tokenizer, dataset, num_samples, seq_len)
decoder_layer_scales = []

for idx in range(llama_config.num_hidden_layers):
Expand Down
Loading