Skip to content

Commit

Permalink
fix empty state_dict() and bump to 0.2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
mobicham committed Aug 29, 2024
1 parent 293812c commit 5f8f0d2
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 9 deletions.
2 changes: 1 addition & 1 deletion hqq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "0.2.0"
__version__ = "0.2.1"
__author__ = 'Dr. Hicham Badri'
__credits__ = 'Mobius Labs GmbH'
61 changes: 54 additions & 7 deletions hqq/core/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .utils import is_divisible, encode_safetensor_type, decode_safetensor_type
from .optimize import optimize_weights_proximal
from .bitpack import BitPack
from termcolor import colored

_META_TYPE = {
"scale": torch.Tensor,
Expand Down Expand Up @@ -386,6 +387,8 @@ def __init__(
self.ready = False
self.in_gpu = False
self.bias = None
self.axis = None
self.channel_wise = None
self.device = device
self.compute_dtype = compute_dtype
self.quant_config = copy.deepcopy(quant_config)
Expand All @@ -408,6 +411,9 @@ def __init__(
if initialize:
self.initialize()

def is_initialized(self):
return False if (None in [self.W_q, self.meta]) else True

def initialize(self):
if self.linear_layer is not None:
self.quantize(self.linear_layer.weight.data, **self.quant_config)
Expand Down Expand Up @@ -524,9 +530,11 @@ def cuda(self, device):
)

if self.bias is not None:
if(isinstance(self.bias, torch.nn.Parameter)):
self.bias.data = self.bias.data.to(device=device, dtype=self.compute_dtype)
if(isinstance(self.bias, torch.Tensor)):
if isinstance(self.bias, torch.nn.Parameter):
self.bias.data = self.bias.data.to(
device=device, dtype=self.compute_dtype
)
if isinstance(self.bias, torch.Tensor):
self.bias = self.bias.to(device=device, dtype=self.compute_dtype)

self.W_q = nn.Parameter(self.W_q, requires_grad=False)
Expand Down Expand Up @@ -569,7 +577,36 @@ def cpu(self):

# state_dict is encoded by default for safetensors support. You can get the raw dict by setting self.encoded_state_dict=False. \
# Note: you can't change the state once it's done
def state_dict_keys(self):
return set(
[
"W_q",
"nbits",
"group_size",
"shape",
"scale",
"zero",
"axis",
"packing",
"unpack_view_dtype",
"view_as_float",
"quant_scale",
"quant_zero",
"compute_dtype",
"bias",
"offload_meta",
"encoded_state_dict",
"stores_quant_config",
"channel_wise",
"optimize",
"round_zero",
]
)

def state_dict(self, *args, **kwargs): # nn.Module override compatible
if not self.is_initialized():
return {k: None for k in self.state_dict_keys()}

if (
self.quant_config["scale_quant_params"]
or self.quant_config["zero_quant_params"]
Expand Down Expand Up @@ -1027,11 +1064,21 @@ def hqq_base_quant_config(
"view_as_float": view_as_float,
}

if(quant_zero or quant_scale):
print(colored('Warning: Quantized meta-data is deprecated and will be removed. It is not supported for quantized model serialization.', 'yellow'))
if quant_zero or quant_scale:
print(
colored(
"Warning: Quantized meta-data is deprecated and will be removed. It is not supported for quantized model serialization.",
"yellow",
)
)

if(offload_meta):
print(colored('Warning: Meta-data offloading is deprecated and will be removed. It is not supported for quantized model serialization.', 'yellow'))
if offload_meta:
print(
colored(
"Warning: Meta-data offloading is deprecated and will be removed. It is not supported for quantized model serialization.",
"yellow",
)
)

if offload_meta:
if quant_scale != quant_zero:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def run(self):

setup(
name="hqq",
version="0.2.0",
version="0.2.1",
description="Half-Quadratic Quantization (HQQ)",
url="https://github.com/mobiusml/hqq/",
author="Dr. Hicham Badri",
Expand Down

0 comments on commit 5f8f0d2

Please sign in to comment.