From 934afde29490766a1337fbaf42503d0bf87e9285 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 12 Oct 2023 16:44:46 +0800 Subject: [PATCH] [inference] add smooth function and delete useless code for smoothquant (#4895) * add smooth function and delete useless code * update datasets * remove duplicate import * delete useless file --- .../quant/smoothquant/calibration.py | 53 -------- .../quant/smoothquant/models/base_model.py | 119 +++++++++++------- .../quant/smoothquant/models/llama.py | 51 ++++---- 3 files changed, 102 insertions(+), 121 deletions(-) delete mode 100644 colossalai/inference/quant/smoothquant/calibration.py diff --git a/colossalai/inference/quant/smoothquant/calibration.py b/colossalai/inference/quant/smoothquant/calibration.py deleted file mode 100644 index 66ac49826592..000000000000 --- a/colossalai/inference/quant/smoothquant/calibration.py +++ /dev/null @@ -1,53 +0,0 @@ -# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ - -import functools - -import torch -import torch.nn as nn -from datasets import load_dataset -from tqdm import tqdm - - -def get_act_scales(model, tokenizer, dataset_path, num_samples=512, seq_len=512): - model.eval() - device = next(model.parameters()).device - act_scales = {} - - def stat_tensor(name, tensor): - hidden_dim = tensor.shape[-1] - tensor = tensor.view(-1, hidden_dim).abs().detach() - comming_max = torch.max(tensor, dim=0)[0].float().cpu() - if name in act_scales: - act_scales[name] = torch.max(act_scales[name], comming_max) - else: - act_scales[name] = comming_max - - def stat_input_hook(m, x, y, name): - if isinstance(x, tuple): - x = x[0] - stat_tensor(name, x) - - hooks = [] - for name, m in model.named_modules(): - if isinstance(m, nn.Linear): - hooks.append(m.register_forward_hook(functools.partial(stat_input_hook, name=name))) - - dataset = load_dataset("json", data_files=dataset_path) - - print("text", dataset["train"]["rows"][0][1]["row"]["text"]) - - dataset = dataset.shuffle(seed=42) - - for i in tqdm(range(num_samples)): - input_ids = tokenizer( - dataset["train"]["rows"][0][i]["row"]["text"], - return_tensors="pt", - max_length=seq_len, - truncation=True, - ).input_ids.to(device) - model(input_ids) - - for h in hooks: - h.remove() - - return act_scales diff --git a/colossalai/inference/quant/smoothquant/models/base_model.py b/colossalai/inference/quant/smoothquant/models/base_model.py index 326c3df6e038..73cdbb39e53f 100644 --- a/colossalai/inference/quant/smoothquant/models/base_model.py +++ b/colossalai/inference/quant/smoothquant/models/base_model.py @@ -4,71 +4,31 @@ import types import warnings from abc import abstractmethod +from functools import partial from os.path import isdir, isfile, join from typing import Dict, List, Optional, Union import accelerate +import numpy as np import torch import torch.nn as nn import transformers from safetensors.torch import save_file as safe_save from torch import device +from tqdm import tqdm from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel from transformers.modeling_utils import no_init_weights from transformers.utils.generic import ContextManagers from transformers.utils.hub import PushToHubMixin, cached_file -from ....tensor_parallel.batch_infer_state import BatchInferState -from ....tensor_parallel.kvcache_manager import MemoryManager +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager CPU = device("cpu") -CUDA_0 = device("cuda:0") SUPPORTED_MODELS = ["llama"] -def get_module_by_name_suffix(model, module_name: str): - for name, module in model.named_modules(): - if name.endswith(module_name): - return module - - -def simple_dispatch_model(model, device_map): - from accelerate.hooks import AlignDevicesHook, add_hook_to_module - - if "" in device_map: - d = device_map[""] - model = model.to(torch.device(d)) - model.hf_device_map = device_map - return model - - tied_params = accelerate.utils.modeling.find_tied_parameters(model) - if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}: - main_device = "cpu" - else: - main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0] - - cpu_offload_group = [(n, d) for n, d in device_map.items() if d == "cpu"] - prev_hook = None - for idx, (n, d) in enumerate(cpu_offload_group): - m = get_module_by_name_suffix(model, n) - _, prev_hook = accelerate.cpu_offload_with_hook(m, execution_device=main_device, prev_module_hook=prev_hook) - # set first cpu offload module's prev_module_hook to the last cpu offload module's hook - if len(cpu_offload_group) > 1: - get_module_by_name_suffix(model, cpu_offload_group[0][0])._hf_hook.prev_module_hook = prev_hook - - for n, d in device_map.items(): - m = get_module_by_name_suffix(model, n) - if d != "cpu": - d = torch.device(d) - hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True) - add_hook_to_module(m, hook) - accelerate.utils.modeling.retie_parameters(model, tied_params) - model.hf_device_map = device_map - - return model - - class BaseSmoothForCausalLM(nn.Module, PushToHubMixin): layer_type: str = None @@ -132,6 +92,7 @@ def init_batch_state(self, max_output_len=256, **kwargs): batch_infer_state.past_key_values_len = 0 batch_infer_state.is_context_stage = True batch_infer_state.set_cache_manager(self.cache_manager) + batch_infer_state.cache_manager.free_all() return batch_infer_state @abstractmethod @@ -157,8 +118,6 @@ def generate(self, **kwargs): if self.config.model_type == "llama": setattr(self.model.model, "infer_state", batch_infer_state) - batch_infer_state.is_context_stage = True - with torch.inference_mode(): return self.model.generate(**kwargs) @@ -166,6 +125,72 @@ def prepare_inputs_for_generation(self, *args, **kwargs): """shortcut for model.prepare_inputs_for_generation""" return self.model.prepare_inputs_for_generation(*args, **kwargs) + def collect_act_scales(self, model, tokenizer, dataset, device, num_samples=512, seq_len=512): + for text in tqdm(dataset): + input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device) + model(input_ids) + + def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samples=512, seq_len=512): + pbar = tqdm(dataset) + for text in pbar: + input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device) + model(input_ids) + mean_scale = np.mean([v["input"] for v in act_dict.values()]) + pbar.set_description(f"Mean input scale: {mean_scale:.2f}") + + def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512): + model.eval() + device = next(model.parameters()).device + act_scales = {} + + def stat_tensor(name, tensor): + hidden_dim = tensor.shape[-1] + tensor = tensor.view(-1, hidden_dim).abs().detach() + comming_max = torch.max(tensor, dim=0)[0].float().cpu() + if name in act_scales: + act_scales[name] = torch.max(act_scales[name], comming_max) + else: + act_scales[name] = comming_max + + def stat_input_hook(m, x, y, name): + if isinstance(x, tuple): + x = x[0] + stat_tensor(name, x) + + hooks = [] + for name, m in model.named_modules(): + if isinstance(m, nn.Linear): + hooks.append(m.register_forward_hook(partial(stat_input_hook, name=name))) + + self.collect_act_scales(model, tokenizer, dataset, device, num_samples, seq_len) + + for h in hooks: + h.remove() + + return act_scales + + @torch.no_grad() + def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5): + if not isinstance(fcs, list): + fcs = [fcs] + for fc in fcs: + assert isinstance(fc, nn.Linear) + assert ln.weight.numel() == fc.in_features == act_scales.numel() + + device, dtype = fcs[0].weight.device, fcs[0].weight.dtype + act_scales = act_scales.to(device=device, dtype=dtype) + weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0) + weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5) + + scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype) + + ln.weight.div_(scales) + if hasattr(ln, "bias"): + ln.bias.div_(scales) + + for fc in fcs: + fc.weight.mul_(scales.view(1, -1)) + def save_quantized( self, save_dir: str, diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py index b201347825b2..014fb640e060 100644 --- a/colossalai/inference/quant/smoothquant/models/llama.py +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -7,14 +7,10 @@ from functools import partial from typing import List, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from datasets import load_dataset -from torch import nn from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T -from tqdm import tqdm from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.configuration_llama import LlamaConfig @@ -756,15 +752,14 @@ class SmoothLlamaForCausalLM(BaseSmoothForCausalLM): def __init__(self, model: PreTrainedModel, quantized: bool = False): super().__init__(model, quantized) - def quantized( + def get_act_dict( self, tokenizer, - dataset_path, + dataset, num_samples=512, seq_len=512, ): llama_model = self.model - llama_config = llama_model.config llama_model.eval() device = next(llama_model.parameters()).device @@ -798,23 +793,37 @@ def stat_io_hook(m, x, y, name): if isinstance(m, torch.nn.Linear): hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) - print("Collecting activation scales...") - pbar = tqdm(range(num_samples)) - dataset = load_dataset("json", data_files=dataset_path, split="train") - dataset = dataset.shuffle(seed=42) - for i in pbar: - input_ids = tokenizer( - dataset["rows"][0][i]["row"]["text"], - return_tensors="pt", - max_length=seq_len, - truncation=True, - ).input_ids.to(device) - llama_model(input_ids) - mean_scale = np.mean([v["input"] for v in act_dict.values()]) - pbar.set_description(f"Mean input scale: {mean_scale:.2f}") + self.collect_act_dict(llama_model, tokenizer, dataset, act_dict, device, num_samples, seq_len) + for hook in hooks: hook.remove() + return act_dict + + def smooth_fn(self, scales, alpha=0.5): + model = self.model + for name, module in model.named_modules(): + if isinstance(module, LlamaDecoderLayer): + attn_ln = module.input_layernorm + qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj] + qkv_input_scales = scales[name + ".self_attn.q_proj"] + self.smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha) + + def quantized( + self, + tokenizer, + dataset, + num_samples=512, + seq_len=512, + alpha=0.5, + ): + llama_model = self.model + llama_config = llama_model.config + + act_scales = self.get_act_scales(llama_model, tokenizer, dataset, num_samples, seq_len) + + self.smooth_fn(act_scales, alpha) + act_dict = self.get_act_dict(tokenizer, dataset, num_samples, seq_len) decoder_layer_scales = [] for idx in range(llama_config.num_hidden_layers):