From 25a0d6db73c97746cac2187d425c36f2159a4e1d Mon Sep 17 00:00:00 2001 From: SeungAh Lee Date: Wed, 26 Jun 2024 18:14:16 +0900 Subject: [PATCH] Update patch version v1.4.2 Co-authored-by: Alchan Kim Co-authored-by: Yunmo Koo Co-authored-by: Soomin Chun --- friendli/cli/model.py | 86 ++- friendli/modules/converter/maps.py | 2 + friendli/modules/converter/models/arctic.py | 254 ++++++++ friendli/modules/quantizer/base.py | 10 +- friendli/modules/quantizer/maps.py | 2 + friendli/modules/quantizer/models/arctic.py | 114 ++++ friendli/modules/quantizer_v2/__init__.py | 3 + friendli/modules/quantizer_v2/base.py | 257 ++++++++ friendli/modules/quantizer_v2/enums.py | 41 ++ .../modules/quantizer_v2/int8/__init__.py | 3 + friendli/modules/quantizer_v2/int8/base.py | 125 ++++ friendli/modules/quantizer_v2/int8/utils.py | 97 +++ friendli/modules/quantizer_v2/layers.py | 94 +++ friendli/modules/quantizer_v2/maps.py | 86 +++ friendli/modules/quantizer_v2/models/llama.py | 169 ++++++ friendli/modules/quantizer_v2/models/phi3.py | 144 +++++ friendli/modules/quantizer_v2/quantize.py | 89 +++ .../modules/quantizer_v2/schema/__init__.py | 3 + .../modules/quantizer_v2/schema/config.py | 77 +++ friendli/modules/quantizer_v2/schema/data.py | 66 ++ friendli/modules/quantizer_v2/utils.py | 565 ++++++++++++++++++ .../schema/api/v1/chat/completion_chunk.py | 52 ++ friendli/schema/api/v1/chat/completions.py | 79 ++- .../api/v1/codegen/chat_completions_pb2.py | 32 +- .../api/v1/codegen/chat_completions_pb2.pyi | 148 ++++- .../schema/api/v1/codegen/completions_pb2.py | 38 +- .../schema/api/v1/codegen/completions_pb2.pyi | 29 +- .../api/v1/codegen/completions_pb2_grpc.py | 18 +- .../api/v1/codegen/response_format_pb2.py | 33 + .../api/v1/codegen/response_format_pb2.pyi | 34 ++ .../api/v1/codegen/text_to_image_pb2.py | 12 +- friendli/sdk/api/base.py | 6 +- friendli/sdk/api/chat/completions.py | 96 ++- friendli/utils/request.py | 18 + proto/chat_completions.proto | 77 ++- proto/completions.proto | 138 ++--- proto/response_format.proto | 17 + pyproject.toml | 2 +- scripts/fix_imports.py | 48 ++ tox.ini | 11 +- 40 files changed, 2943 insertions(+), 232 deletions(-) create mode 100644 friendli/modules/converter/models/arctic.py create mode 100644 friendli/modules/quantizer/models/arctic.py create mode 100644 friendli/modules/quantizer_v2/__init__.py create mode 100644 friendli/modules/quantizer_v2/base.py create mode 100644 friendli/modules/quantizer_v2/enums.py create mode 100644 friendli/modules/quantizer_v2/int8/__init__.py create mode 100644 friendli/modules/quantizer_v2/int8/base.py create mode 100644 friendli/modules/quantizer_v2/int8/utils.py create mode 100644 friendli/modules/quantizer_v2/layers.py create mode 100644 friendli/modules/quantizer_v2/maps.py create mode 100644 friendli/modules/quantizer_v2/models/llama.py create mode 100644 friendli/modules/quantizer_v2/models/phi3.py create mode 100644 friendli/modules/quantizer_v2/quantize.py create mode 100644 friendli/modules/quantizer_v2/schema/__init__.py create mode 100644 friendli/modules/quantizer_v2/schema/config.py create mode 100644 friendli/modules/quantizer_v2/schema/data.py create mode 100644 friendli/modules/quantizer_v2/utils.py create mode 100644 friendli/schema/api/v1/chat/completion_chunk.py create mode 100644 friendli/schema/api/v1/codegen/response_format_pb2.py create mode 100644 friendli/schema/api/v1/codegen/response_format_pb2.pyi create mode 100644 proto/response_format.proto create mode 100644 scripts/fix_imports.py diff --git a/friendli/cli/model.py b/friendli/cli/model.py index 00821e8c..dff7bba3 100644 --- a/friendli/cli/model.py +++ b/friendli/cli/model.py @@ -13,7 +13,13 @@ import yaml from friendli.enums import CheckpointFileType, ModelDataType -from friendli.errors import CheckpointConversionError, InvalidConfigError, NotFoundError +from friendli.errors import ( + CheckpointConversionError, + InvalidConfigError, + NotFoundError, + NotSupportedQuantConfigError, + QuantizationError, +) from friendli.formatter import TableFormatter from friendli.sdk.client import Friendli from friendli.utils.compat import model_dump, model_parse @@ -144,6 +150,7 @@ def convert( lookup_column_name: text num_samples: 128 max_length: 512 + batch_size: 1 awq_args: quant_bit: 4 quant_group_size: 64 @@ -160,6 +167,7 @@ def convert( - **`lookup_column_name`**: The name of a column in the dataset to be used as calibration inputs. Defaults to "text". - **`num_samples`**: The number of dataset samples to use for calibration. Note that the dataset will be shuffled before sampling. Defaults to 512. - **`max_length`**: The maximum length of a calibration input sequence. Defauts to 512. + - **`batch_size`**: The number of samples to process in a single batch. Defaults to 1. - **`awq_args`** (Fill in this field only for "awq" mode) - **`quant_bit`** : Bit width of integers to represent weights. Possible values are `4` or `8`. Defaults to 4. - **`quant_group_size`**: Group size of quantized matrices. 64 is the only supported value at this time. Defaults to 64. @@ -187,15 +195,19 @@ def convert( ::: """ + # pylint: disable=too-many-branches try: - from friendli.modules.converter.convert import ( # pylint: disable=import-outside-toplevel - convert_checkpoint, - ) - from friendli.modules.quantizer.schema.config import ( # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + from friendli.modules.converter.convert import convert_checkpoint + from friendli.modules.quantizer.schema.config import ( AWQConfig, OneOfQuantConfig, QuantConfig, ) + from friendli.modules.quantizer_v2.quantize import quantize_checkpoint + from friendli.modules.quantizer_v2.schema.config import Int8QuantConfig + + # pylint: enable=import-outside-toplevel except ModuleNotFoundError as exc: secho_error_and_exit(str(exc)) @@ -205,18 +217,29 @@ def convert( os.mkdir(output_dir) quant_config: Optional[OneOfQuantConfig] = None + use_quantizer_v2 = False if quantize: if quant_config_file: try: quant_config_dict = cast(dict, yaml.safe_load(quant_config_file.read())) except yaml.YAMLError as err: secho_error_and_exit(f"Failed to load the quant config file: {err}") - quant_config = model_parse( - QuantConfig, {"config": quant_config_dict} - ).config + if quant_config_dict["mode"] == "int8": + quant_config = model_parse( # type: ignore + Int8QuantConfig, quant_config_dict + ) + else: + quant_config = model_parse( + QuantConfig, {"config": quant_config_dict} + ).config + + # TODO(SA): All Quantization mode will be migrated to V2. After migration, please remove it. else: quant_config = AWQConfig() + if isinstance(quant_config, Int8QuantConfig): + use_quantizer_v2 = True + default_names = { CheckpointFileType.HDF5: "model.h5", CheckpointFileType.SAFETENSORS: "model.safetensors", @@ -225,21 +248,38 @@ def convert( output_model_file_name or default_names[output_ckpt_file_type] ) - try: - convert_checkpoint( - model_name_or_path=model_name_or_path, - output_model_file_name=output_model_file_name, - output_ckpt_file_type=output_ckpt_file_type, - output_attr_file_name=output_attr_file_name, - output_dir=output_dir, - data_type=data_type, - cache_dir=cache_dir, - dry_run=dry_run, - quantize=quantize, - quant_config=quant_config, - ) - except (NotFoundError, CheckpointConversionError, InvalidConfigError) as exc: - secho_error_and_exit(str(exc)) + if use_quantizer_v2: + if output_ckpt_file_type == CheckpointFileType.HDF5: + secho_error_and_exit( + f"int8 quantization only supports `safetensors` output_ckpt_file_type. Current output_ckpt_file_type: {output_ckpt_file_type}" + ) + try: + assert isinstance(quant_config, Int8QuantConfig) + quantize_checkpoint( + model_name_or_path=model_name_or_path, + output_dir=output_dir, + cache_dir=cache_dir, + dry_run=dry_run, + quant_config=quant_config, + ) + except (NotFoundError, QuantizationError, NotSupportedQuantConfigError) as exc: + secho_error_and_exit(str(exc)) + else: + try: + convert_checkpoint( + model_name_or_path=model_name_or_path, + output_model_file_name=output_model_file_name, + output_ckpt_file_type=output_ckpt_file_type, + output_attr_file_name=output_attr_file_name, + output_dir=output_dir, + data_type=data_type, + cache_dir=cache_dir, + dry_run=dry_run, + quantize=quantize, + quant_config=quant_config, + ) + except (NotFoundError, CheckpointConversionError, InvalidConfigError) as exc: + secho_error_and_exit(str(exc)) msg = ( f"Checkpoint({model_name_or_path}) can be converted." diff --git a/friendli/modules/converter/maps.py b/friendli/modules/converter/maps.py index 5284ed02..7a8bcd37 100644 --- a/friendli/modules/converter/maps.py +++ b/friendli/modules/converter/maps.py @@ -29,6 +29,7 @@ from friendli.errors import NotSupportedCheckpointError from friendli.modules.converter.base import OneOfAdapterConverter, OneOfConverter +from friendli.modules.converter.models.arctic import ArcticForCausalLMConverter from friendli.modules.converter.models.blenderbot import BlenderbotConverter from friendli.modules.converter.models.bloom import BloomForCausalLMConverter from friendli.modules.converter.models.codegen import CodegenForCausalLMConverter @@ -83,6 +84,7 @@ "CohereForCausalLM": (CohereForCausalLM, CohereForCausalLMConverter), "DbrxForCausalLM": (DbrxForCausalLM, DbrxForCausalLMConverter), "Phi3ForCausalLM": (Phi3ForCausalLM, Phi3ForCausalLMConverter), + "ArcticForCausalLM": (AutoModelForCausalLM, ArcticForCausalLMConverter), } MODEL_ARCH_ADAPTER_CONVERTER_MAP: Dict[ diff --git a/friendli/modules/converter/models/arctic.py b/friendli/modules/converter/models/arctic.py new file mode 100644 index 00000000..293d21d9 --- /dev/null +++ b/friendli/modules/converter/models/arctic.py @@ -0,0 +1,254 @@ +# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved. + +"""Friendli Arctic Checkpoint Converter.""" + + +from __future__ import annotations + +from typing import cast + +from transformers import PretrainedConfig # type: ignore[import] + +from friendli.errors import CheckpointConversionError, NotSupportedCheckpointError +from friendli.logging import logger +from friendli.modules.converter.base import FP8OnlyConverter +from friendli.modules.converter.interface import RotaryEmbeddingConversionInterface + + +class ArcticConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ArcticModel`]. It is used to instantiate an + Arctic model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the #TODO(rsamdani): add what model has the default config.. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Arctic model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ArcticModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Arctic's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to root per-token, can be also interpreted as the `top-p` routing + parameter + num_local_experts (`int`, *optional*, defaults to 8): + Number of experts per Sparse MLP layer. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + ```python + >>> from transformers import ArcticModel, ArcticConfig + >>> # Initializing a Arctic 7B style configuration TODO(rsamdani): verify which model does the default configuration correspond to. + >>> configuration = ArcticConfig() + >>> # Initializing a model from the Arctic 7B style configuration + >>> model = ArcticModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "arctic" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=1, + num_local_experts=8, + router_aux_loss_coef=0.001, + moe_layer_frequency=2, + parallel_attn_mlp_res=False, + moe_train_capacity_factor=1, + moe_eval_capacity_factor=1, + enable_expert_tensor_parallelism=False, + moe_min_capacity=0, + moe_token_dropping=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.router_aux_loss_coef = router_aux_loss_coef + self.moe_layer_frequency = moe_layer_frequency + self.moe_train_capacity_factor = moe_train_capacity_factor + self.moe_eval_capacity_factor = moe_eval_capacity_factor + self.enable_expert_tensor_parallelism = enable_expert_tensor_parallelism + self.moe_min_capacity = moe_min_capacity + self.moe_token_dropping = moe_token_dropping + self.parallel_attn_mlp_res = parallel_attn_mlp_res + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class ArcticForCausalLMConverter(FP8OnlyConverter, RotaryEmbeddingConversionInterface): + """ArcticForCausalLM Architectures Converter Class.""" + + def check_config(self) -> None: + """Check if Arctic architectures' config can be converted to Friendli format.""" + super().check_config() + config = cast(ArcticConfig, self.config) + try: + if config.tie_word_embeddings: + raise NotSupportedCheckpointError( + invalid_option="'tie_word_embeddings=True'", + valid_options=[False], + ) + if config.hidden_act not in ["silu"]: + raise NotSupportedCheckpointError( + invalid_option=f"'hidden_act={config.hidden_act}'", + valid_options=["silu"], + ) + if config.moe_layer_frequency != 1: + raise NotSupportedCheckpointError( + invalid_option=f"'moe_layer_frequency={config.moe_layer_frequency}'", + valid_options=[1], + ) + if not config.parallel_attn_mlp_res: + raise NotSupportedCheckpointError( + invalid_option=f"'parallel_attn_mlp_res={config.parallel_attn_mlp_res}'", + valid_options=[True], + ) + + except AttributeError as exc: + raise CheckpointConversionError(str(exc)) from exc + + @property + def model_type(self) -> str: + """Model type.""" + return "arctic" + + @property + def decoder_layer_prefix(self) -> str: + """The layer name prefix used before Arctic's transformer block number.""" + return "model.layers." + + @property + def decoder_layer_num(self) -> int: + """The number of decoder layers in Arctic.""" + return cast(ArcticConfig, self.config).num_hidden_layers + + @property + def decoder_hidden_size(self) -> int: + """The hidden size in Arctic.""" + return cast(ArcticConfig, self.config).hidden_size + + @property + def decoder_num_attention_heads(self) -> int: + """The number of attention heads in Arctic.""" + return cast(ArcticConfig, self.config).num_attention_heads + + @property + def decoder_num_kv_attention_heads(self) -> int: + """The number of key-value attention heads in Arctic.""" + config = cast(ArcticConfig, self.config) + if config.num_key_value_heads is None: + return self.decoder_num_attention_heads + return config.num_key_value_heads + + @property + def decoder_head_size(self) -> int: + """The head size of Arctic.""" + return self.decoder_hidden_size // self.decoder_num_attention_heads + + @property + def decoder_ff_intermediate_size(self) -> int: + """The intermediate size of the linear layer in Arctic MLP.""" + return cast(ArcticConfig, self.config).intermediate_size + + @property + def rotary_dim(self) -> int: + """The rotary embedding dimension of Arctic.""" + return self.decoder_head_size + + @property + def rotary_emb_base(self) -> float: + """The rotary embedding base of Arctic.""" + return cast(ArcticConfig, self.config).rope_theta + + @property + def num_experts(self) -> int: + """The number of moe experts per transformer block in Arctic.""" + return cast(ArcticConfig, self.config).num_local_experts + + @property + def num_selected_moe_experts(self) -> int: + """The number of selected moe experts per transformer block in Arctic.""" + return cast(ArcticConfig, self.config).num_experts_per_tok diff --git a/friendli/modules/quantizer/base.py b/friendli/modules/quantizer/base.py index de9c65db..ea97e092 100644 --- a/friendli/modules/quantizer/base.py +++ b/friendli/modules/quantizer/base.py @@ -350,8 +350,12 @@ def _get_weight_act_quantize_results( ), "currently support fp8_e4m3" max_val = 448.0 min_val = -448.0 - - input_max = torch.concat([max_input_stats[name] for name in names]) + input_max = None + for name in names: + input_max = max_input_stats.get(name) + if input_max is not None: + break + assert input_max is not None target_weights = [model.get_submodule(name).weight for name in names] target_weight = torch.concat(target_weights) @@ -403,7 +407,7 @@ def quantize( for tf_quant_input in tqdm( self.hook.iter_tf_quant_inputs(model), total=len(self.hook.get_tf_blocks(model)), - desc="Qunatize", + desc="Quantize", unit="layer", ): assert isinstance(tf_quant_input, HFTFQuantInputs) diff --git a/friendli/modules/quantizer/maps.py b/friendli/modules/quantizer/maps.py index 17e70ed7..465d5c3e 100644 --- a/friendli/modules/quantizer/maps.py +++ b/friendli/modules/quantizer/maps.py @@ -16,6 +16,7 @@ from friendli.modules.quantizer.awq.models.llama import AWQLlamaHook from friendli.modules.quantizer.awq.models.mpt import AWQMPTHook from friendli.modules.quantizer.base import CommonQuantizer, FP8QuantHook, FP8Quantizer +from friendli.modules.quantizer.models.arctic import ArcticHook from friendli.modules.quantizer.models.dbrx import DbrxHook from friendli.modules.quantizer.models.llama import LlamaHook from friendli.modules.quantizer.models.mixtral import MixtralHook @@ -66,6 +67,7 @@ "CohereForCausalLM": LlamaHook, "DbrxForCausalLM": DbrxHook, "Phi3ForCausalLM": Phi3Hook, + "ArcticForCausalLM": ArcticHook, } diff --git a/friendli/modules/quantizer/models/arctic.py b/friendli/modules/quantizer/models/arctic.py new file mode 100644 index 00000000..cc7d3fd9 --- /dev/null +++ b/friendli/modules/quantizer/models/arctic.py @@ -0,0 +1,114 @@ +# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved. + +"""Friendli ArcticForCausalLM QuantizerHook.""" + +# mypy: ignore-errors + +from __future__ import annotations + +from typing import Iterator, List, Tuple, Type + +import torch + +from friendli.modules.quantizer.base import FP8QuantHook +from friendli.modules.quantizer.schema.data import ( + HFQuantInput, + HFTFQuantInputs, + TFQuantInputs, +) + + +class ArcticHook(FP8QuantHook): + """FP8QuantHook for ArcticForCausalLM.""" + + def get_tf_blocks(self, model: torch.nn.Module) -> List[torch.nn.Module]: + """Returns the transformer blocks in ArcticForCausalLM.""" + return model.model.layers + + def get_linear_layer_types(self) -> Tuple[Type[torch.nn.Module]]: + """Returns the linear layer types in ArcticForCausalLM.""" + return (torch.nn.Linear,) + + def iter_tf_quant_inputs( + self, model: torch.nn.Module + ) -> Iterator[TFQuantInputs] | Iterator[HFTFQuantInputs]: + """Returns the layers which should be quantized in transformer block of ArcticForCausalLM.""" + for index, decoder_layer in enumerate( + self.get_tf_blocks(model) # type: ignore[union-attr, arg-type] + ): + self_attn = decoder_layer.self_attn + block_sparse_moe = decoder_layer.block_sparse_moe + mlp = decoder_layer.residual_mlp + moe_ff1_ff_gate_target_names = [] + for expert_idx in range(self.converter.num_experts): + moe_ff1_ff_gate_target_names.extend( + [ + f"{self.quantized_layer_prefix}{index}.block_sparse_moe.experts.{expert_idx}.w1", + f"{self.quantized_layer_prefix}{index}.block_sparse_moe.experts.{expert_idx}.w3", + ] + ) + + yield HFTFQuantInputs( + layer_index=index, + block=decoder_layer, + quant_inputs=[ + HFQuantInput( + parent_module=self_attn, + target_names=[ + f"{self.quantized_layer_prefix}{index}.self_attn.q_proj", + f"{self.quantized_layer_prefix}{index}.self_attn.k_proj", + f"{self.quantized_layer_prefix}{index}.self_attn.v_proj", + ], + local_names=["q_proj", "k_proj", "v_proj"], + ), + HFQuantInput( + parent_module=self_attn, + target_names=[ + f"{self.quantized_layer_prefix}{index}.self_attn.o_proj", + ], + local_names=[ + "o_proj", + ], + ), + # router + HFQuantInput( + parent_module=block_sparse_moe, + target_names=[ + f"{self.quantized_layer_prefix}{index}.block_sparse_moe.gate", + ], + local_names=["gate"], + ), + # ff1, ff_gate in each moe + HFQuantInput( + parent_module=block_sparse_moe.experts, + target_names=moe_ff1_ff_gate_target_names, + local_names=["w1", "w3"], + ), + # ff2 in each moe + HFQuantInput( + parent_module=block_sparse_moe.experts, + target_names=[ + f"{self.quantized_layer_prefix}{index}.block_sparse_moe.experts.{expert_idx}.w2" + for expert_idx in range(self.converter.num_experts) + ], + local_names=["w2"], + ), + # ff1, ff_gate in parallel mlp + HFQuantInput( + parent_module=mlp, + target_names=[ + f"{self.quantized_layer_prefix}{index}.residual_mlp.w1", + f"{self.quantized_layer_prefix}{index}.residual_mlp.w3", + ], + local_names=["w1", "w3"], + ), + # ff2 in parallel mlp + HFQuantInput( + parent_module=mlp, + target_names=[ + f"{self.quantized_layer_prefix}{index}.residual_mlp.w2" + ], + local_names=["w2"], + ), + ], + ) diff --git a/friendli/modules/quantizer_v2/__init__.py b/friendli/modules/quantizer_v2/__init__.py new file mode 100644 index 00000000..9ee5a33d --- /dev/null +++ b/friendli/modules/quantizer_v2/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved. + +"""Friendli Model Quantizer V2.""" diff --git a/friendli/modules/quantizer_v2/base.py b/friendli/modules/quantizer_v2/base.py new file mode 100644 index 00000000..08c48f2d --- /dev/null +++ b/friendli/modules/quantizer_v2/base.py @@ -0,0 +1,257 @@ +# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved. + +"""Friendli Model Quantization Interface.""" + +from __future__ import annotations + +import os +from abc import ABC, abstractmethod +from contextlib import contextmanager +from typing import Any, Dict, Iterator, List, Tuple, Type + +import huggingface_hub # type: ignore +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import PretrainedConfig, PreTrainedModel # type: ignore + +from friendli.errors import NotSupportedQuantConfigError +from friendli.logging import logger +from friendli.modules.quantizer_v2.enums import QuantDatasetFormat +from friendli.modules.quantizer_v2.layers import ( + WeightActQuantizedLinearLayer, + WeightOnlyQuantizedLinearLayer, +) +from friendli.modules.quantizer_v2.schema.config import OneOfQuantConfig +from friendli.modules.quantizer_v2.schema.data import TFQuantInputs +from friendli.modules.quantizer_v2.utils import ( + collect_stats, + get_weight_act_quant_scales, + get_weight_only_quant_scales, + offload_module_sequence, + send_model_to_device, +) + + +class AbstractQuantHookV2(ABC): + """Abstract Quantization Hook for a specific model.""" + + def __init__(self, quant_config: OneOfQuantConfig, model_config: PretrainedConfig): + """Initialize the Quantization Hook. + + Args: + quant_config (OneOfQuantConfig): Quantization configuration. + model_config (PretrainedConfig): Model configuration. + """ + self.quant_config = quant_config + self.model_config = model_config + + @abstractmethod + def check_model_config(self) -> None: + """Check if the model is quantizable.""" + + @abstractmethod + def get_linear_layer_types(self) -> Tuple[Type[torch.nn.Module], ...]: + """Get linear layer types in the model.""" + + @abstractmethod + def get_tf_blocks(self, model: PreTrainedModel) -> List[torch.nn.Module]: + """Get tensor fusion blocks in the model.""" + + @abstractmethod + def iter_tf_quant_inputs(self, model: PreTrainedModel) -> Iterator[TFQuantInputs]: + """Iterate over TFQuantInputs.""" + + @property + @abstractmethod + def quantized_layer_prefix(self) -> str: + """Returns the prefix of the transformer block name.""" + + +class AbstractQuantizerV2(ABC): + """Abstract class for quantizer.""" + + def __init__(self, hook: AbstractQuantHookV2, config: OneOfQuantConfig): + """Initialize AbstractQuantizer.""" + self.config = config + self.hook = hook + + def check_config(self) -> None: + """Check if the model is quantizable.""" + self.hook.check_model_config() + calibration_dataset_config = self.config.calibration_dataset + data_path_or_name = calibration_dataset_config.path_or_name + percentile = self.config.percentile + if percentile <= 0 or percentile > 100: + raise NotSupportedQuantConfigError( + invalid_option=str(percentile), + valid_options=["0 < percentile <= 100"], + ) + if not os.path.exists(data_path_or_name): + data_name = data_path_or_name.split(":")[0] + if data_name not in ( + data.id for data in huggingface_hub.list_datasets(search=data_name) + ): + raise NotSupportedQuantConfigError( + invalid_option=data_name, + valid_options=["datasets on the huggingface hub", "local path"], + ) + else: + if calibration_dataset_config.format not in QuantDatasetFormat: + raise NotSupportedQuantConfigError( + invalid_option=calibration_dataset_config.format, + valid_options=list(QuantDatasetFormat), + ) + try: + torch.device(self.config.device) + except ValueError as err: + raise NotSupportedQuantConfigError( + invalid_option=self.config.device, + valid_options=["cpu", "cuda"], + ) from err + + @contextmanager + def _try_offload_model(self, model: PreTrainedModel): + if not self.config.offload: + logger.info("Offloading not enabled. Skipping.") + model.to(self.config.device) + yield + else: + logger.info("Offloading enabled.") + tf_blocks = self.hook.get_tf_blocks(model) + send_model_to_device(model, self.config.device, exclude=tf_blocks) + with offload_module_sequence(tf_blocks, self.config.device): + yield + + @abstractmethod + def quantize(self, model: PreTrainedModel) -> PreTrainedModel: + """Quantize model.""" + + def pre_quantize(self, model: PreTrainedModel) -> PreTrainedModel: + """Preprocess model before quantization.""" + + def post_quantize(self, model: PreTrainedModel) -> PreTrainedModel: + """Postprocess model after quantization.""" + + @abstractmethod + def get_quant_config(self) -> Dict[str, Any]: + """Get quantizer config.""" + + +class AbstractWeightOnlyQuantizer(AbstractQuantizerV2): + """Abstract class for weight only quantizer.""" + + def quantize(self, model: PreTrainedModel) -> PreTrainedModel: + """Return quantized model.""" + with self._try_offload_model(model): + for tf_quant_inputs in tqdm( + self.hook.iter_tf_quant_inputs(model), + total=len(self.hook.get_tf_blocks(model)), + desc="Quantize model..", + ): + for quant_input in tf_quant_inputs.quant_inputs: + parent_module, local_names, names = ( + quant_input.parent_module, + quant_input.local_names, + quant_input.target_names, + ) + parent_modules_w_local_name = [] + if isinstance(parent_module, torch.nn.ModuleList): + # For MoE models with seperate expert layers + for p_module in parent_module: + for local_name in local_names: + parent_modules_w_local_name.append( + (p_module, local_name) + ) + else: + assert isinstance(parent_module, torch.nn.Module) + for local_name in local_names: + parent_modules_w_local_name.append( + (parent_module, local_name) + ) + layers = [ + p_module.get_submodule(local_name) + for p_module, local_name in parent_modules_w_local_name + ] + assert self.config.quant_scale_dtype + quant_results = get_weight_only_quant_scales( + model, + names, + quant_dtype=self.config.quant_dtype, + quant_scale_dtype=self.config.quant_scale_dtype, + q_group_size=self.config.quant_group_size, + use_symmetric=self.config.use_symmetric, + ) + q_layers = [ + WeightOnlyQuantizedLinearLayer.from_layer(layer, quant_result) + for layer, quant_result in zip(layers, quant_results) + ] + for (p_module, local_name), q_layer in zip( + parent_modules_w_local_name, q_layers + ): + setattr(p_module, local_name, q_layer) + return model + + +class AbstractWeightActQuantizer(AbstractQuantizerV2): + """Abstract class for weight and activation quantizer.""" + + @abstractmethod + def get_calib_dataloader(self) -> DataLoader: + """Get encoded calibration dataset.""" + + def quantize(self, model: PreTrainedModel) -> PreTrainedModel: + """Return quantized model.""" + with self._try_offload_model(model): + max_input_stats, _ = collect_stats( + model, + self.config.device, + self.get_calib_dataloader(), + self.hook.get_linear_layer_types(), + percentile=self.config.percentile, + tqdm_desc="Collecting stats for Static Quantization.", + ) + for tf_quant_inputs in tqdm( + self.hook.iter_tf_quant_inputs(model), + total=len(self.hook.get_tf_blocks(model)), + desc="Quantize model..", + ): + for quant_input in tf_quant_inputs.quant_inputs: + parent_module, local_names, names = ( + quant_input.parent_module, + quant_input.local_names, + quant_input.target_names, + ) + parent_modules_w_local_name = [] + if isinstance(parent_module, torch.nn.ModuleList): + # For MoE models with seperate expert layers + for p_module in parent_module: + for local_name in local_names: + parent_modules_w_local_name.append( + (p_module, local_name) + ) + else: + assert isinstance(parent_module, torch.nn.Module) + for local_name in local_names: + parent_modules_w_local_name.append((p_module, local_name)) + layers = [ + p_module.get_submodule(local_name) + for p_module, local_name in parent_modules_w_local_name + ] + assert self.config.quant_scale_dtype + quant_results = get_weight_act_quant_scales( + model, + names, + max_input_stats, + quant_scale_dtype=self.config.quant_scale_dtype, + quant_dtype=self.config.quant_dtype, + ) + q_layers = [ + WeightActQuantizedLinearLayer.from_layer(layer, quant_result) + for layer, quant_result in zip(layers, quant_results) + ] + for (p_module, local_name), q_layer in zip( + parent_modules_w_local_name, q_layers + ): + setattr(p_module, local_name, q_layer) + return model diff --git a/friendli/modules/quantizer_v2/enums.py b/friendli/modules/quantizer_v2/enums.py new file mode 100644 index 00000000..18bc60c7 --- /dev/null +++ b/friendli/modules/quantizer_v2/enums.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved. + +"""Friendli Model Quantizer Enums.""" + + +from __future__ import annotations + +from enum import Enum + + +class QuantMode(str, Enum): + """Supported quantization modes.""" + + INT8 = "int8" + DUMMY = "dummy" + + +class QuantDatasetFormat(str, Enum): + """Supported file format for calibration datasets for quantization.""" + + JSON = "json" + CSV = "csv" + PARQUET = "parquet" + TXT = "txt" + + +class Int8QuantType(str, Enum): + """Int8Quant modes.""" + + DYNAMIC = "dynamic" + + +class ModelDataType(str, Enum): + """Model dtype enums.""" + + BF16 = "bf16" + FP16 = "fp16" + FP32 = "fp32" + FP8_E4M3 = "fp8_e4m3" + INT8 = "int8" + INT4 = "int4" diff --git a/friendli/modules/quantizer_v2/int8/__init__.py b/friendli/modules/quantizer_v2/int8/__init__.py new file mode 100644 index 00000000..9f651b15 --- /dev/null +++ b/friendli/modules/quantizer_v2/int8/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved. + +"""Friendli Int8 Quantizer.""" diff --git a/friendli/modules/quantizer_v2/int8/base.py b/friendli/modules/quantizer_v2/int8/base.py new file mode 100644 index 00000000..66e200a8 --- /dev/null +++ b/friendli/modules/quantizer_v2/int8/base.py @@ -0,0 +1,125 @@ +# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved. + +"""Friendli Int8 Quantizer Base.""" + +from __future__ import annotations + +from abc import abstractmethod +from typing import Any, Dict, Iterator, List, Tuple, cast + +import torch +from torch.utils.data import DataLoader +from transformers import PreTrainedModel # type: ignore + +from friendli.modules.converter.utils import get_tokenizer +from friendli.modules.quantizer_v2.base import ( + AbstractQuantHookV2, + AbstractQuantizerV2, + AbstractWeightActQuantizer, + AbstractWeightOnlyQuantizer, +) +from friendli.modules.quantizer_v2.int8.utils import perform_smoothing +from friendli.modules.quantizer_v2.schema.config import Int8QuantConfig +from friendli.modules.quantizer_v2.schema.data import ModuleName +from friendli.modules.quantizer_v2.utils import collect_stats, safe_load_datasets + + +class Int8QuantHook(AbstractQuantHookV2): + """Int8 Quant Hook Base.""" + + @abstractmethod + def get_attn_fc_layer(self, decoder_layer: torch.nn.Module) -> torch.nn.Linear: + """Returns the attention fc layer in the decoder block.""" + + @abstractmethod + def get_ff2_layer(self, decoder_layer: torch.nn.Module) -> torch.nn.Linear: + """Returns the second feed-forward layer in the decoder block.""" + + @abstractmethod + def iter_pre_act_post_act_params( + self, model: PreTrainedModel + ) -> Iterator[Tuple[List[torch.Tensor], List[torch.Tensor], ModuleName]]: + """Returns iterator of pre_act_params and post_act_params per transformer block.""" + + +class Int8Quantizer(AbstractQuantizerV2): + """Int8 Quantizer Base.""" + + def get_smoothing_calib_dataloader(self) -> DataLoader: + """Get calibration dataset for Int8.""" + data_cfg = self.config.calibration_dataset + dataset = safe_load_datasets(data_cfg) + tokenizer = get_tokenizer(self.hook.model_config.name_or_path) + dataset = ( + dataset.shuffle(self.config.seed) + .select(range(data_cfg.num_samples)) + .select_columns([data_cfg.lookup_column_name]) + ) + encoded_dataset = tokenizer( + dataset[data_cfg.lookup_column_name], + return_tensors="pt", + truncation=True, + padding=True, + max_length=data_cfg.max_length, + ) + return DataLoader(encoded_dataset["input_ids"], batch_size=data_cfg.batch_size) + + def _smooth( + self, + model: PreTrainedModel, + ) -> None: + """Smooths the models before Quantization.""" + model.eval() + # collect stats for Int8 quantization scale. + with self._try_offload_model(model): + calib_dataloader = self.get_smoothing_calib_dataloader() + quant_config = cast(Int8QuantConfig, self.config) + max_input_stats, _ = collect_stats( + model, + quant_config.device, + calib_dataloader, + self.hook.get_linear_layer_types(), + tqdm_desc="Collecting stats for Smoothing.", + percentile=100.0, + ) + + for pre_act_params, post_act_params, name in cast( + Int8QuantHook, self.hook + ).iter_pre_act_post_act_params(model): + perform_smoothing( + pre_act_params, + post_act_params, + max_input_stats[name], + migration_strength=quant_config.int8_args.migration_strength, + inplace=True, + ) + + def pre_quantize( + self, + model: PreTrainedModel, + ) -> None: + """Pre-procedure that should be called before quantize() is called.""" + self._smooth(model) + + def quantize(self, model: PreTrainedModel) -> torch.nn.Module: + """Quantize the model.""" + self.pre_quantize(model) + return super().quantize(model) + + def get_quant_config(self) -> Dict[str, Any]: + """Get the quantization configuration.""" + return { + "bits": 8, + "mode": cast(Int8QuantConfig, self.config).int8_args.quant_type.value, + "zero_point": False, + "quant_method": "int8", + "quant_group_size": self.config.quant_group_size, + } + + +class Int8StaticQuantizer(Int8Quantizer, AbstractWeightActQuantizer): + """Int8 Dynamic Quantizer Base.""" + + +class Int8DynamicQuantizer(Int8Quantizer, AbstractWeightOnlyQuantizer): + """Int8 Dynamic Quantizer Base.""" diff --git a/friendli/modules/quantizer_v2/int8/utils.py b/friendli/modules/quantizer_v2/int8/utils.py new file mode 100644 index 00000000..c482f87d --- /dev/null +++ b/friendli/modules/quantizer_v2/int8/utils.py @@ -0,0 +1,97 @@ +# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved. + +"""Friendli Int8 Quantizer Base.""" + +from __future__ import annotations + +from typing import List, Tuple + +import torch + + +@torch.no_grad() +def perform_smoothing( + pre_act_params: List[torch.Tensor], + post_act_params: List[torch.Tensor], + activation_max: torch.Tensor, + *, + migration_strength: float = 0.5, + epsilon: float = 1e-5, + inplace: bool = False, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """Perform activation-weight smoothing in SmoothQuant. + + Performs the activation-weight smoothing scheme described in SmoothQuant + (Xiao et al., 2023), which migrates the amplitude of outliers from activations + to weights of matmul layers. The function takes in the following parameters: + + Args: + pre_act_params: torch.Tensors representing affine parameters + before each matmul layer. + post_act_params: torch.Tensors representing the weight matrices of the matmul layer. + activation_max: The maximum activation value of inputs of the matmul layer. + migration_strength: the strength of the activation migration. Default is 0.5. + epsilon: The epsilon used for numerical stability when calculating the scales. + Default is 1e-5. + + Returns: + A tuple of three torch.Tensors: (smoothed_pre_act_params, smoothed_post_act_params) + + The function calculates "scales" as `pow(|Activation|, migration_strength) / + pow(|Weight|, 1-migration_strength)` and applies the smoothing effect into + a normalization layer that exists before every matmul layer. This is done because + it is more efficient than introducing a new smoothing layer before every matmul layer. + Fusing the smoothing effect into the normalization layer results in a faster and + more efficient implementation of the smoothing scheme. + + The function returns the smoothed normalization coefficients and the smoothed weight + matrices after the smoothing process. + """ + # shape of activation norms: [InChannels] + # shape of fc weights: [OutChannels, InChannels] + # shape of activation_max: [InChannels] + + # pylint: disable=too-many-locals + assert pre_act_params + assert post_act_params + + in_channels = pre_act_params[0].size(0) + device = pre_act_params[0].device + dtype = pre_act_params[0].dtype + + for pre_act_param in pre_act_params: + assert pre_act_param.device == device + assert pre_act_param.dtype == dtype + + for weight in post_act_params: + assert weight.ndim == 2 + assert weight.size(1) == in_channels, (weight.size(), in_channels) + assert weight.device == device + + activation_max = activation_max.to(device=device) + weight_max = post_act_params[0].abs().max(dim=0).values + for weight in post_act_params[1:]: + weight_max = torch.maximum(weight_max, weight.abs().max(dim=0).values) + + assert tuple(activation_max.size()) == (in_channels,) + assert tuple(weight_max.size()) == (in_channels,) + alpha = migration_strength + scales = ( + ( + activation_max.to(dtype=torch.float32).pow(alpha) + / weight_max.to(dtype=torch.float32).pow(1 - alpha) + ) + .clamp(min=epsilon) + .to(dtype=dtype) + ) + + scaled_pre_act_params = [act_norm / scales for act_norm in pre_act_params] + scaled_weights = [w * scales.view(1, -1) for w in post_act_params] + + if inplace: + for dst, src in zip(pre_act_params, scaled_pre_act_params): + dst.copy_(src) + for dst, src in zip(post_act_params, scaled_weights): + dst.copy_(src) + + return scaled_pre_act_params, scaled_weights diff --git a/friendli/modules/quantizer_v2/layers.py b/friendli/modules/quantizer_v2/layers.py new file mode 100644 index 00000000..3a203210 --- /dev/null +++ b/friendli/modules/quantizer_v2/layers.py @@ -0,0 +1,94 @@ +# Copyright (c) 2022-present, FriendliAI Inc. All rights reserved. + +"""Friendli Quantization Layers.""" + +from __future__ import annotations + +from typing import Optional, cast + +import torch + +from friendli.modules.quantizer_v2.schema.data import ( + WeightActQuantResult, + WeightOnlyQuantResult, +) + + +class WeightOnlyQuantizedLinearLayer(torch.nn.Module): + """Linear Layer with weight only quantization.""" + + def __init__( + self, + in_features: int, + out_features: int, + q_weight: torch.Tensor, + weight_scale: torch.Tensor, + zeros: Optional[torch.nn.Parameter] = None, + bias: Optional[torch.nn.Parameter] = None, + ): + """Initialize the Weight Only Quantized Linear Layer.""" + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight_scale = torch.nn.Parameter(weight_scale) + self.weight = torch.nn.Parameter(q_weight, requires_grad=False) + self.register_parameter("zeros", zeros) + self.register_parameter("bias", bias) + + @staticmethod + def from_layer( + layer: torch.nn.Module, quant_result: WeightOnlyQuantResult + ) -> torch.nn.Module: + """Returns the quantized layer from the original layer.""" + zeros = ( + torch.nn.Parameter(quant_result.zero_point) + if quant_result.zero_point + else None + ) + return WeightOnlyQuantizedLinearLayer( + cast(torch.nn.Linear, layer).in_features, + cast(torch.nn.Linear, layer).out_features, + quant_result.q_weight, + quant_result.weight_scale, + zeros, + cast(torch.nn.Linear, layer).bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass with fake quantization. Not used in conversion.""" + raise NotImplementedError("Not used in conversion.") + + +class WeightActQuantizedLinearLayer(torch.nn.Module): + """Linear Layer with weight-act quantization.""" + + def __init__( # pylint: disable=too-many-arguments + self, + q_weight: torch.Tensor, + weight_scale: torch.Tensor, + act_scale: torch.Tensor, + bias: Optional[torch.nn.Parameter] = None, + ): + """Initialize the Weight Only Quantized Linear Layer.""" + super().__init__() + self.in_scale = torch.nn.Parameter(act_scale) + self.weight_scale = torch.nn.Parameter(weight_scale) + self.weight = torch.nn.Parameter(q_weight, requires_grad=False) + self.register_parameter("bias", bias) + + @staticmethod + def from_layer( + layer: torch.nn.Module, quant_result: WeightActQuantResult + ) -> torch.nn.Module: + """Returns the quantized layer from the original layer.""" + q_result = cast(WeightActQuantResult, quant_result) + return WeightActQuantizedLinearLayer( + q_result.q_weight, + q_result.weight_scale, + q_result.act_scale, + cast(torch.nn.Linear, layer).bias if hasattr(layer, "bias") else None, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass with fake quantization. Not used in conversion.""" + raise NotImplementedError("Not used in conversion.") diff --git a/friendli/modules/quantizer_v2/maps.py b/friendli/modules/quantizer_v2/maps.py new file mode 100644 index 00000000..48e972eb --- /dev/null +++ b/friendli/modules/quantizer_v2/maps.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved. + +"""Friendli Quantizer V2 Maps.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Tuple, Type, cast + +import transformers # type: ignore +from transformers import ( # type: ignore + LlamaForCausalLM, + MistralForCausalLM, + Phi3ForCausalLM, + PretrainedConfig, + PreTrainedModel, +) + +from friendli.errors import NotSupportedQuantModeError, QuantizationError +from friendli.modules.quantizer_v2.base import AbstractQuantizerV2 +from friendli.modules.quantizer_v2.enums import Int8QuantType, QuantMode +from friendli.modules.quantizer_v2.int8.base import Int8DynamicQuantizer, Int8QuantHook +from friendli.modules.quantizer_v2.models.llama import LlamaInt8QuantHook +from friendli.modules.quantizer_v2.models.phi3 import Phi3Int8QuantHook +from friendli.modules.quantizer_v2.schema.config import ( + Int8QuantConfig, + OneOfQuantConfig, +) + +model_arch_int8_hook_map: Dict[PreTrainedModel, type[Int8QuantHook]] = { + LlamaForCausalLM: LlamaInt8QuantHook, + MistralForCausalLM: LlamaInt8QuantHook, + Phi3ForCausalLM: Phi3Int8QuantHook, +} + + +def get_quanthook_map(quant_mode: QuantMode) -> Dict[Type[PreTrainedModel], Any]: + """Get quantizer map.""" + if quant_mode == QuantMode.INT8: + return model_arch_int8_hook_map + raise NotSupportedQuantModeError( + invalid_option=quant_mode, + valid_options=[e.value for e in QuantMode], + ) + + +def get_model_class(config: PretrainedConfig) -> PreTrainedModel: + """Get HuggingFace model architecture from config.""" + model_arch_list = cast(List[str], cast(PretrainedConfig, config).architectures) + if len(model_arch_list) == 0: + raise QuantizationError("Model architecture not found in config.") + model_arch = model_arch_list[0] + try: + cls_type = getattr(transformers, model_arch, None) + except AttributeError as exc: + raise QuantizationError(str(exc)) from exc + return cls_type + + +def get_quantizer_class(quant_config: OneOfQuantConfig) -> Type[AbstractQuantizerV2]: + """Get quantizer class.""" + quant_mode = quant_config.mode + if quant_mode == QuantMode.INT8: + if ( + cast(Int8QuantConfig, quant_config).int8_args.quant_type + == Int8QuantType.DYNAMIC + ): + return Int8DynamicQuantizer + raise QuantizationError( + "Only Dynamic quantization is supported for int8 quantization." + ) + raise NotSupportedQuantModeError( + invalid_option=quant_mode, + valid_options=[e.value for e in QuantMode], + ) + + +def get_hf_quantizer_factory( + model_config: PretrainedConfig, + quant_config: OneOfQuantConfig, +) -> Tuple[PreTrainedModel, AbstractQuantizerV2]: + """Get quantizer for specific model architecture with quant mode and args.""" + hf_model_cls = get_model_class(model_config) + quantizer = get_quantizer_class(quant_config) + quanthook_map = get_quanthook_map(quant_config.mode) + quanthook = quanthook_map[hf_model_cls](quant_config, model_config) + return hf_model_cls, quantizer(quanthook, quant_config) diff --git a/friendli/modules/quantizer_v2/models/llama.py b/friendli/modules/quantizer_v2/models/llama.py new file mode 100644 index 00000000..649d8471 --- /dev/null +++ b/friendli/modules/quantizer_v2/models/llama.py @@ -0,0 +1,169 @@ +# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved. + +"""Friendli LlamaForCausalLM QuantizerHook.""" + +# mypy: ignore-errors + +from __future__ import annotations + +from typing import Iterator, List, Tuple, Type, cast + +import torch +from transformers import LlamaConfig, LlamaForCausalLM, PreTrainedModel + +from friendli.errors import NotSupportedCheckpointError, QuantizationError +from friendli.modules.quantizer_v2.base import AbstractQuantHookV2 +from friendli.modules.quantizer_v2.int8.base import Int8QuantHook +from friendli.modules.quantizer_v2.schema.config import Int8QuantConfig +from friendli.modules.quantizer_v2.schema.data import ( + ModuleName, + QuantInput, + TFQuantInputs, +) + + +class LlamaQuantHook(AbstractQuantHookV2): + """BaseQuantHook for LlamaForCausalLM.""" + + def check_model_config(self) -> None: + """Check if LLaMA architectures' config can be converted to Friendli format.""" + try: + if cast(LlamaConfig, self.model_config).hidden_act not in ["silu"]: + raise NotSupportedCheckpointError( + invalid_option=f"'hidden_act={cast(LlamaConfig, self.model_config).hidden_act}'", + valid_options=["silu"], + ) + if cast(LlamaConfig, self.model_config).tie_word_embeddings: + raise NotSupportedCheckpointError( + invalid_option="'tie_word_embeddings=True'", + valid_options=[False], + ) + if cast(LlamaConfig, self.model_config).rms_norm_eps not in (1e-5, 1e-6): + raise NotSupportedCheckpointError( + invalid_option=f"'rms_norm_eps={cast(LlamaConfig, self.model_config).rms_norm_eps}'", + valid_options=[1e-5, 1e-6], + ) + except AttributeError as exc: + raise QuantizationError(str(exc)) from exc + + def get_tf_blocks(self, model: PreTrainedModel) -> List[torch.nn.Module]: + """Return the transformer blocks in LlamaForCausalLM.""" + return model.model.layers + + def get_linear_layer_types(self) -> Tuple[Type[torch.nn.Module]]: + """Return the linear layer types in LlamaForCausalLM.""" + return (torch.nn.Linear,) + + @property + def quantized_layer_prefix(self) -> str: + """The layer name prefix used before LLaMA's transformer block number.""" + return "model.layers." + + +class LlamaInt8QuantHook(LlamaQuantHook, Int8QuantHook): + """Int8QuantHook for LlamaForCausalLM.""" + + def get_attn_fc_layer(self, decoder_layer: torch.nn.Module) -> torch.nn.Linear: + """Return the linear layer after attention in the decoder layer.""" + return decoder_layer.self_attn.o_proj + + def get_ff2_layer(self, decoder_layer: torch.nn.Module) -> torch.nn.Linear: + """Return the linear layer after FF1 in the decoder layer.""" + return decoder_layer.mlp.down_proj + + def iter_pre_act_post_act_params( + self, + model: LlamaForCausalLM, + ) -> Iterator[Tuple[List[torch.Tensor], List[torch.Tensor], ModuleName]]: + """Return iterator of layernorm's weight and linear layer's weight per transformer block in LlamaForCausalLM.""" + + for index, decoder_layer in enumerate(model.model.layers): # type: ignore[union-attr] + # [LayerNorm 1] - [ QKV projection ] gets smoothed + yield ( + [ + decoder_layer.input_layernorm.weight.data, + ], + [ + decoder_layer.self_attn.q_proj.weight.data, + decoder_layer.self_attn.k_proj.weight.data, + decoder_layer.self_attn.v_proj.weight.data, + ], + f"{self.quantized_layer_prefix}{index}.self_attn.q_proj", # the input tensors fed into Q, K, V matrices are identical. + ) + # [LayerNorm 2] - [ MLP FF 1, MLP FF GATE ] gets smoothed + yield ( + [ + decoder_layer.post_attention_layernorm.weight.data, + ], + [ + decoder_layer.mlp.up_proj.weight.data, + decoder_layer.mlp.gate_proj.weight.data, + ], + f"{self.quantized_layer_prefix}{index}.mlp.up_proj", + ) + + def iter_tf_quant_inputs(self, model: PreTrainedModel) -> Iterator[TFQuantInputs]: + """Return the layers which should be quantized in transformer block of LlamaForCausalLM.""" + for index, decoder_layer in enumerate( + self.get_tf_blocks(model) # type: ignore[union-attr, arg-type] + ): + self_attn = decoder_layer.self_attn + mlp = decoder_layer.mlp + + yield TFQuantInputs( + layer_index=index, + block=decoder_layer, + quant_inputs=[ + QuantInput( + parent_module=self_attn, + target_names=[ + f"{self.quantized_layer_prefix}{index}.self_attn.q_proj", + ], + local_names=["q_proj"], + ), + QuantInput( + parent_module=self_attn, + target_names=[ + f"{self.quantized_layer_prefix}{index}.self_attn.k_proj", + ], + local_names=["k_proj"], + ), + QuantInput( + parent_module=self_attn, + target_names=[ + f"{self.quantized_layer_prefix}{index}.self_attn.v_proj", + ], + local_names=["v_proj"], + ), + QuantInput( + parent_module=self_attn, + target_names=[ + f"{self.quantized_layer_prefix}{index}.self_attn.o_proj", + ], + local_names=[ + "o_proj", + ], + ), + QuantInput( + parent_module=mlp, + target_names=[ + f"{self.quantized_layer_prefix}{index}.mlp.up_proj", + ], + local_names=["up_proj"], + ), + QuantInput( + parent_module=mlp, + target_names=[ + f"{self.quantized_layer_prefix}{index}.mlp.gate_proj", + ], + local_names=["gate_proj"], + ), + QuantInput( + parent_module=mlp, + target_names=[ + f"{self.quantized_layer_prefix}{index}.mlp.down_proj" + ], + local_names=["down_proj"], + ), + ], + ) diff --git a/friendli/modules/quantizer_v2/models/phi3.py b/friendli/modules/quantizer_v2/models/phi3.py new file mode 100644 index 00000000..0fdc095f --- /dev/null +++ b/friendli/modules/quantizer_v2/models/phi3.py @@ -0,0 +1,144 @@ +# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved. + +"""Friendli Phi3ForCausalLM QuantizerHook.""" + +# mypy: ignore-errors + +from __future__ import annotations + +from typing import Iterator, List, Tuple, Type, cast + +import torch +from transformers import Phi3Config, Phi3ForCausalLM, PreTrainedModel + +from friendli.errors import NotSupportedCheckpointError, QuantizationError +from friendli.modules.quantizer_v2.base import AbstractQuantHookV2 +from friendli.modules.quantizer_v2.int8.base import Int8QuantHook +from friendli.modules.quantizer_v2.schema.data import ( + ModuleName, + QuantInput, + TFQuantInputs, +) + + +class Phi3QuantHook(AbstractQuantHookV2): + """BaseQuantHook for Phi3ForCausalLM.""" + + def check_model_config(self) -> None: + """Check if Phi3 architectures' config can be converted to Friendli format.""" + try: + if cast(Phi3Config, self.model_config).hidden_act not in ["silu"]: + raise NotSupportedCheckpointError( + invalid_option=f"'hidden_act={cast(Phi3Config, self.model_config).hidden_act}'", + valid_options=["silu"], + ) + if cast(Phi3Config, self.model_config).tie_word_embeddings: + raise NotSupportedCheckpointError( + invalid_option="'tie_word_embeddings=True'", + valid_options=[False], + ) + if cast(Phi3Config, self.model_config).rms_norm_eps not in (1e-5, 1e-6): + raise NotSupportedCheckpointError( + invalid_option=f"'rms_norm_eps={cast(Phi3Config, self.model_config).rms_norm_eps}'", + valid_options=[1e-5, 1e-6], + ) + except AttributeError as exc: + raise QuantizationError(str(exc)) from exc + + def get_tf_blocks(self, model: PreTrainedModel) -> List[torch.nn.Module]: + """Return the transformer blocks in Phi3ForCausalLM.""" + return model.model.layers + + def get_linear_layer_types(self) -> Tuple[Type[torch.nn.Module]]: + """Return the linear layer types in Phi3ForCausalLM.""" + return (torch.nn.Linear,) + + @property + def quantized_layer_prefix(self) -> str: + """The layer name prefix used before Phi3's transformer block number.""" + return "model.layers." + + +class Phi3Int8QuantHook(Phi3QuantHook, Int8QuantHook): + """Int8QuantHook for Phi3ForCausalLM.""" + + def get_attn_fc_layer(self, decoder_layer: torch.nn.Module) -> torch.nn.Linear: + """Return the linear layer after attention in the decoder layer.""" + return decoder_layer.self_attn.o_proj + + def get_ff2_layer(self, decoder_layer: torch.nn.Module) -> torch.nn.Linear: + """Return the linear layer after FF1 in the decoder layer.""" + return decoder_layer.mlp.down_proj + + def iter_pre_act_post_act_params( + self, + model: Phi3ForCausalLM, + ) -> Iterator[Tuple[List[torch.Tensor], List[torch.Tensor], ModuleName]]: + """Return iterator of layernorm's weight and linear layer's weight per transformer block in Phi3ForCausalLM.""" + + for index, decoder_layer in enumerate(model.model.layers): # type: ignore[union-attr] + # [LayerNorm 1] - [ QKV projection ] gets smoothed + yield ( + [ + decoder_layer.input_layernorm.weight.data, + ], + [ + decoder_layer.self_attn.qkv_proj.weight.data, + ], + f"{self.quantized_layer_prefix}{index}.self_attn.qkv_proj", + ) + # [LayerNorm 2] - [ MLP FF 1, MLP FF GATE ] gets smoothed + yield ( + [ + decoder_layer.post_attention_layernorm.weight.data, + ], + [ + decoder_layer.mlp.gate_up_proj.weight.data, + ], + f"{self.quantized_layer_prefix}{index}.mlp.gate_up_proj", + ) + + def iter_tf_quant_inputs(self, model: PreTrainedModel) -> Iterator[TFQuantInputs]: + """Return the layers which should be quantized in transformer block of Phi3ForCausalLM.""" + for index, decoder_layer in enumerate( + self.get_tf_blocks(model) # type: ignore[union-attr, arg-type] + ): + self_attn = decoder_layer.self_attn + mlp = decoder_layer.mlp + + yield TFQuantInputs( + layer_index=index, + block=decoder_layer, + quant_inputs=[ + QuantInput( + parent_module=self_attn, + target_names=[ + f"{self.quantized_layer_prefix}{index}.self_attn.qkv_proj", + ], + local_names=["qkv_proj"], + ), + QuantInput( + parent_module=self_attn, + target_names=[ + f"{self.quantized_layer_prefix}{index}.self_attn.o_proj", + ], + local_names=[ + "o_proj", + ], + ), + QuantInput( + parent_module=mlp, + target_names=[ + f"{self.quantized_layer_prefix}{index}.mlp.gate_up_proj", + ], + local_names=["gate_up_proj"], + ), + QuantInput( + parent_module=mlp, + target_names=[ + f"{self.quantized_layer_prefix}{index}.mlp.down_proj" + ], + local_names=["down_proj"], + ), + ], + ) diff --git a/friendli/modules/quantizer_v2/quantize.py b/friendli/modules/quantizer_v2/quantize.py new file mode 100644 index 00000000..8187db5f --- /dev/null +++ b/friendli/modules/quantizer_v2/quantize.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved. + +"""Friendli Model Converter.""" + +from __future__ import annotations + +from typing import Optional + +from friendli.errors import TokenizerNotFoundError +from friendli.logging import logger +from friendli.modules.quantizer_v2.maps import get_hf_quantizer_factory +from friendli.modules.quantizer_v2.schema.config import OneOfQuantConfig +from friendli.modules.quantizer_v2.utils import ( + get_model_dtype, + get_model_pretrained_config, + save_tokenizer, +) + + +def quantize_checkpoint( + model_name_or_path: str, + output_dir: str, + quant_config: OneOfQuantConfig, + *, + cache_dir: Optional[str] = None, + dry_run: bool = False, +) -> None: + """Quantize HuggingFace model checkpoint to Friendli format. + + Args: + model_name_or_path (str): Hugging Face model name or local path to the checkpoint. + output_dir (str) : Directory path to save the converted checkpoint and the attribute YAML, + and tokenizer configuration file. + quant_config (OneOfQuantConfig): Quantization configuration. + cache_dir (Optional[str], optional): Path for downloading checkpoint. Defaults to None. + dry_run (bool, optional): Check only if checkpoint is convertable. Defaults to False. + + Raises: + InValidconfigError: Raised when data_type is not supported. + NotFoundError: Raised when `model_name_or_path` or `tokenizer_output_dir` is not found. + NotSupportedCheckpointError: Raised when model architecture is not supported to quantize. + """ + model_config = get_model_pretrained_config( + model_name_or_path, output_dir, cache_dir + ) + if quant_config.quant_scale_dtype is None: + model_dtype = get_model_dtype(model_config.torch_dtype) + quant_config.quant_scale_dtype = model_dtype + logger.warn( + "quant_scale_dtype is not set. Set to %s, same as hf model dtype.", + model_dtype, + ) + hf_factory, quantizer = get_hf_quantizer_factory(model_config, quant_config) + dtype = model_config.torch_dtype + quantizer.check_config() + + if not dry_run: + logger.info( + "Start loading Hugging Face checkpoint(%s) for conversion...", + model_name_or_path, + ) + model = hf_factory.from_pretrained( + model_name_or_path, + torch_dtype=dtype, + cache_dir=cache_dir, + trust_remote_code=True, + low_cpu_mem_usage=True, + # `low_cpu_mem_usage` is for model loading faster and using ~1x model size CPU memory. + # https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.from_pretrained.example + ) + logger.info( + "Hugging Face checkpoint(%s) is successfully loaded!", + model_name_or_path, + ) + model = quantizer.quantize(model) + model.config.update({"quantization_config": quantizer.get_quant_config()}) + model.save_pretrained(output_dir) + try: + save_tokenizer( + model_name_or_path=model_name_or_path, + cache_dir=cache_dir, + save_dir=output_dir, + ) + except TokenizerNotFoundError as exc: + logger.warn(str(exc)) + logger.info( + "Hugging Face checkpoint (%s) is successfully quantized to Friendli format!", + model_name_or_path, + ) diff --git a/friendli/modules/quantizer_v2/schema/__init__.py b/friendli/modules/quantizer_v2/schema/__init__.py new file mode 100644 index 00000000..f5d8dd04 --- /dev/null +++ b/friendli/modules/quantizer_v2/schema/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2022-present, FriendliAI Inc. All rights reserved. + +"""Friendli Model Quantizer Schema.""" diff --git a/friendli/modules/quantizer_v2/schema/config.py b/friendli/modules/quantizer_v2/schema/config.py new file mode 100644 index 00000000..37b481c2 --- /dev/null +++ b/friendli/modules/quantizer_v2/schema/config.py @@ -0,0 +1,77 @@ +# Copyright (c) 2022-present, FriendliAI Inc. All rights reserved. + +"""Friendli Model Quantizer Config Schema.""" + +from __future__ import annotations + +from typing import Literal, Optional, Union + +from pydantic import BaseModel, Field +from typing_extensions import Annotated + +from friendli.modules.quantizer_v2.enums import ( + Int8QuantType, + ModelDataType, + QuantDatasetFormat, + QuantMode, +) + + +class CalibrationDatasetConfig(BaseModel): + """Calibration dataset config.""" + + path_or_name: str = "cnn_dailymail:3.0.0" + format: QuantDatasetFormat = QuantDatasetFormat.JSON + split: str = "validation" + lookup_column_name: str = "article" + num_samples: int = 512 + max_length: int = 512 + batch_size: int = 1 + + +class AbstractQuantConfig(BaseModel): + """Abstract quantization config.""" + + mode: QuantMode + device: str = "cuda:0" + offload: bool = True + seed: int = 42 + percentile: float = 100.0 + quant_dtype: ModelDataType = ModelDataType.INT8 + quant_scale_dtype: Optional[ModelDataType] = None + use_symmetric: bool = True + quant_group_size: int = -1 # no grouping + calibration_dataset: CalibrationDatasetConfig = Field( + default_factory=CalibrationDatasetConfig + ) + + +class Int8QuantArtgs(BaseModel): + """Int8Quant args.""" + + migration_strength: float = 0.5 + quant_type: Int8QuantType = Int8QuantType.DYNAMIC + + +class Int8QuantConfig(AbstractQuantConfig): + """Int8Quant config.""" + + mode: Literal[QuantMode.INT8] = QuantMode.INT8 + int8_args: Int8QuantArtgs = Field(default_factory=Int8QuantArtgs) + + +class DummyQuantConfig(AbstractQuantConfig): + """Dummy quant config.""" + + mode: Literal[QuantMode.DUMMY] = QuantMode.DUMMY + + +OneOfQuantConfig = Annotated[ + Union[Int8QuantConfig, DummyQuantConfig], Field(discriminator="mode") +] + + +class QuantConfig(BaseModel): + """Quantization config.""" + + config: OneOfQuantConfig diff --git a/friendli/modules/quantizer_v2/schema/data.py b/friendli/modules/quantizer_v2/schema/data.py new file mode 100644 index 00000000..a5d8e29d --- /dev/null +++ b/friendli/modules/quantizer_v2/schema/data.py @@ -0,0 +1,66 @@ +# Copyright (c) 2022-present, FriendliAI Inc. All rights reserved. + +"""Friendli Model Quantizer Data Schema.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional + +import torch + +ModuleName = str + + +@dataclass +class BaseQuantResult: + """Dataclass for quantization result per layer.""" + + q_group_size: int + zero_point: Optional[torch.Tensor] + q_weight: torch.Tensor + weight_scale: torch.Tensor + + +@dataclass +class WeightOnlyQuantResult(BaseQuantResult): + """Dataclass for weight-only quantization result per layer.""" + + +@dataclass +class WeightActQuantResult(BaseQuantResult): + """Dataclass for weight-activation quantization result per layer.""" + + act_scale: torch.Tensor + q_group_size: int + + +@dataclass +class QuantInput: + """Dataclass for quantization input of each layer in transformer block. + + When you want to quantize specific layers at once, the target layers should be + included in this dataclass. For example, if the quantization scale of the q_proj, + k_proj, and v_proj layers in the self-attention layer are calculated together, + the target_names and local_names of these layers should be included in the + same QuantInput dataclass. + + Attributes: + parent_module: module contains target layers. + target_names: list of target module's full name + (ex. model.model.layers.0.self_attn.q_proj, ) + local_names: list of target module's name using when access from parent_module + (ex. q_proj, k_proj, v_proj ) + """ + + parent_module: torch.nn.Module + target_names: List[ModuleName] + local_names: str + + +@dataclass +class TFQuantInputs: + """Dataclass for Container of per transformer block.""" + + layer_index: int + block: torch.nn.Module + quant_inputs: List[QuantInput] diff --git a/friendli/modules/quantizer_v2/utils.py b/friendli/modules/quantizer_v2/utils.py new file mode 100644 index 00000000..368ba95b --- /dev/null +++ b/friendli/modules/quantizer_v2/utils.py @@ -0,0 +1,565 @@ +# Copyright (c) 2022-present, FriendliAI Inc. All rights reserved. + +"""Friendli Quantizer Utils.""" + +from __future__ import annotations + +import os +from contextlib import contextmanager +from pathlib import Path +from typing import ( + Callable, + Dict, + Iterable, + List, + Optional, + Protocol, + Sequence, + Tuple, + Type, + Union, +) + +import datasets # type: ignore[import] +import torch +from accelerate import cpu_offload_with_hook # type: ignore +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import ( # type: ignore + AutoConfig, + AutoTokenizer, + PretrainedConfig, + PreTrainedModel, + PreTrainedTokenizer, +) + +from friendli.errors import ( + InvalidConfigError, + NotFoundError, + QuantizationError, + TokenizerNotFoundError, +) +from friendli.logging import logger +from friendli.modules.quantizer_v2.enums import ModelDataType +from friendli.modules.quantizer_v2.schema.config import CalibrationDatasetConfig +from friendli.modules.quantizer_v2.schema.data import ( + ModuleName, + WeightActQuantResult, + WeightOnlyQuantResult, +) + + +def get_tokenizer( + model_name_or_path: str, + *, + cache_dir: Optional[str] = None, +) -> PreTrainedTokenizer: + """Try to get tokenizer of a pretrained model.""" + try: + tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, + cache_dir=cache_dir, + trust_remote_code=True, + ) + except OSError as exc: + raise TokenizerNotFoundError(str(exc)) from exc + + if not tokenizer.is_fast: + raise TokenizerNotFoundError( + "This model does not support Friendli-compatible tokenizer" + ) + + if tokenizer.pad_token != "": + tokenizer.pad_token = tokenizer.eos_token + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + return tokenizer + + +def save_tokenizer( + model_name_or_path: str, + *, + cache_dir: Optional[str] = None, + save_dir: str, +) -> Tuple[str, ...]: + """Try to save `tokenizer.json` of a pretrained model.""" + if not os.path.isdir(save_dir): + raise NotFoundError(f"Directory '{save_dir}' is not found.") + + tokenizer = get_tokenizer(model_name_or_path, cache_dir=cache_dir) + saved_file_paths = tokenizer.save_pretrained(save_directory=save_dir) + tokenizer_json_path = None + for path in saved_file_paths: + if "tokenizer.json" == os.path.basename(path): + tokenizer_json_path = path + break + + if tokenizer_json_path is None: + raise TokenizerNotFoundError( + "This model has the Friendli-compatible tokenizer implementation, but " + "'tokenizer.json' file is not found." + ) + return saved_file_paths + + +def get_model_pretrained_config( + model_name_or_path: str, model_output_path: str, cache_dir: Optional[str] = None +) -> PretrainedConfig: + """Get HuggingFace model configs.""" + try: + config = AutoConfig.from_pretrained( + model_name_or_path, cache_dir=cache_dir, trust_remote_code=True + ) + except OSError as exc: # from AutoConfig.from_pretrained() + config_dir = Path(model_name_or_path) + model_output_dir = Path(model_output_path).parent + if config_dir.exists() and model_output_dir.absolute() == config_dir.absolute(): + raise NotFoundError( + f"'output_dir' ({model_output_dir.as_posix()}) and " + f"'model_name_or_path' ({model_name_or_path}) are the same. " + "In such a case, checkpoints should be prepared in 'output_dir'." + ) from exc + raise NotFoundError(str(exc)) from exc + + return config + + +def safe_load_datasets(data_cfg: CalibrationDatasetConfig) -> datasets.Dataset: + """Load dataset from calibration dataset config.""" + data_path = data_cfg.path_or_name + data_split = data_cfg.split + + try: + if os.path.exists(data_path): + dataset = datasets.load_dataset( + data_cfg.format, + data_files=data_path, + split=data_split, + ) + else: + data_name_parts = data_path.split(":") + if len(data_name_parts) == 1: + dataset = datasets.load_dataset(data_path, split=data_split) + elif len(data_name_parts) == 2: + data_name, subset_name = data_name_parts + dataset = datasets.load_dataset( + data_name, subset_name, split=data_split + ) + else: + raise InvalidConfigError( + "Dataset name is in invalid format. " + "(valid format: '' or ':')" + ) + except ValueError as err: + raise QuantizationError(f"datasets.load_dataset failed. {str(err)}") from err + + if not isinstance(dataset, datasets.Dataset): + raise InvalidConfigError( + "This dataset format is not supported for the calibration." + ) + + return dataset + + +def build_percentile_statistics( + scale_percentile: float, + symmetric: bool = True, +) -> Tuple[Callable, Callable, Callable]: + """Builds the hooks for getting the max input and output activations of a model.""" + logger.info( + "Building percentile statistics hooks. scale_percentile: (%s)", + scale_percentile, + ) + + max_input_M1: Dict[str, torch.Tensor] = {} + max_input_M2: Dict[str, torch.Tensor] = {} + max_input_num: Dict[str, torch.Tensor] = {} + max_output_M1: Dict[str, torch.Tensor] = {} + max_output_M2: Dict[str, torch.Tensor] = {} + max_output_num: Dict[str, torch.Tensor] = {} + + def create_hook(name: ModuleName): + def update_stats( + max_M1: Dict[str, torch.Tensor], + max_M2: Dict[str, torch.Tensor], + max_num: Dict[str, int], + new_t: torch.Tensor, + ) -> None: + # Chan's method for computing mean and variance incrementally + new_t = new_t.detach().reshape(-1, new_t.size(-1)) + new_numel = new_t.size(0) + new_t_M1 = new_t.to(torch.float64).mean(dim=0) + if symmetric: + # it is assumed samples are always centered on zero + # in the symmetric quantization scheme + new_t_M1.zero_() + new_t_M2 = ((new_t.to(torch.float64) - new_t_M1) ** 2).sum(dim=0) + try: + pre_numel = max_num[name] + max_num[name] += new_numel + delta = new_t_M1 - max_M1[name] + max_M1[name] += delta * (new_numel / max_num[name]) + max_M2[name] += new_t_M2 + torch.pow(delta, 2) * ( + pre_numel * new_numel / max_num[name] + ) + except KeyError: + max_num[name] = new_numel + max_M1[name] = new_t_M1 + max_M2[name] = new_t_M2 + + def hook(module, in_t_tup, out_t): # pylint: disable=unused-argument + with torch.no_grad(): + in_t = in_t_tup[0] + update_stats(max_input_M1, max_input_M2, max_input_num, in_t) + update_stats(max_output_M1, max_output_M2, max_output_num, out_t) + + return hook + + def finish_input_stats(): + return { + name: torch.distributions.Normal( + loc=max_input_M1[name], + scale=torch.sqrt(max_input_M2[name] / max_input_num[name]).clip( + min=1e-7 + ), + ).icdf( + torch.Tensor([(scale_percentile / 100.0) * 0.5 + 0.5]).to( + max_input_M1[name].device + ) + ) + for name in list(max_input_M1.keys()) + } + + def finish_output_stats(): + return { + name: torch.distributions.Normal( + loc=max_output_M1[name], + scale=torch.sqrt(max_output_M2[name] / max_output_num[name]).clip( + min=1e-7 + ), + ).icdf( + torch.Tensor([(scale_percentile / 100.0) * 0.5 + 0.5]).to( + max_output_M1[name].device + ) + ) + for name in list(max_output_M1.keys()) + } + + return finish_input_stats, finish_output_stats, create_hook + + +def build_max_statistics() -> Tuple[Callable, Callable, Callable]: + """Builds the hooks for getting the max input and output activations of a model.""" + logger.info("Building max statistics hooks") + max_input_stats: Dict[str, torch.Tensor] = {} + max_output_stats: Dict[str, torch.Tensor] = {} + + def create_hook(name: ModuleName): + def hook(modules, in_t_tup, out_t): # pylint: disable=unused-argument + in_t = in_t_tup[0] + in_t = ( + in_t.detach().abs().reshape(-1, in_t.size(-1)).max(dim=0).values + ) # reduce-max only leaving the hidden dim (supposing the last dim is the hidden dim) + out_t = out_t.detach().reshape(-1, out_t.size(-1)) + out_t = out_t.abs().max(dim=0).values + try: + max_input_stats[name] = torch.maximum(max_input_stats[name], in_t) + except KeyError: + max_input_stats[name] = in_t + try: + max_output_stats[name] = torch.maximum(max_output_stats[name], out_t) + except KeyError: + max_output_stats[name] = out_t + + return hook + + def finish_input_stats(): + return max_input_stats + + def finish_output_stats(): + return max_output_stats + + return finish_input_stats, finish_output_stats, create_hook + + +@torch.no_grad() +def collect_stats( + model: PreTrainedModel, + device: str, + calib_dataloader: DataLoader, + target_classes: Tuple[Type[torch.nn.Module], ...], + tqdm_desc: str, + percentile: float, +) -> Tuple[Dict[ModuleName, torch.Tensor], Dict[ModuleName, torch.Tensor]]: + """Collects the maximum values of input and output activations of a specific model. + + Args: + model (torch.nn.Module): The model for which we want to collect the max statistics. + dataset (Dataset): Dataset that contains input tensors. + target_classes (Tuple[Type[torch.nn.Module], ...]): A tuple of the target classes. + + Returns: + A tuple of two dictionaries: (max_input_stats, max_output_stats), where: + max_input_stats: The maximum input activation values for each module of the model. + max_output_stats: The maximum output activation values for each module of the model. + + This function uses a forward hook to capture the maximum input and output activation values + of the specified target_classes. The max_batch_size parameter controls the size of the input + batches that are passed through the model. + + The function returns two dictionaries containing the maximum input and output activation + values for each module of the model, respectively. These dictionaries can be used to calculate + scaling factors for weight quantization and activation smoothing. + + """ + # pylint: disable=too-many-locals + max_input_stats, max_output_stats, create_hook = ( + build_percentile_statistics(percentile) + if percentile < 100.0 + else build_max_statistics() + ) + name_mods = [ + (name, module) + for name, module in model.named_modules() + if isinstance(module, target_classes) + ] + + removables = [] + for name, module in name_mods: + removables.append(module.register_forward_hook(create_hook(name))) + try: + for inputs in tqdm(calib_dataloader, desc=tqdm_desc): + model(inputs.to(device)) + finally: + for removable in removables: + removable.remove() + return max_input_stats(), max_output_stats() + + +def convert_tensor_to_quant_dtype( + param: torch.Tensor, + quant_dtype: ModelDataType, +) -> torch.Tensor: + """Convert tensor format to the given data type. + + Args: + param (torch.Tensor): The tensor to be converted. + data_type (ModelDataType): The data type of the tensor. + + Returns: + torch.Tensor: The converted tensor. + + """ + assert quant_dtype in [ModelDataType.INT4, ModelDataType.INT8] + if quant_dtype is ModelDataType.INT4: + pack_num = 8 // 4 + int4_param = torch.zeros( + (param.shape[0], param.shape[1] // pack_num), + dtype=torch.uint8, + device=param.device, + ) + for col in range(int4_param.shape[1]): + for i in range(pack_num): + int4_param[:, col] |= param[:, col * pack_num + i] << (i * 4) + param = int4_param.to(torch.int8) + + elif quant_dtype is ModelDataType.INT8: + param = param.to(torch.int8) + + return param.detach().to("cpu") + + +@torch.no_grad() +def get_weight_act_quant_scales( + model: PreTrainedModel, + layer_names: List[str], + max_input_stats: Dict[ModuleName, torch.Tensor], + device: str = "cpu", + quant_dtype: ModelDataType = ModelDataType.INT8, + quant_scale_dtype: ModelDataType = ModelDataType.FP32, +) -> List[WeightActQuantResult]: + """Get the quantization scales and int8 weight for a specific layer.""" + input_max = torch.concat([max_input_stats[name] for name in layer_names]) + target_weights = [model.get_submodule(name).weight for name in layer_names] + target_weight = torch.concat(target_weights) + + max_val = 2 ** (8 - 1) - 1 + min_val = -(2 ** (8 - 1)) + + act_scale = float(input_max.detach().abs().max().item()) / float(max_val) + weight_scale = float(target_weight.detach().abs().max().item()) / float(max_val) + + q_weights = [ + ( + convert_tensor_to_quant_dtype( + (weight.detach().float() / weight_scale).clip(min_val, max_val), + quant_dtype, + ).to(device) + ) + for weight in target_weights + ] + quant_scale_torch_dtype = get_torch_data_type(quant_scale_dtype) + return [ + WeightActQuantResult( + act_scale=torch.tensor(act_scale, dtype=quant_scale_torch_dtype), + weight_scale=torch.tensor(weight_scale, dtype=quant_scale_torch_dtype), + q_weight=q_weight, + q_group_size=-1, + zero_point=None, + ) + for _, q_weight in zip(layer_names, q_weights) + ] + + +def get_weight_only_quant_scales( + model: PreTrainedModel, + layer_names: List[str], + quant_dtype: ModelDataType, + quant_scale_dtype: ModelDataType, + q_group_size: int = -1, + use_symmetric: bool = True, + device: Union[str, torch.device] = "cpu", +) -> List[WeightOnlyQuantResult]: + """Return the quantization scales of weight for a specific layer.""" + # pylint: disable=too-many-locals + assert quant_dtype in [ModelDataType.INT4, ModelDataType.INT8] + q_bit = 4 if quant_dtype == ModelDataType.INT4 else 8 + target_weights = [model.get_submodule(name).weight for name in layer_names] + org_w_shape = target_weights[0].shape # [OutDim, InDim] + w = torch.concat(target_weights) + + if q_group_size != -1: + w = w.reshape(-1, q_group_size) # [OutDim x num_groups, group_size] + + if use_symmetric: + max_val = w.abs().amax(dim=1, keepdim=True) + max_int = 2 ** (q_bit - 1) - 1 + min_int = -(2 ** (q_bit - 1)) + scales = (max_val / float(max_int)).clamp(min=1e-5) + zeros = torch.zeros_like(max_val) + else: + max_val = w.amax(dim=1, keepdim=True) + min_val = w.amin(dim=1, keepdim=True) + max_int = 2**q_bit - 1 + min_int = 0 + + scales = (max_val - min_val).clamp(min=1e-5) / max_int + zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) + + q_weights = [ + convert_tensor_to_quant_dtype( + torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) + .reshape(org_w_shape) + .detach(), + quant_dtype, + ).to(device) + for w in target_weights + ] + quant_scale_torch_dtype = get_torch_data_type(quant_scale_dtype) + scales = ( + scales.view(org_w_shape[0], -1).detach().transpose(0, 1).to(device) + ) # [num_groups, OutDim] + zeros = ( + zeros.view(org_w_shape[0], -1).detach().transpose(0, 1).to(device) + ) # [num_groups, OutDim] + + if q_group_size == -1: + scales = scales.squeeze(0) + zeros = zeros.squeeze(0) + + return [ + WeightOnlyQuantResult( + zero_point=None if use_symmetric else zeros.to(quant_scale_torch_dtype), + q_group_size=q_group_size, + weight_scale=scales.to(quant_scale_torch_dtype), + q_weight=q_weight, + ) + for q_weight in q_weights + ] + + +def get_model_dtype(torch_dtype: torch.dtype) -> ModelDataType: + """Get torch data type from Enum.""" + if torch_dtype == torch.float16: + return ModelDataType.FP16 + if torch_dtype == torch.float32: + return ModelDataType.FP32 + if torch_dtype == torch.bfloat16: + return ModelDataType.BF16 + raise QuantizationError(f"{torch_dtype} is not valid dtype for hf model dtype.") + + +def get_torch_data_type(data_type: str) -> torch.dtype: + """Get torch data type from Enum.""" + if data_type == ModelDataType.FP16: + return torch.float16 + if data_type == ModelDataType.FP32: + return torch.float32 + if data_type == ModelDataType.BF16: + return torch.bfloat16 + raise QuantizationError( + f"Can't not converted original param to {data_type}. Only FP16, FP32, BF16 are supported." + ) + + +def send_model_to_device( + model: PreTrainedModel, + device: Union[str, torch.device], + *, + exclude: Iterable[torch.nn.Module] = (), +): + """Send the model and its submodules onto device except for modules designated by `exclude`.""" + exclude_set = set(exclude) + + @torch.no_grad() + def recurse(m: torch.nn.Module): + if m in exclude_set: + return + for name, p in list(m.named_parameters(recurse=False)): + m.register_parameter(name, torch.nn.Parameter(p.to(device))) + for name, b in list(m.named_buffers(recurse=False)): + m.register_buffer(name, b.to(device)) + + for child in m.children(): + recurse(child) + + recurse(model) + + +class RemovableOffloaderHook(Protocol): + """Hook protocol for cpu offloader.""" + + def offload(self) -> None: + """Offload the associated block onto CPU.""" + + def remove(self) -> None: + """Remove this hook.""" + + +@contextmanager +def offload_module_sequence( + blocks: Sequence[torch.nn.Module], device: Union[str, torch.device] +): + """Offload a sequence of torch modules automatically. + + In the beginning, all blocks are supposed to reside on CPU. + When i-th block is called, it is loaded onto `device` on the fly. + And at the same time, it offloads (i-1)-th block back to CPU. + """ + module_hooks: List[RemovableOffloaderHook] = [] + if blocks: + prev_module_hook = None + for tf_block in blocks: + _, module_hook = cpu_offload_with_hook( + tf_block, device, prev_module_hook=prev_module_hook + ) + prev_module_hook = module_hook + module_hooks.append(module_hook) + try: + yield + finally: + for hook in module_hooks: + hook.offload() + for hook in module_hooks: + hook.remove() diff --git a/friendli/schema/api/v1/chat/completion_chunk.py b/friendli/schema/api/v1/chat/completion_chunk.py new file mode 100644 index 00000000..07b9a224 --- /dev/null +++ b/friendli/schema/api/v1/chat/completion_chunk.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved. + +"""Friendli V1 Chat Completion Chunk Serving API Schemas.""" + +from __future__ import annotations + +from typing import List, Optional + +from pydantic import BaseModel +from typing_extensions import Literal + +from friendli.schema.api.v1.chat.completions import ChatCompletionTokenLogprob + + +class ChoiceDeltaFunctionCall(BaseModel): + arguments: Optional[str] = None + name: Optional[str] = None + + +class ChoiceDeltaToolCallFunction(BaseModel): + arguments: Optional[str] = None + name: Optional[str] = None + + +class ChoiceDeltaToolCall(BaseModel): + index: int + id: Optional[str] = None + function: Optional[ChoiceDeltaToolCallFunction] = None + type: Optional[Literal["function"]] = None + + +class ChoiceDelta(BaseModel): + content: Optional[str] = None + function_call: Optional[ChoiceDeltaFunctionCall] = None + role: Optional[Literal["system", "user", "assistant", "tool"]] = None + tool_calls: Optional[List[ChoiceDeltaToolCall]] = None + + +class ChoiceLogprobs(BaseModel): + content: Optional[List[ChatCompletionTokenLogprob]] = None + + +class Choice(BaseModel): + delta: ChoiceDelta + finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None + index: int + logprobs: Optional[ChoiceLogprobs] = None + + +class ChatCompletionChunk(BaseModel): + choices: List[Choice] + created: int diff --git a/friendli/schema/api/v1/chat/completions.py b/friendli/schema/api/v1/chat/completions.py index 34837788..931d86cb 100644 --- a/friendli/schema/api/v1/chat/completions.py +++ b/friendli/schema/api/v1/chat/completions.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import List, Optional +from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel from typing_extensions import Required, TypedDict @@ -17,11 +17,72 @@ class MessageParam(TypedDict, total=False): content: Required[str] +class ToolFunctionParam(TypedDict, total=False): + """Tool function param schema.""" + + name: Required[str] + description: Optional[str] + parameters: Dict[str, Any] + + +class ToolParam(TypedDict, total=False): + """Tool param schema.""" + + type: Required[str] + function: Required[ToolFunctionParam] + + +class ResponseFormatParam(TypedDict, total=True): + """Response format param schema.""" + + type: Required[str] + schema: Optional[str] + + +class Function(BaseModel): + """Function schema.""" + + arguments: str + name: str + + +class ChatCompletionMessageToolCall(BaseModel): + """Tool call schema.""" + + id: str + function: Function + type: Literal["function"] + + class Message(BaseModel): """Message schema.""" role: Optional[str] = None content: Optional[str] = None + tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None + + +class TopLogprob(BaseModel): + """Top logprob schema.""" + + token: str + bytes: Optional[List[int]] = None + logprob: float + + +class ChatCompletionTokenLogprob(BaseModel): + """Chat completion token log prob schema.""" + + token: str + bytes: Optional[List[int]] = None + logprob: float + top_logprobs: List[TopLogprob] + + +class ChoiceLogprobs(BaseModel): + """Schema of log prob info.""" + + content: Optional[List[ChatCompletionTokenLogprob]] = None class ChatCompletionChoice(BaseModel): @@ -30,14 +91,7 @@ class ChatCompletionChoice(BaseModel): index: int message: Message finish_reason: str - - -class ChatCompletionDeltaChoice(BaseModel): - """Schema of chat completion choice with delta.""" - - index: int - delta: Message - finish_reason: Optional[str] = None + logprobs: Optional[ChoiceLogprobs] = None class ChatCompletionUsage(BaseModel): @@ -54,10 +108,3 @@ class ChatCompletion(BaseModel): choices: List[ChatCompletionChoice] usage: ChatCompletionUsage created: int - - -class ChatCompletionLine(BaseModel): - """Chat completion line schema.""" - - choices: List[ChatCompletionDeltaChoice] - created: int diff --git a/friendli/schema/api/v1/codegen/chat_completions_pb2.py b/friendli/schema/api/v1/codegen/chat_completions_pb2.py index 856fd7d6..00ca809a 100644 --- a/friendli/schema/api/v1/codegen/chat_completions_pb2.py +++ b/friendli/schema/api/v1/codegen/chat_completions_pb2.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! -# source: friendli/schema/api/v1/codegen/chat_completions.proto +# source: chat_completions.proto # Protobuf Python Version: 5.26.1 """Generated protocol buffer code.""" from __future__ import annotations @@ -17,19 +17,33 @@ _sym_db = _symbol_database.Default() +from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 + +from friendli.schema.api.v1.codegen import response_format_pb2 as response__format__pb2 + DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n5friendli/schema/api/v1/codegen/chat_completions.proto"\xf1\x03\n\x18V1ChatCompletionsRequest\x12\x33\n\x08messages\x18\x01 \x03(\x0b\x32!.V1ChatCompletionsRequest.Message\x12\x12\n\x05model\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x1e\n\x11\x66requency_penalty\x18\x03 \x01(\x02H\x01\x88\x01\x01\x12\x17\n\nmax_tokens\x18\x05 \x01(\x05H\x02\x88\x01\x01\x12\x0e\n\x01n\x18\x06 \x01(\x05H\x03\x88\x01\x01\x12\x1d\n\x10presence_penalty\x18\x07 \x01(\x02H\x04\x88\x01\x01\x12\x0c\n\x04stop\x18\x08 \x03(\t\x12\x13\n\x06stream\x18\t \x01(\x08H\x05\x88\x01\x01\x12\x18\n\x0btemperature\x18\n \x01(\x02H\x06\x88\x01\x01\x12\x12\n\x05top_p\x18\x0b \x01(\x02H\x07\x88\x01\x01\x12!\n\x14timeout_microseconds\x18\x1e \x01(\x05H\x08\x88\x01\x01\x1a(\n\x07Message\x12\x0f\n\x07\x63ontent\x18\x01 \x01(\t\x12\x0c\n\x04role\x18\x02 \x01(\tB\x08\n\x06_modelB\x14\n\x12_frequency_penaltyB\r\n\x0b_max_tokensB\x04\n\x02_nB\x13\n\x11_presence_penaltyB\t\n\x07_streamB\x0e\n\x0c_temperatureB\x08\n\x06_top_pB\x17\n\x15_timeout_microsecondsb\x06proto3' + b'\n\x16\x63hat_completions.proto\x12\x04orca\x1a\x1cgoogle/protobuf/struct.proto\x1a\x15response_format.proto"|\n\x08ToolCall\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12)\n\x08\x66unction\x18\x03 \x01(\x0b\x32\x17.orca.ToolCall.Function\x1a+\n\x08\x46unction\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x11\n\targuments\x18\x02 \x01(\t"\xa5\x01\n\x07Message\x12\x14\n\x07\x63ontent\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x0c\n\x04role\x18\x02 \x01(\t\x12\x11\n\x04name\x18\x03 \x01(\tH\x01\x88\x01\x01\x12\x19\n\x0ctool_call_id\x18\x04 \x01(\tH\x02\x88\x01\x01\x12"\n\ntool_calls\x18\x05 \x03(\x0b\x32\x0e.orca.ToolCallB\n\n\x08_contentB\x07\n\x05_nameB\x0f\n\r_tool_call_id"\xac\x01\n\x04Tool\x12\x0c\n\x04type\x18\x01 \x01(\t\x12%\n\x08\x66unction\x18\x02 \x01(\x0b\x32\x13.orca.Tool.Function\x1ao\n\x08\x46unction\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x12+\n\nparameters\x18\x03 \x01(\x0b\x32\x17.google.protobuf.StructB\x0e\n\x0c_description"\xf6\x07\n\x18V1ChatCompletionsRequest\x12\x1f\n\x08messages\x18\x01 \x03(\x0b\x32\r.orca.Message\x12\x12\n\x05model\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x1e\n\x11\x66requency_penalty\x18\x03 \x01(\x02H\x01\x88\x01\x01\x12\x41\n\nlogit_bias\x18\x04 \x03(\x0b\x32-.orca.V1ChatCompletionsRequest.LogitBiasEntry\x12\x17\n\nmin_tokens\x18\x05 \x01(\x05H\x02\x88\x01\x01\x12\x17\n\nmax_tokens\x18\x06 \x01(\x05H\x03\x88\x01\x01\x12\x0e\n\x01n\x18\x07 \x01(\x05H\x04\x88\x01\x01\x12\x1d\n\x10presence_penalty\x18\x08 \x01(\x02H\x05\x88\x01\x01\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x13\n\x06stream\x18\n \x01(\x08H\x06\x88\x01\x01\x12\x18\n\x0btemperature\x18\x0b \x01(\x02H\x07\x88\x01\x01\x12\x12\n\x05top_p\x18\x0c \x01(\x02H\x08\x88\x01\x01\x12!\n\x14timeout_microseconds\x18\r \x01(\x05H\t\x88\x01\x01\x12\x15\n\x08logprobs\x18\x0e \x01(\x08H\n\x88\x01\x01\x12\x19\n\x0ctop_logprobs\x18\x0f \x01(\x05H\x0b\x88\x01\x01\x12\x12\n\x05top_k\x18\x13 \x01(\x05H\x0c\x88\x01\x01\x12\x1f\n\x12repetition_penalty\x18\x14 \x01(\x02H\r\x88\x01\x01\x12\x0c\n\x04seed\x18\x15 \x03(\x04\x12\x11\n\teos_token\x18\x16 \x03(\x05\x12\x19\n\x05tools\x18\x17 \x03(\x0b\x32\n.orca.Tool\x12\x32\n\x0fresponse_format\x18\x18 \x01(\x0b\x32\x14.orca.ResponseFormatH\x0e\x88\x01\x01\x12\x30\n\x0btool_choice\x18\x19 \x01(\x0b\x32\x16.google.protobuf.ValueH\x0f\x88\x01\x01\x12 \n\x13parallel_tool_calls\x18\x1a \x01(\x08H\x10\x88\x01\x01\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x08\n\x06_modelB\x14\n\x12_frequency_penaltyB\r\n\x0b_min_tokensB\r\n\x0b_max_tokensB\x04\n\x02_nB\x13\n\x11_presence_penaltyB\t\n\x07_streamB\x0e\n\x0c_temperatureB\x08\n\x06_top_pB\x17\n\x15_timeout_microsecondsB\x0b\n\t_logprobsB\x0f\n\r_top_logprobsB\x08\n\x06_top_kB\x15\n\x13_repetition_penaltyB\x12\n\x10_response_formatB\x0e\n\x0c_tool_choiceB\x16\n\x14_parallel_tool_callsb\x06proto3' ) _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages( - DESCRIPTOR, "friendli.schema.api.v1.codegen.chat_completions_pb2", _globals -) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "chat_completions_pb2", _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals["_V1CHATCOMPLETIONSREQUEST"]._serialized_start = 58 - _globals["_V1CHATCOMPLETIONSREQUEST"]._serialized_end = 555 - _globals["_V1CHATCOMPLETIONSREQUEST_MESSAGE"]._serialized_start = 379 - _globals["_V1CHATCOMPLETIONSREQUEST_MESSAGE"]._serialized_end = 419 + _globals["_V1CHATCOMPLETIONSREQUEST_LOGITBIASENTRY"]._loaded_options = None + _globals["_V1CHATCOMPLETIONSREQUEST_LOGITBIASENTRY"]._serialized_options = b"8\001" + _globals["_TOOLCALL"]._serialized_start = 85 + _globals["_TOOLCALL"]._serialized_end = 209 + _globals["_TOOLCALL_FUNCTION"]._serialized_start = 166 + _globals["_TOOLCALL_FUNCTION"]._serialized_end = 209 + _globals["_MESSAGE"]._serialized_start = 212 + _globals["_MESSAGE"]._serialized_end = 377 + _globals["_TOOL"]._serialized_start = 380 + _globals["_TOOL"]._serialized_end = 552 + _globals["_TOOL_FUNCTION"]._serialized_start = 441 + _globals["_TOOL_FUNCTION"]._serialized_end = 552 + _globals["_V1CHATCOMPLETIONSREQUEST"]._serialized_start = 555 + _globals["_V1CHATCOMPLETIONSREQUEST"]._serialized_end = 1569 + _globals["_V1CHATCOMPLETIONSREQUEST_LOGITBIASENTRY"]._serialized_start = 1247 + _globals["_V1CHATCOMPLETIONSREQUEST_LOGITBIASENTRY"]._serialized_end = 1295 # @@protoc_insertion_point(module_scope) diff --git a/friendli/schema/api/v1/codegen/chat_completions_pb2.pyi b/friendli/schema/api/v1/codegen/chat_completions_pb2.pyi index 61fc4aac..6e477f37 100644 --- a/friendli/schema/api/v1/codegen/chat_completions_pb2.pyi +++ b/friendli/schema/api/v1/codegen/chat_completions_pb2.pyi @@ -10,15 +10,93 @@ from typing import Union as _Union from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message +from google.protobuf import struct_pb2 as _struct_pb2 from google.protobuf.internal import containers as _containers +from friendli.schema.api.v1.codegen import response_format_pb2 as _response_format_pb2 + DESCRIPTOR: _descriptor.FileDescriptor +class ToolCall(_message.Message): + __slots__ = ("id", "type", "function") + + class Function(_message.Message): + __slots__ = ("name", "arguments") + NAME_FIELD_NUMBER: _ClassVar[int] + ARGUMENTS_FIELD_NUMBER: _ClassVar[int] + name: str + arguments: str + def __init__( + self, name: _Optional[str] = ..., arguments: _Optional[str] = ... + ) -> None: ... + ID_FIELD_NUMBER: _ClassVar[int] + TYPE_FIELD_NUMBER: _ClassVar[int] + FUNCTION_FIELD_NUMBER: _ClassVar[int] + id: str + type: str + function: ToolCall.Function + def __init__( + self, + id: _Optional[str] = ..., + type: _Optional[str] = ..., + function: _Optional[_Union[ToolCall.Function, _Mapping]] = ..., + ) -> None: ... + +class Message(_message.Message): + __slots__ = ("content", "role", "name", "tool_call_id", "tool_calls") + CONTENT_FIELD_NUMBER: _ClassVar[int] + ROLE_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + TOOL_CALL_ID_FIELD_NUMBER: _ClassVar[int] + TOOL_CALLS_FIELD_NUMBER: _ClassVar[int] + content: str + role: str + name: str + tool_call_id: str + tool_calls: _containers.RepeatedCompositeFieldContainer[ToolCall] + def __init__( + self, + content: _Optional[str] = ..., + role: _Optional[str] = ..., + name: _Optional[str] = ..., + tool_call_id: _Optional[str] = ..., + tool_calls: _Optional[_Iterable[_Union[ToolCall, _Mapping]]] = ..., + ) -> None: ... + +class Tool(_message.Message): + __slots__ = ("type", "function") + + class Function(_message.Message): + __slots__ = ("name", "description", "parameters") + NAME_FIELD_NUMBER: _ClassVar[int] + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + PARAMETERS_FIELD_NUMBER: _ClassVar[int] + name: str + description: str + parameters: _struct_pb2.Struct + def __init__( + self, + name: _Optional[str] = ..., + description: _Optional[str] = ..., + parameters: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., + ) -> None: ... + TYPE_FIELD_NUMBER: _ClassVar[int] + FUNCTION_FIELD_NUMBER: _ClassVar[int] + type: str + function: Tool.Function + def __init__( + self, + type: _Optional[str] = ..., + function: _Optional[_Union[Tool.Function, _Mapping]] = ..., + ) -> None: ... + class V1ChatCompletionsRequest(_message.Message): __slots__ = ( "messages", "model", "frequency_penalty", + "logit_bias", + "min_tokens", "max_tokens", "n", "presence_penalty", @@ -27,20 +105,32 @@ class V1ChatCompletionsRequest(_message.Message): "temperature", "top_p", "timeout_microseconds", + "logprobs", + "top_logprobs", + "top_k", + "repetition_penalty", + "seed", + "eos_token", + "tools", + "response_format", + "tool_choice", + "parallel_tool_calls", ) - class Message(_message.Message): - __slots__ = ("content", "role") - CONTENT_FIELD_NUMBER: _ClassVar[int] - ROLE_FIELD_NUMBER: _ClassVar[int] - content: str - role: str + class LogitBiasEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: int + value: float def __init__( - self, content: _Optional[str] = ..., role: _Optional[str] = ... + self, key: _Optional[int] = ..., value: _Optional[float] = ... ) -> None: ... MESSAGES_FIELD_NUMBER: _ClassVar[int] MODEL_FIELD_NUMBER: _ClassVar[int] FREQUENCY_PENALTY_FIELD_NUMBER: _ClassVar[int] + LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int] + MIN_TOKENS_FIELD_NUMBER: _ClassVar[int] MAX_TOKENS_FIELD_NUMBER: _ClassVar[int] N_FIELD_NUMBER: _ClassVar[int] PRESENCE_PENALTY_FIELD_NUMBER: _ClassVar[int] @@ -49,11 +139,21 @@ class V1ChatCompletionsRequest(_message.Message): TEMPERATURE_FIELD_NUMBER: _ClassVar[int] TOP_P_FIELD_NUMBER: _ClassVar[int] TIMEOUT_MICROSECONDS_FIELD_NUMBER: _ClassVar[int] - messages: _containers.RepeatedCompositeFieldContainer[ - V1ChatCompletionsRequest.Message - ] + LOGPROBS_FIELD_NUMBER: _ClassVar[int] + TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int] + TOP_K_FIELD_NUMBER: _ClassVar[int] + REPETITION_PENALTY_FIELD_NUMBER: _ClassVar[int] + SEED_FIELD_NUMBER: _ClassVar[int] + EOS_TOKEN_FIELD_NUMBER: _ClassVar[int] + TOOLS_FIELD_NUMBER: _ClassVar[int] + RESPONSE_FORMAT_FIELD_NUMBER: _ClassVar[int] + TOOL_CHOICE_FIELD_NUMBER: _ClassVar[int] + PARALLEL_TOOL_CALLS_FIELD_NUMBER: _ClassVar[int] + messages: _containers.RepeatedCompositeFieldContainer[Message] model: str frequency_penalty: float + logit_bias: _containers.ScalarMap[int, float] + min_tokens: int max_tokens: int n: int presence_penalty: float @@ -62,13 +162,23 @@ class V1ChatCompletionsRequest(_message.Message): temperature: float top_p: float timeout_microseconds: int + logprobs: bool + top_logprobs: int + top_k: int + repetition_penalty: float + seed: _containers.RepeatedScalarFieldContainer[int] + eos_token: _containers.RepeatedScalarFieldContainer[int] + tools: _containers.RepeatedCompositeFieldContainer[Tool] + response_format: _response_format_pb2.ResponseFormat + tool_choice: _struct_pb2.Value + parallel_tool_calls: bool def __init__( self, - messages: _Optional[ - _Iterable[_Union[V1ChatCompletionsRequest.Message, _Mapping]] - ] = ..., + messages: _Optional[_Iterable[_Union[Message, _Mapping]]] = ..., model: _Optional[str] = ..., frequency_penalty: _Optional[float] = ..., + logit_bias: _Optional[_Mapping[int, float]] = ..., + min_tokens: _Optional[int] = ..., max_tokens: _Optional[int] = ..., n: _Optional[int] = ..., presence_penalty: _Optional[float] = ..., @@ -77,4 +187,16 @@ class V1ChatCompletionsRequest(_message.Message): temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., timeout_microseconds: _Optional[int] = ..., + logprobs: bool = ..., + top_logprobs: _Optional[int] = ..., + top_k: _Optional[int] = ..., + repetition_penalty: _Optional[float] = ..., + seed: _Optional[_Iterable[int]] = ..., + eos_token: _Optional[_Iterable[int]] = ..., + tools: _Optional[_Iterable[_Union[Tool, _Mapping]]] = ..., + response_format: _Optional[ + _Union[_response_format_pb2.ResponseFormat, _Mapping] + ] = ..., + tool_choice: _Optional[_Union[_struct_pb2.Value, _Mapping]] = ..., + parallel_tool_calls: bool = ..., ) -> None: ... diff --git a/friendli/schema/api/v1/codegen/completions_pb2.py b/friendli/schema/api/v1/codegen/completions_pb2.py index f63bd321..04e3c4c3 100644 --- a/friendli/schema/api/v1/codegen/completions_pb2.py +++ b/friendli/schema/api/v1/codegen/completions_pb2.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! -# source: friendli/schema/api/v1/codegen/completions.proto +# source: completions.proto # Protobuf Python Version: 5.26.1 """Generated protocol buffer code.""" from __future__ import annotations @@ -17,31 +17,27 @@ _sym_db = _symbol_database.Default() +from friendli.schema.api.v1.codegen import response_format_pb2 as response__format__pb2 + DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n0friendli/schema/api/v1/codegen/completions.proto\x12\x04orca"\xe1\x0f\n\x14V1CompletionsRequest\x12\x13\n\x06stream\x18\x01 \x01(\x08H\x00\x88\x01\x01\x12\x12\n\x05model\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x13\n\x06prompt\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x0e\n\x06tokens\x18\x04 \x03(\x05\x12!\n\x14timeout_microseconds\x18\x05 \x01(\x05H\x03\x88\x01\x01\x12\x17\n\nmax_tokens\x18\x06 \x01(\x05H\x04\x88\x01\x01\x12\x1d\n\x10max_total_tokens\x18\x07 \x01(\x05H\x05\x88\x01\x01\x12\x17\n\nmin_tokens\x18\x08 \x01(\x05H\x06\x88\x01\x01\x12\x1d\n\x10min_total_tokens\x18\t \x01(\x05H\x07\x88\x01\x01\x12\x0e\n\x01n\x18\n \x01(\x05H\x08\x88\x01\x01\x12\x16\n\tnum_beams\x18\x0b \x01(\x05H\t\x88\x01\x01\x12\x1b\n\x0elength_penalty\x18\x0c \x01(\x02H\n\x88\x01\x01\x12\x1b\n\x0e\x65\x61rly_stopping\x18\r \x01(\x08H\x0b\x88\x01\x01\x12\x1c\n\x0fno_repeat_ngram\x18\x0e \x01(\x05H\x0c\x88\x01\x01\x12$\n\x17\x65ncoder_no_repeat_ngram\x18\x0f \x01(\x05H\r\x88\x01\x01\x12\x1f\n\x12repetition_penalty\x18\x10 \x01(\x02H\x0e\x88\x01\x01\x12\'\n\x1a\x65ncoder_repetition_penalty\x18\x11 \x01(\x02H\x0f\x88\x01\x01\x12\x1e\n\x11\x66requency_penalty\x18\x12 \x01(\x02H\x10\x88\x01\x01\x12\x1d\n\x10presence_penalty\x18\x13 \x01(\x02H\x11\x88\x01\x01\x12\x18\n\x0btemperature\x18\x14 \x01(\x02H\x12\x88\x01\x01\x12\x12\n\x05top_k\x18\x15 \x01(\x05H\x13\x88\x01\x01\x12\x12\n\x05top_p\x18\x16 \x01(\x02H\x14\x88\x01\x01\x12\x0c\n\x04stop\x18\x17 \x03(\t\x12=\n\x0bstop_tokens\x18\x18 \x03(\x0b\x32(.orca.V1CompletionsRequest.TokenSequence\x12\x0c\n\x04seed\x18\x19 \x03(\x04\x12\x1e\n\x16token_index_to_replace\x18\x1a \x03(\x05\x12\x1c\n\x14\x65mbedding_to_replace\x18\x1b \x03(\x02\x12H\n\x10\x62\x65\x61m_search_type\x18\x1c \x01(\x0e\x32).orca.V1CompletionsRequest.BeamSearchTypeH\x15\x88\x01\x01\x12*\n\x1d\x62\x65\x61m_compat_pre_normalization\x18\x1d \x01(\x08H\x16\x88\x01\x01\x12.\n!beam_compat_no_post_normalization\x18\x1e \x01(\x08H\x17\x88\x01\x01\x12\x11\n\tbad_words\x18\x1f \x03(\t\x12\x41\n\x0f\x62\x61\x64_word_tokens\x18 \x03(\x0b\x32(.orca.V1CompletionsRequest.TokenSequence\x12"\n\x15include_output_logits\x18! \x01(\x08H\x18\x88\x01\x01\x12$\n\x17include_output_logprobs\x18" \x01(\x08H\x19\x88\x01\x01\x12\x1c\n\x14\x66orced_output_tokens\x18# \x03(\x05\x12\x11\n\teos_token\x18$ \x03(\x05\x12G\n\x0fresponse_format\x18% \x01(\x0b\x32).orca.V1CompletionsRequest.ResponseFormatH\x1a\x88\x01\x01\x1a\x1f\n\rTokenSequence\x12\x0e\n\x06tokens\x18\x01 \x03(\x05\x1a\x9c\x01\n\x0eResponseFormat\x12<\n\x04type\x18\x01 \x01(\x0e\x32..orca.V1CompletionsRequest.ResponseFormat.Type\x12\x13\n\x06schema\x18\x02 \x01(\tH\x00\x88\x01\x01",\n\x04Type\x12\x08\n\x04text\x10\x00\x12\x0f\n\x0bjson_object\x10\x01\x12\t\n\x05regex\x10\x02\x42\t\n\x07_schema"G\n\x0e\x42\x65\x61mSearchType\x12\x11\n\rDETERMINISTIC\x10\x00\x12\x0e\n\nSTOCHASTIC\x10\x01\x12\x12\n\x0eNAIVE_SAMPLING\x10\x02\x42\t\n\x07_streamB\x08\n\x06_modelB\t\n\x07_promptB\x17\n\x15_timeout_microsecondsB\r\n\x0b_max_tokensB\x13\n\x11_max_total_tokensB\r\n\x0b_min_tokensB\x13\n\x11_min_total_tokensB\x04\n\x02_nB\x0c\n\n_num_beamsB\x11\n\x0f_length_penaltyB\x11\n\x0f_early_stoppingB\x12\n\x10_no_repeat_ngramB\x1a\n\x18_encoder_no_repeat_ngramB\x15\n\x13_repetition_penaltyB\x1d\n\x1b_encoder_repetition_penaltyB\x14\n\x12_frequency_penaltyB\x13\n\x11_presence_penaltyB\x0e\n\x0c_temperatureB\x08\n\x06_top_kB\x08\n\x06_top_pB\x13\n\x11_beam_search_typeB \n\x1e_beam_compat_pre_normalizationB$\n"_beam_compat_no_post_normalizationB\x18\n\x16_include_output_logitsB\x1a\n\x18_include_output_logprobsB\x12\n\x10_response_format"\x9e\x01\n\x15V1CompletionsResponse\x12\x30\n\x05\x65vent\x18\x01 \x01(\x0e\x32!.orca.V1CompletionsResponse.Event\x12\r\n\x05token\x18\x02 \x03(\x05\x12\x11\n\x04text\x18\x03 \x01(\tH\x00\x88\x01\x01"(\n\x05\x45vent\x12\x11\n\rTOKEN_SAMPLED\x10\x00\x12\x0c\n\x08\x43OMPLETE\x10\x01\x42\x07\n\x05_text2`\n\x15TextGenerationService\x12G\n\x08Generate\x12\x1a.orca.V1CompletionsRequest\x1a\x1b.orca.V1CompletionsResponse"\x00\x30\x01\x62\x06proto3' + b'\n\x11\x63ompletions.proto\x12\x04orca\x1a\x15response_format.proto"\xad\x0e\n\x14V1CompletionsRequest\x12\x13\n\x06stream\x18\x01 \x01(\x08H\x00\x88\x01\x01\x12\x12\n\x05model\x18\x39 \x01(\tH\x01\x88\x01\x01\x12\x13\n\x06prompt\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x0e\n\x06tokens\x18\x04 \x03(\x05\x12!\n\x14timeout_microseconds\x18\x05 \x01(\x05H\x03\x88\x01\x01\x12\x17\n\nmax_tokens\x18\x06 \x01(\x05H\x04\x88\x01\x01\x12\x1d\n\x10max_total_tokens\x18\x07 \x01(\x05H\x05\x88\x01\x01\x12\x17\n\nmin_tokens\x18\x08 \x01(\x05H\x06\x88\x01\x01\x12\x1d\n\x10min_total_tokens\x18\t \x01(\x05H\x07\x88\x01\x01\x12\x0e\n\x01n\x18\n \x01(\x05H\x08\x88\x01\x01\x12\x16\n\tnum_beams\x18\x0b \x01(\x05H\t\x88\x01\x01\x12\x1b\n\x0elength_penalty\x18\x0c \x01(\x02H\n\x88\x01\x01\x12\x1b\n\x0e\x65\x61rly_stopping\x18\x0f \x01(\x08H\x0b\x88\x01\x01\x12\x1c\n\x0fno_repeat_ngram\x18\x11 \x01(\x05H\x0c\x88\x01\x01\x12$\n\x17\x65ncoder_no_repeat_ngram\x18\x12 \x01(\x05H\r\x88\x01\x01\x12\x1f\n\x12repetition_penalty\x18\x13 \x01(\x02H\x0e\x88\x01\x01\x12\'\n\x1a\x65ncoder_repetition_penalty\x18" \x01(\x02H\x0f\x88\x01\x01\x12\x1e\n\x11\x66requency_penalty\x18\x35 \x01(\x02H\x10\x88\x01\x01\x12\x1d\n\x10presence_penalty\x18\x36 \x01(\x02H\x11\x88\x01\x01\x12\x18\n\x0btemperature\x18\x14 \x01(\x02H\x12\x88\x01\x01\x12\x12\n\x05top_k\x18\x15 \x01(\x05H\x13\x88\x01\x01\x12\x12\n\x05top_p\x18\x16 \x01(\x02H\x14\x88\x01\x01\x12\x0c\n\x04stop\x18\x17 \x03(\t\x12=\n\x0bstop_tokens\x18\x18 \x03(\x0b\x32(.orca.V1CompletionsRequest.TokenSequence\x12\x0c\n\x04seed\x18\x1a \x03(\x04\x12\x1e\n\x16token_index_to_replace\x18\x1b \x03(\x05\x12\x1c\n\x14\x65mbedding_to_replace\x18\x1c \x03(\x02\x12H\n\x10\x62\x65\x61m_search_type\x18\x1d \x01(\x0e\x32).orca.V1CompletionsRequest.BeamSearchTypeH\x15\x88\x01\x01\x12*\n\x1d\x62\x65\x61m_compat_pre_normalization\x18\x1e \x01(\x08H\x16\x88\x01\x01\x12.\n!beam_compat_no_post_normalization\x18\x1f \x01(\x08H\x17\x88\x01\x01\x12\x11\n\tbad_words\x18 \x03(\t\x12\x41\n\x0f\x62\x61\x64_word_tokens\x18! \x03(\x0b\x32(.orca.V1CompletionsRequest.TokenSequence\x12"\n\x15include_output_logits\x18/ \x01(\x08H\x18\x88\x01\x01\x12$\n\x17include_output_logprobs\x18\x32 \x01(\x08H\x19\x88\x01\x01\x12\x1c\n\x14\x66orced_output_tokens\x18\x33 \x03(\x05\x12\x11\n\teos_token\x18. \x03(\x05\x12\x32\n\x0fresponse_format\x18= \x01(\x0b\x32\x14.orca.ResponseFormatH\x1a\x88\x01\x01\x1a\x1f\n\rTokenSequence\x12\x0e\n\x06tokens\x18\x01 \x03(\x05"G\n\x0e\x42\x65\x61mSearchType\x12\x11\n\rDETERMINISTIC\x10\x00\x12\x0e\n\nSTOCHASTIC\x10\x01\x12\x12\n\x0eNAIVE_SAMPLING\x10\x02\x42\t\n\x07_streamB\x08\n\x06_modelB\t\n\x07_promptB\x17\n\x15_timeout_microsecondsB\r\n\x0b_max_tokensB\x13\n\x11_max_total_tokensB\r\n\x0b_min_tokensB\x13\n\x11_min_total_tokensB\x04\n\x02_nB\x0c\n\n_num_beamsB\x11\n\x0f_length_penaltyB\x11\n\x0f_early_stoppingB\x12\n\x10_no_repeat_ngramB\x1a\n\x18_encoder_no_repeat_ngramB\x15\n\x13_repetition_penaltyB\x1d\n\x1b_encoder_repetition_penaltyB\x14\n\x12_frequency_penaltyB\x13\n\x11_presence_penaltyB\x0e\n\x0c_temperatureB\x08\n\x06_top_kB\x08\n\x06_top_pB\x13\n\x11_beam_search_typeB \n\x1e_beam_compat_pre_normalizationB$\n"_beam_compat_no_post_normalizationB\x18\n\x16_include_output_logitsB\x1a\n\x18_include_output_logprobsB\x12\n\x10_response_format"\x9e\x01\n\x15V1CompletionsResponse\x12\x30\n\x05\x65vent\x18\x01 \x01(\x0e\x32!.orca.V1CompletionsResponse.Event\x12\r\n\x05token\x18\x02 \x03(\x05\x12\x11\n\x04text\x18\x03 \x01(\tH\x00\x88\x01\x01"(\n\x05\x45vent\x12\x11\n\rTOKEN_SAMPLED\x10\x00\x12\x0c\n\x08\x43OMPLETE\x10\x01\x42\x07\n\x05_text2`\n\x15TextGenerationService\x12G\n\x08Generate\x12\x1a.orca.V1CompletionsRequest\x1a\x1b.orca.V1CompletionsResponse"\x00\x30\x01\x62\x06proto3' ) _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages( - DESCRIPTOR, "friendli.schema.api.v1.codegen.completions_pb2", _globals -) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "completions_pb2", _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals["_V1COMPLETIONSREQUEST"]._serialized_start = 59 - _globals["_V1COMPLETIONSREQUEST"]._serialized_end = 2076 - _globals["_V1COMPLETIONSREQUEST_TOKENSEQUENCE"]._serialized_start = 1278 - _globals["_V1COMPLETIONSREQUEST_TOKENSEQUENCE"]._serialized_end = 1309 - _globals["_V1COMPLETIONSREQUEST_RESPONSEFORMAT"]._serialized_start = 1312 - _globals["_V1COMPLETIONSREQUEST_RESPONSEFORMAT"]._serialized_end = 1468 - _globals["_V1COMPLETIONSREQUEST_RESPONSEFORMAT_TYPE"]._serialized_start = 1413 - _globals["_V1COMPLETIONSREQUEST_RESPONSEFORMAT_TYPE"]._serialized_end = 1457 - _globals["_V1COMPLETIONSREQUEST_BEAMSEARCHTYPE"]._serialized_start = 1470 - _globals["_V1COMPLETIONSREQUEST_BEAMSEARCHTYPE"]._serialized_end = 1541 - _globals["_V1COMPLETIONSRESPONSE"]._serialized_start = 2079 - _globals["_V1COMPLETIONSRESPONSE"]._serialized_end = 2237 - _globals["_V1COMPLETIONSRESPONSE_EVENT"]._serialized_start = 2188 - _globals["_V1COMPLETIONSRESPONSE_EVENT"]._serialized_end = 2228 - _globals["_TEXTGENERATIONSERVICE"]._serialized_start = 2239 - _globals["_TEXTGENERATIONSERVICE"]._serialized_end = 2335 + _globals["_V1COMPLETIONSREQUEST"]._serialized_start = 51 + _globals["_V1COMPLETIONSREQUEST"]._serialized_end = 1888 + _globals["_V1COMPLETIONSREQUEST_TOKENSEQUENCE"]._serialized_start = 1249 + _globals["_V1COMPLETIONSREQUEST_TOKENSEQUENCE"]._serialized_end = 1280 + _globals["_V1COMPLETIONSREQUEST_BEAMSEARCHTYPE"]._serialized_start = 1282 + _globals["_V1COMPLETIONSREQUEST_BEAMSEARCHTYPE"]._serialized_end = 1353 + _globals["_V1COMPLETIONSRESPONSE"]._serialized_start = 1891 + _globals["_V1COMPLETIONSRESPONSE"]._serialized_end = 2049 + _globals["_V1COMPLETIONSRESPONSE_EVENT"]._serialized_start = 2000 + _globals["_V1COMPLETIONSRESPONSE_EVENT"]._serialized_end = 2040 + _globals["_TEXTGENERATIONSERVICE"]._serialized_start = 2051 + _globals["_TEXTGENERATIONSERVICE"]._serialized_end = 2147 # @@protoc_insertion_point(module_scope) diff --git a/friendli/schema/api/v1/codegen/completions_pb2.pyi b/friendli/schema/api/v1/codegen/completions_pb2.pyi index 9001d32b..0c951624 100644 --- a/friendli/schema/api/v1/codegen/completions_pb2.pyi +++ b/friendli/schema/api/v1/codegen/completions_pb2.pyi @@ -13,6 +13,8 @@ from google.protobuf import message as _message from google.protobuf.internal import containers as _containers from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from friendli.schema.api.v1.codegen import response_format_pb2 as _response_format_pb2 + DESCRIPTOR: _descriptor.FileDescriptor class V1CompletionsRequest(_message.Message): @@ -70,29 +72,6 @@ class V1CompletionsRequest(_message.Message): TOKENS_FIELD_NUMBER: _ClassVar[int] tokens: _containers.RepeatedScalarFieldContainer[int] def __init__(self, tokens: _Optional[_Iterable[int]] = ...) -> None: ... - - class ResponseFormat(_message.Message): - __slots__ = ("type", "schema") - - class Type(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = () - text: _ClassVar[V1CompletionsRequest.ResponseFormat.Type] - json_object: _ClassVar[V1CompletionsRequest.ResponseFormat.Type] - regex: _ClassVar[V1CompletionsRequest.ResponseFormat.Type] - text: V1CompletionsRequest.ResponseFormat.Type - json_object: V1CompletionsRequest.ResponseFormat.Type - regex: V1CompletionsRequest.ResponseFormat.Type - TYPE_FIELD_NUMBER: _ClassVar[int] - SCHEMA_FIELD_NUMBER: _ClassVar[int] - type: V1CompletionsRequest.ResponseFormat.Type - schema: str - def __init__( - self, - type: _Optional[ - _Union[V1CompletionsRequest.ResponseFormat.Type, str] - ] = ..., - schema: _Optional[str] = ..., - ) -> None: ... STREAM_FIELD_NUMBER: _ClassVar[int] MODEL_FIELD_NUMBER: _ClassVar[int] PROMPT_FIELD_NUMBER: _ClassVar[int] @@ -170,7 +149,7 @@ class V1CompletionsRequest(_message.Message): include_output_logprobs: bool forced_output_tokens: _containers.RepeatedScalarFieldContainer[int] eos_token: _containers.RepeatedScalarFieldContainer[int] - response_format: V1CompletionsRequest.ResponseFormat + response_format: _response_format_pb2.ResponseFormat def __init__( self, stream: bool = ..., @@ -216,7 +195,7 @@ class V1CompletionsRequest(_message.Message): forced_output_tokens: _Optional[_Iterable[int]] = ..., eos_token: _Optional[_Iterable[int]] = ..., response_format: _Optional[ - _Union[V1CompletionsRequest.ResponseFormat, _Mapping] + _Union[_response_format_pb2.ResponseFormat, _Mapping] ] = ..., ) -> None: ... diff --git a/friendli/schema/api/v1/codegen/completions_pb2_grpc.py b/friendli/schema/api/v1/codegen/completions_pb2_grpc.py index 877851ec..f8e668c2 100644 --- a/friendli/schema/api/v1/codegen/completions_pb2_grpc.py +++ b/friendli/schema/api/v1/codegen/completions_pb2_grpc.py @@ -8,9 +8,7 @@ import grpc -from friendli.schema.api.v1.codegen import ( - completions_pb2 as friendli_dot_schema_dot_api_dot_v1_dot_codegen_dot_completions__pb2, -) +from friendli.schema.api.v1.codegen import completions_pb2 as completions__pb2 GRPC_GENERATED_VERSION = "1.64.1" GRPC_VERSION = grpc.__version__ @@ -30,7 +28,7 @@ if _version_not_supported: warnings.warn( f"The grpc package installed is at version {GRPC_VERSION}," - + f" but the generated code in friendli/schema/api/v1/codegen/completions_pb2_grpc.py depends on" + + f" but the generated code in completions_pb2_grpc.py depends on" + f" grpcio>={GRPC_GENERATED_VERSION}." + f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}" + f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}." @@ -51,8 +49,8 @@ def __init__(self, channel): """ self.Generate = channel.unary_stream( "/orca.TextGenerationService/Generate", - request_serializer=friendli_dot_schema_dot_api_dot_v1_dot_codegen_dot_completions__pb2.V1CompletionsRequest.SerializeToString, - response_deserializer=friendli_dot_schema_dot_api_dot_v1_dot_codegen_dot_completions__pb2.V1CompletionsResponse.FromString, + request_serializer=completions__pb2.V1CompletionsRequest.SerializeToString, + response_deserializer=completions__pb2.V1CompletionsResponse.FromString, _registered_method=True, ) @@ -71,8 +69,8 @@ def add_TextGenerationServiceServicer_to_server(servicer, server): rpc_method_handlers = { "Generate": grpc.unary_stream_rpc_method_handler( servicer.Generate, - request_deserializer=friendli_dot_schema_dot_api_dot_v1_dot_codegen_dot_completions__pb2.V1CompletionsRequest.FromString, - response_serializer=friendli_dot_schema_dot_api_dot_v1_dot_codegen_dot_completions__pb2.V1CompletionsResponse.SerializeToString, + request_deserializer=completions__pb2.V1CompletionsRequest.FromString, + response_serializer=completions__pb2.V1CompletionsResponse.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -105,8 +103,8 @@ def Generate( request, target, "/orca.TextGenerationService/Generate", - friendli_dot_schema_dot_api_dot_v1_dot_codegen_dot_completions__pb2.V1CompletionsRequest.SerializeToString, - friendli_dot_schema_dot_api_dot_v1_dot_codegen_dot_completions__pb2.V1CompletionsResponse.FromString, + completions__pb2.V1CompletionsRequest.SerializeToString, + completions__pb2.V1CompletionsResponse.FromString, options, channel_credentials, insecure, diff --git a/friendli/schema/api/v1/codegen/response_format_pb2.py b/friendli/schema/api/v1/codegen/response_format_pb2.py new file mode 100644 index 00000000..8b6b442a --- /dev/null +++ b/friendli/schema/api/v1/codegen/response_format_pb2.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved. + +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: response_format.proto +# Protobuf Python Version: 5.26.1 +"""Generated protocol buffer code.""" +from __future__ import annotations + +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x15response_format.proto\x12\x04orca"\x87\x01\n\x0eResponseFormat\x12\'\n\x04type\x18\x01 \x01(\x0e\x32\x19.orca.ResponseFormat.Type\x12\x13\n\x06schema\x18\x02 \x01(\tH\x00\x88\x01\x01",\n\x04Type\x12\x08\n\x04text\x10\x00\x12\x0f\n\x0bjson_object\x10\x01\x12\t\n\x05regex\x10\x02\x42\t\n\x07_schemab\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "response_format_pb2", _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals["_RESPONSEFORMAT"]._serialized_start = 32 + _globals["_RESPONSEFORMAT"]._serialized_end = 167 + _globals["_RESPONSEFORMAT_TYPE"]._serialized_start = 112 + _globals["_RESPONSEFORMAT_TYPE"]._serialized_end = 156 +# @@protoc_insertion_point(module_scope) diff --git a/friendli/schema/api/v1/codegen/response_format_pb2.pyi b/friendli/schema/api/v1/codegen/response_format_pb2.pyi new file mode 100644 index 00000000..901942ae --- /dev/null +++ b/friendli/schema/api/v1/codegen/response_format_pb2.pyi @@ -0,0 +1,34 @@ +# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved. + +from __future__ import annotations + +from typing import ClassVar as _ClassVar +from typing import Optional as _Optional +from typing import Union as _Union + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper + +DESCRIPTOR: _descriptor.FileDescriptor + +class ResponseFormat(_message.Message): + __slots__ = ("type", "schema") + + class Type(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + text: _ClassVar[ResponseFormat.Type] + json_object: _ClassVar[ResponseFormat.Type] + regex: _ClassVar[ResponseFormat.Type] + text: ResponseFormat.Type + json_object: ResponseFormat.Type + regex: ResponseFormat.Type + TYPE_FIELD_NUMBER: _ClassVar[int] + SCHEMA_FIELD_NUMBER: _ClassVar[int] + type: ResponseFormat.Type + schema: str + def __init__( + self, + type: _Optional[_Union[ResponseFormat.Type, str]] = ..., + schema: _Optional[str] = ..., + ) -> None: ... diff --git a/friendli/schema/api/v1/codegen/text_to_image_pb2.py b/friendli/schema/api/v1/codegen/text_to_image_pb2.py index 346b2a7b..9f8021eb 100644 --- a/friendli/schema/api/v1/codegen/text_to_image_pb2.py +++ b/friendli/schema/api/v1/codegen/text_to_image_pb2.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! -# source: friendli/schema/api/v1/codegen/text_to_image.proto +# source: text_to_image.proto # Protobuf Python Version: 5.26.1 """Generated protocol buffer code.""" from __future__ import annotations @@ -18,16 +18,14 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n2friendli/schema/api/v1/codegen/text_to_image.proto"\xd8\x02\n\x14V1TextToImageRequest\x12\x0e\n\x06prompt\x18\x01 \x01(\t\x12\x1c\n\x0fnegative_prompt\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x18\n\x0bnum_outputs\x18\x03 \x01(\x05H\x01\x88\x01\x01\x12 \n\x13num_inference_steps\x18\x04 \x01(\x05H\x02\x88\x01\x01\x12\x1b\n\x0eguidance_scale\x18\x05 \x01(\x02H\x03\x88\x01\x01\x12\x11\n\x04seed\x18\x06 \x01(\x05H\x04\x88\x01\x01\x12\x1c\n\x0fresponse_format\x18\x07 \x01(\tH\x05\x88\x01\x01\x12\x12\n\x05model\x18\x08 \x01(\tH\x06\x88\x01\x01\x42\x12\n\x10_negative_promptB\x0e\n\x0c_num_outputsB\x16\n\x14_num_inference_stepsB\x11\n\x0f_guidance_scaleB\x07\n\x05_seedB\x12\n\x10_response_formatB\x08\n\x06_modelb\x06proto3' + b'\n\x13text_to_image.proto"\xd8\x02\n\x14V1TextToImageRequest\x12\x0e\n\x06prompt\x18\x01 \x01(\t\x12\x1c\n\x0fnegative_prompt\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x18\n\x0bnum_outputs\x18\x03 \x01(\x05H\x01\x88\x01\x01\x12 \n\x13num_inference_steps\x18\x04 \x01(\x05H\x02\x88\x01\x01\x12\x1b\n\x0eguidance_scale\x18\x05 \x01(\x02H\x03\x88\x01\x01\x12\x11\n\x04seed\x18\x06 \x01(\x05H\x04\x88\x01\x01\x12\x1c\n\x0fresponse_format\x18\x07 \x01(\tH\x05\x88\x01\x01\x12\x12\n\x05model\x18\x08 \x01(\tH\x06\x88\x01\x01\x42\x12\n\x10_negative_promptB\x0e\n\x0c_num_outputsB\x16\n\x14_num_inference_stepsB\x11\n\x0f_guidance_scaleB\x07\n\x05_seedB\x12\n\x10_response_formatB\x08\n\x06_modelb\x06proto3' ) _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages( - DESCRIPTOR, "friendli.schema.api.v1.codegen.text_to_image_pb2", _globals -) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_to_image_pb2", _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals["_V1TEXTTOIMAGEREQUEST"]._serialized_start = 55 - _globals["_V1TEXTTOIMAGEREQUEST"]._serialized_end = 399 + _globals["_V1TEXTTOIMAGEREQUEST"]._serialized_start = 24 + _globals["_V1TEXTTOIMAGEREQUEST"]._serialized_end = 368 # @@protoc_insertion_point(module_scope) diff --git a/friendli/sdk/api/base.py b/friendli/sdk/api/base.py index 9fd72d55..2e6aad86 100644 --- a/friendli/sdk/api/base.py +++ b/friendli/sdk/api/base.py @@ -24,7 +24,7 @@ from friendli.auth import get_auth_header from friendli.errors import APIError -from friendli.utils.request import DEFAULT_REQ_TIMEOUT +from friendli.utils.request import DEFAULT_REQ_TIMEOUT, transform_request_data _GenerationLine = TypeVar("_GenerationLine", bound=BaseModel) @@ -258,6 +258,8 @@ def _request( if self._endpoint_id is not None and model is not None: raise ValueError("`model` is not allowed for dedicated endpoints.") + data = transform_request_data(data) + if self._use_grpc: grpc_request = self._build_grpc_request(data=data, model=model) if not self._grpc_channel: @@ -343,6 +345,8 @@ async def _request( if self._endpoint_id is not None and model is not None: raise ValueError("`model` is not allowed for dedicated endpoints.") + data = transform_request_data(data) + if self._use_grpc: grpc_request = self._build_grpc_request(data=data, model=model) if not self._grpc_channel: diff --git a/friendli/sdk/api/chat/completions.py b/friendli/sdk/api/chat/completions.py index efe339ae..e2a103a5 100644 --- a/friendli/sdk/api/chat/completions.py +++ b/friendli/sdk/api/chat/completions.py @@ -2,20 +2,22 @@ """Friendli Completion API.""" -# pylint: disable=line-too-long, no-name-in-module +# pylint: disable=line-too-long, no-name-in-module, too-many-locals from __future__ import annotations import json -from typing import List, Literal, Optional, Type, Union, overload +from typing import Dict, List, Literal, Optional, Type, Union, overload from pydantic import ValidationError from friendli.errors import InvalidGenerationError +from friendli.schema.api.v1.chat.completion_chunk import ChatCompletionChunk from friendli.schema.api.v1.chat.completions import ( ChatCompletion, - ChatCompletionLine, MessageParam, + ResponseFormatParam, + ToolParam, ) from friendli.schema.api.v1.codegen.chat_completions_pb2 import V1ChatCompletionsRequest from friendli.sdk.api.base import ( @@ -55,12 +57,21 @@ def create( model: Optional[str] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, + repetition_penalty: Optional[float] = None, max_tokens: Optional[int] = None, + min_tokens: Optional[int] = None, n: Optional[int] = None, stop: Optional[List[str]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, + top_k: Optional[int] = None, + logit_bias: Optional[Dict[int, float]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, timeout_microseconds: Optional[int] = None, + tools: Optional[List[ToolParam]] = None, + tool_choice: Optional[str] = "auto", + response_format: Optional[ResponseFormatParam] = None, ) -> ChatCompletionStream: """[skip-doc].""" @@ -73,12 +84,21 @@ def create( model: Optional[str] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, + repetition_penalty: Optional[float] = None, max_tokens: Optional[int] = None, + min_tokens: Optional[int] = None, n: Optional[int] = None, stop: Optional[List[str]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, + top_k: Optional[int] = None, + logit_bias: Optional[Dict[int, float]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, timeout_microseconds: Optional[int] = None, + tools: Optional[List[ToolParam]] = None, + tool_choice: Optional[str] = "auto", + response_format: Optional[ResponseFormatParam] = None, ) -> ChatCompletion: """[skip-doc].""" @@ -86,16 +106,25 @@ def create( self, *, messages: List[MessageParam], - stream: bool, + stream: bool = False, model: Optional[str] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, + repetition_penalty: Optional[float] = None, max_tokens: Optional[int] = None, + min_tokens: Optional[int] = None, n: Optional[int] = None, stop: Optional[List[str]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, + top_k: Optional[int] = None, + logit_bias: Optional[Dict[int, float]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, timeout_microseconds: Optional[int] = None, + tools: Optional[List[ToolParam]] = None, + tool_choice: Optional[str] = "auto", + response_format: Optional[ResponseFormatParam] = None, ) -> Union[ChatCompletionStream, ChatCompletion]: """Creates a chat completion. @@ -121,12 +150,21 @@ def create( "stream": stream, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, + "repetition_penalty": repetition_penalty, "max_tokens": max_tokens, + "min_tokens": min_tokens, "n": n, "stop": stop, "temperature": temperature, "top_p": top_p, + "top_k": top_k, + "logit_bias": logit_bias, + "logprobs": logprobs, + "top_logprobs": top_logprobs, "timeout_microseconds": timeout_microseconds, + "tools": tools, + "tool_choice": tool_choice, + "response_format": response_format, } response = self._request(data=request_dict, stream=stream, model=model) @@ -163,12 +201,21 @@ async def create( model: Optional[str] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, + repetition_penalty: Optional[float] = None, max_tokens: Optional[int] = None, + min_tokens: Optional[int] = None, n: Optional[int] = None, stop: Optional[List[str]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, + top_k: Optional[int] = None, + logit_bias: Optional[Dict[int, float]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, timeout_microseconds: Optional[int] = None, + tools: Optional[List[ToolParam]] = None, + tool_choice: Optional[str] = "auto", + response_format: Optional[ResponseFormatParam] = None, ) -> AsyncChatCompletionStream: """[skip-doc].""" @@ -181,12 +228,21 @@ async def create( model: Optional[str] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, + repetition_penalty: Optional[float] = None, max_tokens: Optional[int] = None, + min_tokens: Optional[int] = None, n: Optional[int] = None, stop: Optional[List[str]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, + top_k: Optional[int] = None, + logit_bias: Optional[Dict[int, float]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, timeout_microseconds: Optional[int] = None, + tools: Optional[List[ToolParam]] = None, + tool_choice: Optional[str] = "auto", + response_format: Optional[ResponseFormatParam] = None, ) -> ChatCompletion: """[skip-doc].""" @@ -194,16 +250,25 @@ async def create( self, *, messages: List[MessageParam], - stream: bool, + stream: bool = False, model: Optional[str] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, + repetition_penalty: Optional[float] = None, max_tokens: Optional[int] = None, + min_tokens: Optional[int] = None, n: Optional[int] = None, stop: Optional[List[str]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, + top_k: Optional[int] = None, + logit_bias: Optional[Dict[int, float]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, timeout_microseconds: Optional[int] = None, + tools: Optional[List[ToolParam]] = None, + tool_choice: Optional[str] = "auto", + response_format: Optional[ResponseFormatParam] = None, ) -> Union[AsyncChatCompletionStream, ChatCompletion]: """Creates a completion asynchronously. @@ -229,12 +294,21 @@ async def create( "stream": stream, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, + "repetition_penalty": repetition_penalty, "max_tokens": max_tokens, + "min_tokens": min_tokens, "n": n, "stop": stop, "temperature": temperature, "top_p": top_p, + "top_k": top_k, + "logit_bias": logit_bias, + "logprobs": logprobs, + "top_logprobs": top_logprobs, "timeout_microseconds": timeout_microseconds, + "tools": tools, + "tool_choice": tool_choice, + "response_format": response_format, } response = await self._request(data=request_dict, stream=stream, model=model) @@ -243,10 +317,10 @@ async def create( return model_parse(ChatCompletion, response.json()) -class ChatCompletionStream(GenerationStream[ChatCompletionLine]): +class ChatCompletionStream(GenerationStream[ChatCompletionChunk]): """Completion stream.""" - def __next__(self) -> ChatCompletionLine: # noqa: D105 + def __next__(self) -> ChatCompletionChunk: # noqa: D105 line = next(self._iter) while not line: line = next(self._iter) @@ -257,17 +331,17 @@ def __next__(self) -> ChatCompletionLine: # noqa: D105 parsed = json.loads(data) try: - return model_parse(ChatCompletionLine, parsed) + return model_parse(ChatCompletionChunk, parsed) except ValidationError as exc: raise InvalidGenerationError( f"Generation result has invalid schema: {str(exc)}" ) from exc -class AsyncChatCompletionStream(AsyncGenerationStream[ChatCompletionLine]): +class AsyncChatCompletionStream(AsyncGenerationStream[ChatCompletionChunk]): """Asynchronous completion stream.""" - async def __anext__(self) -> ChatCompletionLine: # noqa: D105 + async def __anext__(self) -> ChatCompletionChunk: # noqa: D105 line = await self._iter.__anext__() while not line: line = await self._iter.__anext__() @@ -278,7 +352,7 @@ async def __anext__(self) -> ChatCompletionLine: # noqa: D105 parsed = json.loads(data) try: - return model_parse(ChatCompletionLine, parsed) + return model_parse(ChatCompletionChunk, parsed) except ValidationError as exc: raise InvalidGenerationError( f"Generation result has invalid schema: {str(exc)}" diff --git a/friendli/utils/request.py b/friendli/utils/request.py index 039d5955..49157af5 100644 --- a/friendli/utils/request.py +++ b/friendli/utils/request.py @@ -4,8 +4,12 @@ from __future__ import annotations +from typing import Any + +import pydantic from requests.exceptions import HTTPError +from friendli.utils.compat import model_dump from friendli.utils.url import discuss_url DEFAULT_REQ_TIMEOUT = 30 @@ -38,3 +42,17 @@ def decode_http_err(exc: HTTPError) -> str: error_str = exc.response.content.decode() return error_str + + +def transform_request_data(data: Any) -> Any: + """Transform the data to be serializable.""" + if isinstance(data, dict): + return {k: transform_request_data(v) for k, v in data.items()} + + if isinstance(data, list): + return [transform_request_data(e) for e in data] + + if isinstance(data, pydantic.BaseModel): + return model_dump(data, exclude_unset=True) + + return data diff --git a/proto/chat_completions.proto b/proto/chat_completions.proto index c01b6bc3..a5f84e89 100644 --- a/proto/chat_completions.proto +++ b/proto/chat_completions.proto @@ -4,21 +4,66 @@ syntax = "proto3"; +import "google/protobuf/struct.proto"; +import "response_format.proto"; + +package orca; + + +message ToolCall { + message Function { + string name = 1; + string arguments = 2; + } + + string id = 1; + string type = 2; + Function function = 3; +} + +message Message { + optional string content = 1; + string role = 2; + optional string name = 3; + optional string tool_call_id = 4; + repeated ToolCall tool_calls = 5; +} + +message Tool { + message Function { + string name = 1; + optional string description = 2; + google.protobuf.Struct parameters = 3; // Json schema + } + + string type = 1; + Function function = 2; +} + message V1ChatCompletionsRequest { - message Message { - string content = 1; - string role = 2; - } - - repeated Message messages = 1; - optional string model = 2; - optional float frequency_penalty = 3; - optional int32 max_tokens = 5; - optional int32 n = 6; - optional float presence_penalty = 7; - repeated string stop = 8; - optional bool stream = 9; - optional float temperature = 10; - optional float top_p = 11; - optional int32 timeout_microseconds = 30; + repeated Message messages = 1; + optional string model = 2; + optional float frequency_penalty = 3; + map logit_bias = 4; + optional int32 min_tokens = 5; + optional int32 max_tokens = 6; + optional int32 n = 7; + optional float presence_penalty = 8; + repeated string stop = 9; + optional bool stream = 10; + optional float temperature = 11; + optional float top_p = 12; + optional int32 timeout_microseconds = 13; + optional bool logprobs = 14; + optional int32 top_logprobs = 15; + optional int32 top_k = 19; + optional float repetition_penalty = 20; + repeated uint64 seed = 21; + repeated int32 eos_token = 22; + repeated Tool tools = 23; + optional ResponseFormat response_format = 24; + + // "auto", "none", "required" or {"type": "function", "function": {"name": "my_function"}} + optional google.protobuf.Value tool_choice = 25; + optional bool parallel_tool_calls = 26; } diff --git a/proto/completions.proto b/proto/completions.proto index cc151885..f338b657 100644 --- a/proto/completions.proto +++ b/proto/completions.proto @@ -6,84 +6,76 @@ syntax = "proto3"; package orca; +import "response_format.proto"; + message V1CompletionsRequest { - message TokenSequence { - repeated int32 tokens = 1; - } - - enum BeamSearchType { - DETERMINISTIC = 0; // Use the standard beam search - STOCHASTIC = 1; // Stochastic beam search by Kool et al. (2019) - NAIVE_SAMPLING = 2; // Huggingface's beam sampling - } - - message ResponseFormat { - enum Type { - text = 0; - json_object = 1; - regex = 2; - } - Type type = 1; - optional string schema = 2; - } - - optional bool stream = 1; - optional string model = 2; - optional string prompt = 3; - repeated int32 tokens = 4; - optional int32 timeout_microseconds = 5; - optional int32 max_tokens = 6; - optional int32 max_total_tokens = 7; - optional int32 min_tokens = 8; - optional int32 min_total_tokens = 9; - optional int32 n = 10; - optional int32 num_beams = 11; - optional float length_penalty = 12; - optional bool early_stopping = 13; - optional int32 no_repeat_ngram = 14; - optional int32 encoder_no_repeat_ngram = 15; - optional float repetition_penalty = 16; - optional float encoder_repetition_penalty = 17; - optional float frequency_penalty = 18; - optional float presence_penalty = 19; - optional float temperature = 20; - optional int32 top_k = 21; - optional float top_p = 22; - - repeated string stop = 23; - repeated TokenSequence stop_tokens = 24; - - repeated uint64 seed = 25; - - repeated int32 token_index_to_replace = 26; - repeated float embedding_to_replace = 27; - - optional BeamSearchType beam_search_type = 28; - optional bool beam_compat_pre_normalization = 29; - optional bool beam_compat_no_post_normalization = 30; - - repeated string bad_words = 31; - repeated TokenSequence bad_word_tokens = 32; - - optional bool include_output_logits = 33; - optional bool include_output_logprobs = 34; - repeated int32 forced_output_tokens = 35; - - repeated int32 eos_token = 36; - - optional ResponseFormat response_format = 37; + message TokenSequence { + repeated int32 tokens = 1; + } + + enum BeamSearchType { + DETERMINISTIC = 0; // Use the standard beam search + STOCHASTIC = 1; // Stochastic beam search by Kool et al. (2019) + NAIVE_SAMPLING = 2; // Huggingface's beam sampling + } + + optional bool stream = 1; + optional string model = 57; + optional string prompt = 3; + repeated int32 tokens = 4; + optional int32 timeout_microseconds = 5; + optional int32 max_tokens = 6; + optional int32 max_total_tokens = 7; + optional int32 min_tokens = 8; + optional int32 min_total_tokens = 9; + optional int32 n = 10; + optional int32 num_beams = 11; + optional float length_penalty = 12; + optional bool early_stopping = 15; + optional int32 no_repeat_ngram = 17; + optional int32 encoder_no_repeat_ngram = 18; + optional float repetition_penalty = 19; + optional float encoder_repetition_penalty = 34; + optional float frequency_penalty = 53; + optional float presence_penalty = 54; + optional float temperature = 20; + optional int32 top_k = 21; + optional float top_p = 22; + + repeated string stop = 23; + repeated TokenSequence stop_tokens = 24; + + repeated uint64 seed = 26; + + repeated int32 token_index_to_replace = 27; + repeated float embedding_to_replace = 28; + + optional BeamSearchType beam_search_type = 29; + optional bool beam_compat_pre_normalization = 30; + optional bool beam_compat_no_post_normalization = 31; + + repeated string bad_words = 32; + repeated TokenSequence bad_word_tokens = 33; + + optional bool include_output_logits = 47; + optional bool include_output_logprobs = 50; + repeated int32 forced_output_tokens = 51; + + repeated int32 eos_token = 46; + + optional ResponseFormat response_format = 61; } message V1CompletionsResponse { - enum Event { - TOKEN_SAMPLED = 0; - COMPLETE = 1; - } - Event event = 1; - repeated int32 token = 2; - optional string text = 3; + enum Event { + TOKEN_SAMPLED = 0; + COMPLETE = 1; + } + Event event = 1; + repeated int32 token = 2; + optional string text = 3; } service TextGenerationService { - rpc Generate(V1CompletionsRequest) returns (stream V1CompletionsResponse) {} + rpc Generate(V1CompletionsRequest) returns (stream V1CompletionsResponse) {} } diff --git a/proto/response_format.proto b/proto/response_format.proto new file mode 100644 index 00000000..efd8fb70 --- /dev/null +++ b/proto/response_format.proto @@ -0,0 +1,17 @@ +/* + * Copyright (c) 2023-present, FriendliAI Inc. All rights reserved. + */ + +syntax = "proto3"; + +package orca; + +message ResponseFormat { + enum Type { + text = 0; + json_object = 1; + regex = 2; + } + Type type = 1; + optional string schema = 2; // Json schema or regex +} diff --git a/pyproject.toml b/pyproject.toml index aa6c7956..dc20e960 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "friendli-client" -version = "1.4.1" +version = "1.4.2" description = "Client of Friendli Suite." license = "Apache-2.0" authors = ["FriendliAI teams "] diff --git a/scripts/fix_imports.py b/scripts/fix_imports.py new file mode 100644 index 00000000..0c53ac7a --- /dev/null +++ b/scripts/fix_imports.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved. + +"""Fix imports.""" + +from __future__ import annotations + +import argparse + + +def adjust_imports(file_path, prefix_to_attach, imports_to_fix): + with open(file_path, "r", encoding="utf-8") as file: + content = file.read() + + # Split the comma-separated import targets + imports_list = imports_to_fix.split(",") + + for import_target in imports_list: + import_target = import_target.strip() + old_import = f"import {import_target}" + new_import = f"from {prefix_to_attach} import {import_target}" + content = content.replace(old_import, new_import) + + with open(file_path, "w", encoding="utf-8") as file: + file.write(content) + + +def main(): + parser = argparse.ArgumentParser( + description="Fix import statements in generated protobuf files." + ) + parser.add_argument("--source-path", required=True, help="The path to file to fix") + parser.add_argument( + "--prefix-to-attach", + required=True, + help="The path prefix to attach (e.g., friendli.schema.api.v1.codegen)", + ) + parser.add_argument( + "--imports-to-fix", + required=True, + help="A comma-separated string of import target to fix (e.g., response_format_pb2,completions_pb2)", + ) + + args = parser.parse_args() + adjust_imports(args.source_path, args.prefix_to_attach, args.imports_to_fix) + + +if __name__ == "__main__": + main() diff --git a/tox.ini b/tox.ini index f733da06..24e9eb2d 100644 --- a/tox.ini +++ b/tox.ini @@ -10,6 +10,7 @@ description = run unittests allowlist_externals = rm poetry + pytest commands_pre = poetry install --with dev --all-extras pip install torch --index-url https://download.pytorch.org/whl/cpu @@ -42,6 +43,10 @@ commands = description = run protobuf codegen commands_pre = poetry install commands = - python -m grpc_tools.protoc -Ifriendli/schema/api/v1/codegen=proto --python_out=. --pyi_out=. --grpc_python_out=. proto/completions.proto - python -m grpc_tools.protoc -Ifriendli/schema/api/v1/codegen=proto --python_out=. --pyi_out=. proto/chat_completions.proto - python -m grpc_tools.protoc -Ifriendli/schema/api/v1/codegen=proto --python_out=. --pyi_out=. proto/text_to_image.proto + python -m grpc_tools.protoc -Iproto --python_out=friendli/schema/api/v1/codegen --pyi_out=friendli/schema/api/v1/codegen proto/chat_completions.proto proto/response_format.proto proto/text_to_image.proto + python -m grpc_tools.protoc -Iproto --python_out=friendli/schema/api/v1/codegen --pyi_out=friendli/schema/api/v1/codegen --grpc_python_out=friendli/schema/api/v1/codegen proto/completions.proto + python scripts/fix_imports.py --source-path friendli/schema/api/v1/codegen/chat_completions_pb2.py --prefix-to-attach friendli.schema.api.v1.codegen --imports-to-fix response_format_pb2 + python scripts/fix_imports.py --source-path friendli/schema/api/v1/codegen/chat_completions_pb2.pyi --prefix-to-attach friendli.schema.api.v1.codegen --imports-to-fix response_format_pb2 + python scripts/fix_imports.py --source-path friendli/schema/api/v1/codegen/completions_pb2.py --prefix-to-attach friendli.schema.api.v1.codegen --imports-to-fix response_format_pb2 + python scripts/fix_imports.py --source-path friendli/schema/api/v1/codegen/completions_pb2.pyi --prefix-to-attach friendli.schema.api.v1.codegen --imports-to-fix response_format_pb2 + python scripts/fix_imports.py --source-path friendli/schema/api/v1/codegen/completions_pb2_grpc.py --prefix-to-attach friendli.schema.api.v1.codegen --imports-to-fix completions_pb2