Skip to content

Commit

Permalink
Update patch version v1.4.2
Browse files Browse the repository at this point in the history
Co-authored-by: Alchan Kim <algy@friendli.ai>
Co-authored-by: Yunmo Koo <yunmorning@friendli.ai>
Co-authored-by: Soomin Chun <soomin@friendli.ai>
  • Loading branch information
4 people committed Jul 21, 2024
1 parent 64ee249 commit 25a0d6d
Show file tree
Hide file tree
Showing 40 changed files with 2,943 additions and 232 deletions.
86 changes: 63 additions & 23 deletions friendli/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
import yaml

from friendli.enums import CheckpointFileType, ModelDataType
from friendli.errors import CheckpointConversionError, InvalidConfigError, NotFoundError
from friendli.errors import (
CheckpointConversionError,
InvalidConfigError,
NotFoundError,
NotSupportedQuantConfigError,
QuantizationError,
)
from friendli.formatter import TableFormatter
from friendli.sdk.client import Friendli
from friendli.utils.compat import model_dump, model_parse
Expand Down Expand Up @@ -144,6 +150,7 @@ def convert(
lookup_column_name: text
num_samples: 128
max_length: 512
batch_size: 1
awq_args:
quant_bit: 4
quant_group_size: 64
Expand All @@ -160,6 +167,7 @@ def convert(
- **`lookup_column_name`**: The name of a column in the dataset to be used as calibration inputs. Defaults to "text".
- **`num_samples`**: The number of dataset samples to use for calibration. Note that the dataset will be shuffled before sampling. Defaults to 512.
- **`max_length`**: The maximum length of a calibration input sequence. Defauts to 512.
- **`batch_size`**: The number of samples to process in a single batch. Defaults to 1.
- **`awq_args`** (Fill in this field only for "awq" mode)
- **`quant_bit`** : Bit width of integers to represent weights. Possible values are `4` or `8`. Defaults to 4.
- **`quant_group_size`**: Group size of quantized matrices. 64 is the only supported value at this time. Defaults to 64.
Expand Down Expand Up @@ -187,15 +195,19 @@ def convert(
:::
"""
# pylint: disable=too-many-branches
try:
from friendli.modules.converter.convert import ( # pylint: disable=import-outside-toplevel
convert_checkpoint,
)
from friendli.modules.quantizer.schema.config import ( # pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel
from friendli.modules.converter.convert import convert_checkpoint
from friendli.modules.quantizer.schema.config import (
AWQConfig,
OneOfQuantConfig,
QuantConfig,
)
from friendli.modules.quantizer_v2.quantize import quantize_checkpoint
from friendli.modules.quantizer_v2.schema.config import Int8QuantConfig

# pylint: enable=import-outside-toplevel
except ModuleNotFoundError as exc:
secho_error_and_exit(str(exc))

Expand All @@ -205,18 +217,29 @@ def convert(
os.mkdir(output_dir)

quant_config: Optional[OneOfQuantConfig] = None
use_quantizer_v2 = False
if quantize:
if quant_config_file:
try:
quant_config_dict = cast(dict, yaml.safe_load(quant_config_file.read()))
except yaml.YAMLError as err:
secho_error_and_exit(f"Failed to load the quant config file: {err}")
quant_config = model_parse(
QuantConfig, {"config": quant_config_dict}
).config
if quant_config_dict["mode"] == "int8":
quant_config = model_parse( # type: ignore
Int8QuantConfig, quant_config_dict
)
else:
quant_config = model_parse(
QuantConfig, {"config": quant_config_dict}
).config

# TODO(SA): All Quantization mode will be migrated to V2. After migration, please remove it.
else:
quant_config = AWQConfig()

if isinstance(quant_config, Int8QuantConfig):
use_quantizer_v2 = True

default_names = {
CheckpointFileType.HDF5: "model.h5",
CheckpointFileType.SAFETENSORS: "model.safetensors",
Expand All @@ -225,21 +248,38 @@ def convert(
output_model_file_name or default_names[output_ckpt_file_type]
)

try:
convert_checkpoint(
model_name_or_path=model_name_or_path,
output_model_file_name=output_model_file_name,
output_ckpt_file_type=output_ckpt_file_type,
output_attr_file_name=output_attr_file_name,
output_dir=output_dir,
data_type=data_type,
cache_dir=cache_dir,
dry_run=dry_run,
quantize=quantize,
quant_config=quant_config,
)
except (NotFoundError, CheckpointConversionError, InvalidConfigError) as exc:
secho_error_and_exit(str(exc))
if use_quantizer_v2:
if output_ckpt_file_type == CheckpointFileType.HDF5:
secho_error_and_exit(
f"int8 quantization only supports `safetensors` output_ckpt_file_type. Current output_ckpt_file_type: {output_ckpt_file_type}"
)
try:
assert isinstance(quant_config, Int8QuantConfig)
quantize_checkpoint(
model_name_or_path=model_name_or_path,
output_dir=output_dir,
cache_dir=cache_dir,
dry_run=dry_run,
quant_config=quant_config,
)
except (NotFoundError, QuantizationError, NotSupportedQuantConfigError) as exc:
secho_error_and_exit(str(exc))
else:
try:
convert_checkpoint(
model_name_or_path=model_name_or_path,
output_model_file_name=output_model_file_name,
output_ckpt_file_type=output_ckpt_file_type,
output_attr_file_name=output_attr_file_name,
output_dir=output_dir,
data_type=data_type,
cache_dir=cache_dir,
dry_run=dry_run,
quantize=quantize,
quant_config=quant_config,
)
except (NotFoundError, CheckpointConversionError, InvalidConfigError) as exc:
secho_error_and_exit(str(exc))

msg = (
f"Checkpoint({model_name_or_path}) can be converted."
Expand Down
2 changes: 2 additions & 0 deletions friendli/modules/converter/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from friendli.errors import NotSupportedCheckpointError
from friendli.modules.converter.base import OneOfAdapterConverter, OneOfConverter
from friendli.modules.converter.models.arctic import ArcticForCausalLMConverter
from friendli.modules.converter.models.blenderbot import BlenderbotConverter
from friendli.modules.converter.models.bloom import BloomForCausalLMConverter
from friendli.modules.converter.models.codegen import CodegenForCausalLMConverter
Expand Down Expand Up @@ -83,6 +84,7 @@
"CohereForCausalLM": (CohereForCausalLM, CohereForCausalLMConverter),
"DbrxForCausalLM": (DbrxForCausalLM, DbrxForCausalLMConverter),
"Phi3ForCausalLM": (Phi3ForCausalLM, Phi3ForCausalLMConverter),
"ArcticForCausalLM": (AutoModelForCausalLM, ArcticForCausalLMConverter),
}

MODEL_ARCH_ADAPTER_CONVERTER_MAP: Dict[
Expand Down
Loading

0 comments on commit 25a0d6d

Please sign in to comment.