diff --git a/outlines/models/mlxlm.py b/outlines/models/mlxlm.py index 6e63ef5b6..84a8e5c69 100644 --- a/outlines/models/mlxlm.py +++ b/outlines/models/mlxlm.py @@ -1,5 +1,16 @@ import dataclasses -from typing import TYPE_CHECKING, Generator, Iterator, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Iterator, + List, + Optional, + Tuple, + TypedDict, + Union, + Generator, +) + +from typing_extensions import Unpack from .transformers import TransformerTokenizer @@ -12,6 +23,14 @@ from outlines.processors import OutlinesLogitsProcessor + +class MLXLMParams(TypedDict, total=False): + top_p: float # so top_p can be passed as a parameter to generate() without defining a sampler + repetition_penalty: float + repetition_context_size: int + + + class MLXLM: """ Represents an `mlx_lm` model @@ -28,24 +47,28 @@ def __init__( tokenizer._tokenizer ) # _tokenizer is HF Tokenizer + def generate( self, prompts: Union[str, List[str]], generation_parameters: "GenerationParameters", logits_processor, sampling_parameters: "SamplingParameters", + **mlx_lm_params: Unpack[MLXLMParams], ) -> str: streamer = self.stream( - prompts, generation_parameters, logits_processor, sampling_parameters + prompts, generation_parameters, logits_processor, sampling_parameters, **mlx_lm_params ) return "".join(list(streamer)) + def stream( self, prompts: Union[str, List[str]], generation_parameters: "GenerationParameters", logits_processor, sampling_parameters: "SamplingParameters", + **mlx_lm_params: Unpack[MLXLMParams], ) -> Iterator[str]: """Generate text using `mlx_lm`. @@ -63,6 +86,9 @@ def stream( An instance of `SamplingParameters`, a dataclass that contains the name of the sampler to use and related parameters as available in Outlines. + mlx_lm_params + Of type `MLXLMParams`. + Returns ------- The generated text. @@ -100,6 +126,7 @@ def stream( "top_p": top_p, "sampler": sampler, "logits_processor": logits_processor, + **mlx_lm_params } # Adapted from @@ -121,6 +148,7 @@ def stream( detokenizer.finalize() yield detokenizer.last_segment + def generate_step( self, prompt: "mx.array", @@ -128,33 +156,53 @@ def generate_step( top_p: Optional[float], sampler: str, logits_processor: "OutlinesLogitsProcessor", + repetition_penalty: Optional[float] = None, + repetition_context_size: Optional[int] = 20, ) -> Generator[Tuple[int, float], None, None]: """ Adapted from - https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L129 + https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L129 + and updated (on Sept 2024 to add repetition_* args) from + https://github.com/ml-explore/mlx-examples/blob/bd29aec299c8fa59c161a9c1207bfc59db31d845/llms/mlx_lm/utils.py#L149 A generator producing token ids based on the given prompt from the model. Args: prompt (mx.array): The input prompt. temp (float): The temperature for sampling, if 0 the argmax is used. - Default: ``0``. + Default: ``0``. top_p (float, optional): Nulceus sampling, higher means model considers - more less likely words. + more less likely words. sampler (str): The sampler string defined by SequenceGeneratorAdapter logits_processor (OutlinesLogitsProcessor): Augment logits before sampling. + repetition_penalty (float, optional): The penalty factor for repeating tokens. + 1.0 for no penalty. >1.0 for penalty. Default: ``None``. + repetition_context_size (int, optional): The number of tokens to + consider for repetition penalty. Default: ``20``. + + Yields: + Generator[Tuple[mx.array, mx.array], None, None]: A generator producing + one token and a vector of log probabilities. """ import mlx.core as mx import mlx_lm - temperature: float = temp or 1.0 + if repetition_penalty: + if not isinstance(repetition_penalty, float) or repetition_penalty <= 0: + raise ValueError( + f"repetition_penalty must be a non-negative float, got {repetition_penalty}" ) + if not isinstance(repetition_context_size, int) or repetition_context_size <= 2: + raise ValueError( + f"repetition_context_size must be a positive integer > 2, got {repetition_context_size}" ) + def sample(logits: "mx.array") -> Tuple["mx.array", float]: softmax_logits = mx.softmax(logits) - if temperature == 0.0 or sampler == "greedy": + if temp == 0.0 or sampler == "greedy": # temp == 0, not temperature, which can never be 0 token = mx.argmax(logits, axis=-1) elif sampler == "multinomial": + temperature: float = temp or 1.0 if top_p is not None and top_p > 0 and top_p < 1.0: token = mlx_lm.sample_utils.top_p_sampling( logits, top_p, temperature @@ -167,6 +215,8 @@ def sample(logits: "mx.array") -> Tuple["mx.array", float]: prob = softmax_logits[0, token] return token, prob + + # Create the KV cache for generation kv_heads = ( [self.model.n_kv_heads] * len(self.model.layers) if isinstance(self.model.n_kv_heads, int) @@ -174,6 +224,13 @@ def sample(logits: "mx.array") -> Tuple["mx.array", float]: ) cache = [mlx_lm.models.base.KVCache(self.model.head_dim, n) for n in kv_heads] + + # Init the repetition context + repetition_context = prompt.tolist() + if repetition_context_size: + repetition_context = repetition_context[-repetition_context_size:] + + # kv cache contains processed input IDs, we pass the unprocessed inputs and cache to model() unprocessed_input_ids = prompt generated_ids: List[int] = [] @@ -182,6 +239,10 @@ def sample(logits: "mx.array") -> Tuple["mx.array", float]: logits = self.model(unprocessed_input_ids[None], cache=cache) logits = logits[:, -1, :] + if repetition_penalty: + logits = mlx_lm.utils.apply_repetition_penalty( + logits, repetition_context, repetition_penalty ) + if logits_processor is not None: # convert to logits_processor 1d expectation, apply, then convert back logits_1d = logits.reshape(-1) @@ -191,11 +252,17 @@ def sample(logits: "mx.array") -> Tuple["mx.array", float]: new_token_single, prob = sample(logits) new_token = new_token_single.item() yield new_token, prob + + if repetition_penalty: + repetition_context.append(new_token) + if repetition_context_size and len(repetition_context) > repetition_context_size: + repetition_context = repetition_context[-repetition_context_size:] generated_ids.append(new_token) unprocessed_input_ids = new_token_single + def mlxlm( model_name: str, tokenizer_config: dict = {},