Skip to content

Commit

Permalink
[inference] add smoothquant llama (hpcaitech#4861)
Browse files Browse the repository at this point in the history
* add smoothquant llama

* fix attention accuracy

* fix accuracy

* add kv cache and save pretrained

* refactor example

* delete smooth

* refactor code
  • Loading branch information
Xu-Kai committed Oct 13, 2023
1 parent c93bc2d commit d023076
Show file tree
Hide file tree
Showing 12 changed files with 2,166 additions and 145 deletions.
Empty file.
53 changes: 53 additions & 0 deletions colossalai/inference/quant/smoothquant/calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ

import functools

import torch
import torch.nn as nn
from datasets import load_dataset
from tqdm import tqdm


def get_act_scales(model, tokenizer, dataset_path, num_samples=512, seq_len=512):
model.eval()
device = next(model.parameters()).device
act_scales = {}

def stat_tensor(name, tensor):
hidden_dim = tensor.shape[-1]
tensor = tensor.view(-1, hidden_dim).abs().detach()
comming_max = torch.max(tensor, dim=0)[0].float().cpu()
if name in act_scales:
act_scales[name] = torch.max(act_scales[name], comming_max)
else:
act_scales[name] = comming_max

def stat_input_hook(m, x, y, name):
if isinstance(x, tuple):
x = x[0]
stat_tensor(name, x)

hooks = []
for name, m in model.named_modules():
if isinstance(m, nn.Linear):
hooks.append(m.register_forward_hook(functools.partial(stat_input_hook, name=name)))

dataset = load_dataset("json", data_files=dataset_path)

print("text", dataset["train"]["rows"][0][1]["row"]["text"])

dataset = dataset.shuffle(seed=42)

for i in tqdm(range(num_samples)):
input_ids = tokenizer(
dataset["train"]["rows"][0][i]["row"]["text"],
return_tensors="pt",
max_length=seq_len,
truncation=True,
).input_ids.to(device)
model(input_ids)

for h in hooks:
h.remove()

return act_scales
Loading

0 comments on commit d023076

Please sign in to comment.