diff --git a/awq/quantize/quantizer.py b/awq/quantize/quantizer.py index 91685f14..cd9fb0dd 100644 --- a/awq/quantize/quantizer.py +++ b/awq/quantize/quantizer.py @@ -165,7 +165,7 @@ def quantize(self): ) scales_list = [ self._search_best_scale(self.modules[i], **layer) - for layer in tqdm(module_config, desc="Best Scales", leave=False) + for layer in module_config ] apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat) scales_list = append_str_prefix( @@ -252,9 +252,7 @@ def _module_forward( # but only n_parallel_calib_samples at a time module_output = [] partitioned_inputs = torch.split(x, self.n_parallel_calib_samples) - for x_partial in tqdm( - partitioned_inputs, desc="Module forward", leave=False - ): + for x_partial in partitioned_inputs: partial_output = module(x_partial, **module_kwargs) if isinstance(partial_output, tuple): @@ -370,43 +368,41 @@ def _compute_best_scale( x_mean = x_mean.view(-1).to(device) w_mean = w_mean.view(-1).to(device) - with tqdm(range(n_grid), desc="Grid Search", leave=False) as pbar: - for ratio in pbar: - # create new scales - ratio = ratio / n_grid + for ratio in range(n_grid): + # create new scales + ratio = ratio / n_grid - # NOTE: s^-1 * x is fused here, according to paper - if self.duo_scaling: - scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp(min=1e-4) - else: - scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) - scales = scales / (scales.max() * scales.min()).sqrt() - scales_view = scales.view(1, -1).to(device) - - # avoid scaling values that overflow - scales[torch.isinf(scales)] = 1 - scales[torch.isnan(scales)] = 1 - - # Q(W * s) - for fc in linears2scale: - fc.weight.mul_(scales_view) - fc.weight.data = ( - self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view - ) - - # W * X - int_w_output = self._module_forward(x, module2inspect, kwargs) - - # compute mean squared error (L2 norm) - loss = self._compute_loss(fp16_output, int_w_output, device) - - history.append(loss) - if loss < best_error: - best_error = loss - best_ratio = ratio - best_scales = scales.clone() - module2inspect.load_state_dict(org_sd) - pbar.set_description(f"Grid Search (Best: {best_ratio})") + # NOTE: s^-1 * x is fused here, according to paper + if self.duo_scaling: + scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp(min=1e-4) + else: + scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) + scales = scales / (scales.max() * scales.min()).sqrt() + scales_view = scales.view(1, -1).to(device) + + # avoid scaling values that overflow + scales[torch.isinf(scales)] = 1 + scales[torch.isnan(scales)] = 1 + + # Q(W * s) + for fc in linears2scale: + fc.weight.mul_(scales_view) + fc.weight.data = ( + self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view + ) + + # W * X + int_w_output = self._module_forward(x, module2inspect, kwargs) + + # compute mean squared error (L2 norm) + loss = self._compute_loss(fp16_output, int_w_output, device) + + history.append(loss) + if loss < best_error: + best_error = loss + best_ratio = ratio + best_scales = scales.clone() + module2inspect.load_state_dict(org_sd) if best_ratio == -1: logging.debug(history) @@ -439,16 +435,9 @@ def _compute_loss( int_w_chunks = torch.split(int_w_output_flat, chunk_size) # Compute the loss for each chunk - with tqdm( - zip(fp16_chunks, int_w_chunks), - total=len(fp16_chunks), - desc="Computing Loss", - leave=False, - ) as pbar: - for fp16_chunk, int_w_chunk in pbar: - chunk_loss = (fp16_chunk.to(device) - int_w_chunk.to(device)).float().pow(2).sum().item() - loss += chunk_loss - pbar.set_description(f"Computing Loss (loss: {loss:.2f})") + for fp16_chunk, int_w_chunk in zip(fp16_chunks, int_w_chunks): + chunk_loss = (fp16_chunk.to(device) - int_w_chunk.to(device)).float().pow(2).sum().item() + loss += chunk_loss # Normalize the loss by the total number of elements loss /= num_elements @@ -460,7 +449,7 @@ def _search_best_clip(self, layer, named_linears, input_feat): clip_list = [] avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"] - for name in tqdm(named_linears, desc="Computing Best Clip", leave=False): + for name in named_linears: # due to qk bmm, it is hard to clip precisely if any([_ in name for _ in avoid_clipping]): continue