Skip to content

Commit

Permalink
Update of generate_step to support repetition_penalty and repetition_…
Browse files Browse the repository at this point in the history
…context_size params
  • Loading branch information
ea167 authored and rlouf committed Sep 17, 2024
1 parent 1894fa3 commit 515c197
Showing 1 changed file with 74 additions and 7 deletions.
81 changes: 74 additions & 7 deletions outlines/models/mlxlm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
import dataclasses
from typing import TYPE_CHECKING, Generator, Iterator, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Iterator,
List,
Optional,
Tuple,
TypedDict,
Union,
Generator,
)

from typing_extensions import Unpack

from .transformers import TransformerTokenizer

Expand All @@ -12,6 +23,14 @@
from outlines.processors import OutlinesLogitsProcessor



class MLXLMParams(TypedDict, total=False):
top_p: float # so top_p can be passed as a parameter to generate() without defining a sampler
repetition_penalty: float
repetition_context_size: int



class MLXLM:
"""
Represents an `mlx_lm` model
Expand All @@ -28,24 +47,28 @@ def __init__(
tokenizer._tokenizer
) # _tokenizer is HF Tokenizer


def generate(
self,
prompts: Union[str, List[str]],
generation_parameters: "GenerationParameters",
logits_processor,
sampling_parameters: "SamplingParameters",
**mlx_lm_params: Unpack[MLXLMParams],
) -> str:
streamer = self.stream(
prompts, generation_parameters, logits_processor, sampling_parameters
prompts, generation_parameters, logits_processor, sampling_parameters, **mlx_lm_params
)
return "".join(list(streamer))


def stream(
self,
prompts: Union[str, List[str]],
generation_parameters: "GenerationParameters",
logits_processor,
sampling_parameters: "SamplingParameters",
**mlx_lm_params: Unpack[MLXLMParams],
) -> Iterator[str]:
"""Generate text using `mlx_lm`.
Expand All @@ -63,6 +86,9 @@ def stream(
An instance of `SamplingParameters`, a dataclass that contains
the name of the sampler to use and related parameters as available
in Outlines.
mlx_lm_params
Of type `MLXLMParams`.
Returns
-------
The generated text.
Expand Down Expand Up @@ -100,6 +126,7 @@ def stream(
"top_p": top_p,
"sampler": sampler,
"logits_processor": logits_processor,
**mlx_lm_params
}

# Adapted from
Expand All @@ -121,40 +148,61 @@ def stream(
detokenizer.finalize()
yield detokenizer.last_segment


def generate_step(
self,
prompt: "mx.array",
temp: Optional[float],
top_p: Optional[float],
sampler: str,
logits_processor: "OutlinesLogitsProcessor",
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
) -> Generator[Tuple[int, float], None, None]:
"""
Adapted from
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L129
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L129
and updated (on Sept 2024 to add repetition_* args) from
https://github.com/ml-explore/mlx-examples/blob/bd29aec299c8fa59c161a9c1207bfc59db31d845/llms/mlx_lm/utils.py#L149
A generator producing token ids based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
Default: ``0``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
more less likely words.
sampler (str): The sampler string defined by SequenceGeneratorAdapter
logits_processor (OutlinesLogitsProcessor): Augment logits before sampling.
repetition_penalty (float, optional): The penalty factor for repeating tokens.
1.0 for no penalty. >1.0 for penalty. Default: ``None``.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
one token and a vector of log probabilities.
"""
import mlx.core as mx
import mlx_lm

temperature: float = temp or 1.0
if repetition_penalty:
if not isinstance(repetition_penalty, float) or repetition_penalty <= 0:
raise ValueError(
f"repetition_penalty must be a non-negative float, got {repetition_penalty}" )
if not isinstance(repetition_context_size, int) or repetition_context_size <= 2:
raise ValueError(
f"repetition_context_size must be a positive integer > 2, got {repetition_context_size}" )


def sample(logits: "mx.array") -> Tuple["mx.array", float]:
softmax_logits = mx.softmax(logits)

if temperature == 0.0 or sampler == "greedy":
if temp == 0.0 or sampler == "greedy": # temp == 0, not temperature, which can never be 0
token = mx.argmax(logits, axis=-1)
elif sampler == "multinomial":
temperature: float = temp or 1.0
if top_p is not None and top_p > 0 and top_p < 1.0:
token = mlx_lm.sample_utils.top_p_sampling(
logits, top_p, temperature
Expand All @@ -167,13 +215,22 @@ def sample(logits: "mx.array") -> Tuple["mx.array", float]:
prob = softmax_logits[0, token]
return token, prob


# Create the KV cache for generation
kv_heads = (
[self.model.n_kv_heads] * len(self.model.layers)
if isinstance(self.model.n_kv_heads, int)
else self.model.n_kv_heads
)
cache = [mlx_lm.models.base.KVCache(self.model.head_dim, n) for n in kv_heads]


# Init the repetition context
repetition_context = prompt.tolist()
if repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]


# kv cache contains processed input IDs, we pass the unprocessed inputs and cache to model()
unprocessed_input_ids = prompt
generated_ids: List[int] = []
Expand All @@ -182,6 +239,10 @@ def sample(logits: "mx.array") -> Tuple["mx.array", float]:
logits = self.model(unprocessed_input_ids[None], cache=cache)
logits = logits[:, -1, :]

if repetition_penalty:
logits = mlx_lm.utils.apply_repetition_penalty(
logits, repetition_context, repetition_penalty )

if logits_processor is not None:
# convert to logits_processor 1d expectation, apply, then convert back
logits_1d = logits.reshape(-1)
Expand All @@ -191,11 +252,17 @@ def sample(logits: "mx.array") -> Tuple["mx.array", float]:
new_token_single, prob = sample(logits)
new_token = new_token_single.item()
yield new_token, prob

if repetition_penalty:
repetition_context.append(new_token)
if repetition_context_size and len(repetition_context) > repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]

generated_ids.append(new_token)
unprocessed_input_ids = new_token_single



def mlxlm(
model_name: str,
tokenizer_config: dict = {},
Expand Down

0 comments on commit 515c197

Please sign in to comment.