Skip to content

Commit

Permalink
Remove progress bars (#550)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen committed Jul 23, 2024
1 parent 47f64ac commit 5889753
Showing 1 changed file with 40 additions and 51 deletions.
91 changes: 40 additions & 51 deletions awq/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def quantize(self):
)
scales_list = [
self._search_best_scale(self.modules[i], **layer)
for layer in tqdm(module_config, desc="Best Scales", leave=False)
for layer in module_config
]
apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
scales_list = append_str_prefix(
Expand Down Expand Up @@ -252,9 +252,7 @@ def _module_forward(
# but only n_parallel_calib_samples at a time
module_output = []
partitioned_inputs = torch.split(x, self.n_parallel_calib_samples)
for x_partial in tqdm(
partitioned_inputs, desc="Module forward", leave=False
):
for x_partial in partitioned_inputs:
partial_output = module(x_partial, **module_kwargs)

if isinstance(partial_output, tuple):
Expand Down Expand Up @@ -370,43 +368,41 @@ def _compute_best_scale(
x_mean = x_mean.view(-1).to(device)
w_mean = w_mean.view(-1).to(device)

with tqdm(range(n_grid), desc="Grid Search", leave=False) as pbar:
for ratio in pbar:
# create new scales
ratio = ratio / n_grid
for ratio in range(n_grid):
# create new scales
ratio = ratio / n_grid

# NOTE: s^-1 * x is fused here, according to paper
if self.duo_scaling:
scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp(min=1e-4)
else:
scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt()
scales_view = scales.view(1, -1).to(device)

# avoid scaling values that overflow
scales[torch.isinf(scales)] = 1
scales[torch.isnan(scales)] = 1

# Q(W * s)
for fc in linears2scale:
fc.weight.mul_(scales_view)
fc.weight.data = (
self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
)

# W * X
int_w_output = self._module_forward(x, module2inspect, kwargs)

# compute mean squared error (L2 norm)
loss = self._compute_loss(fp16_output, int_w_output, device)

history.append(loss)
if loss < best_error:
best_error = loss
best_ratio = ratio
best_scales = scales.clone()
module2inspect.load_state_dict(org_sd)
pbar.set_description(f"Grid Search (Best: {best_ratio})")
# NOTE: s^-1 * x is fused here, according to paper
if self.duo_scaling:
scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp(min=1e-4)
else:
scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt()
scales_view = scales.view(1, -1).to(device)

# avoid scaling values that overflow
scales[torch.isinf(scales)] = 1
scales[torch.isnan(scales)] = 1

# Q(W * s)
for fc in linears2scale:
fc.weight.mul_(scales_view)
fc.weight.data = (
self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
)

# W * X
int_w_output = self._module_forward(x, module2inspect, kwargs)

# compute mean squared error (L2 norm)
loss = self._compute_loss(fp16_output, int_w_output, device)

history.append(loss)
if loss < best_error:
best_error = loss
best_ratio = ratio
best_scales = scales.clone()
module2inspect.load_state_dict(org_sd)

if best_ratio == -1:
logging.debug(history)
Expand Down Expand Up @@ -439,16 +435,9 @@ def _compute_loss(
int_w_chunks = torch.split(int_w_output_flat, chunk_size)

# Compute the loss for each chunk
with tqdm(
zip(fp16_chunks, int_w_chunks),
total=len(fp16_chunks),
desc="Computing Loss",
leave=False,
) as pbar:
for fp16_chunk, int_w_chunk in pbar:
chunk_loss = (fp16_chunk.to(device) - int_w_chunk.to(device)).float().pow(2).sum().item()
loss += chunk_loss
pbar.set_description(f"Computing Loss (loss: {loss:.2f})")
for fp16_chunk, int_w_chunk in zip(fp16_chunks, int_w_chunks):
chunk_loss = (fp16_chunk.to(device) - int_w_chunk.to(device)).float().pow(2).sum().item()
loss += chunk_loss

# Normalize the loss by the total number of elements
loss /= num_elements
Expand All @@ -460,7 +449,7 @@ def _search_best_clip(self, layer, named_linears, input_feat):
clip_list = []
avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"]

for name in tqdm(named_linears, desc="Computing Best Clip", leave=False):
for name in named_linears:
# due to qk bmm, it is hard to clip precisely
if any([_ in name for _ in avoid_clipping]):
continue
Expand Down

0 comments on commit 5889753

Please sign in to comment.