forked from dottxt-ai/outlines
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
592 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
import dataclasses | ||
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union | ||
|
||
from .transformers import TransformerTokenizer | ||
|
||
if TYPE_CHECKING: | ||
import mlx.nn as nn | ||
from transformers import PreTrainedTokenizer | ||
|
||
from outlines.generate.api import GenerationParameters, SamplingParameters | ||
|
||
try: | ||
import mlx.core as mx | ||
import mlx_lm | ||
except ImportError: | ||
pass | ||
|
||
|
||
class MLXLM: | ||
""" | ||
Represents an `mlx_lm` model | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: "nn.Module", | ||
tokenizer: "PreTrainedTokenizer", | ||
): | ||
self.model = model | ||
self.tokenizer = TransformerTokenizer(tokenizer._tokenizer) # HF Tokenizer | ||
|
||
def generate( | ||
self, | ||
prompts: Union[str, List[str]], | ||
generation_parameters: "GenerationParameters", | ||
logits_processor, | ||
sampling_parameters: "SamplingParameters", | ||
) -> str: | ||
streamer = self.stream( | ||
prompts, generation_parameters, logits_processor, sampling_parameters | ||
) | ||
return "".join(list(streamer)) | ||
|
||
def stream( | ||
self, | ||
prompts: Union[str, List[str]], | ||
generation_parameters: "GenerationParameters", | ||
logits_processor, | ||
sampling_parameters: "SamplingParameters", | ||
) -> Iterator[str]: | ||
"""Generate text using `mlx_lm`. | ||
Arguments | ||
--------- | ||
prompts | ||
A prompt or list of prompts. | ||
generation_parameters | ||
An instance of `GenerationParameters` that contains the prompt, | ||
the maximum number of tokens, stop sequences and seed. All the | ||
arguments to `SequenceGeneratorAdapter`'s `__cal__` method. | ||
logits_processor | ||
The logits processor to use when generating text. | ||
sampling_parameters | ||
An instance of `SamplingParameters`, a dataclass that contains | ||
the name of the sampler to use and related parameters as available | ||
in Outlines. | ||
Returns | ||
------- | ||
The generated text. | ||
""" | ||
max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters) | ||
sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple( | ||
sampling_parameters | ||
) | ||
|
||
if not isinstance(prompts, str): | ||
raise NotImplementedError( | ||
"The `mlx-lm` library does not support batch inference." | ||
) | ||
if sampler == "beam_search": | ||
raise NotImplementedError( | ||
"The `mlx-lm` library does not support Beam Search." | ||
) | ||
if num_samples != 1: | ||
raise NotImplementedError( | ||
"The `mlx-lm` library does not allow to take several samples." | ||
) | ||
|
||
# PR TODO: error if top_k or seed are set OR implement them | ||
generate_kwargs = { | ||
"temp": temperature, | ||
"top_p": top_p, | ||
"sampler": sampler, | ||
"logits_processor": logits_processor, | ||
} | ||
|
||
# Adapted from | ||
# https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L267 | ||
np_tokens = self.tokenizer.encode(prompts)[0].numpy() | ||
prompt_tokens = mx.array(np_tokens) | ||
|
||
for (token, prob), n in zip( | ||
self.generate_step(prompt_tokens, **generate_kwargs), | ||
range(max_tokens), | ||
): | ||
if token == self.tokenizer.eos_token_id: | ||
break # PR TODO: should use stop_at instead? | ||
yield self.tokenizer.decode([token])[0] | ||
|
||
def generate_step(self, prompt, temp, top_p, sampler, logits_processor): | ||
""" | ||
Adapted from | ||
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L129 | ||
A generator producing token ids based on the given prompt from the model. | ||
Args: | ||
prompt (mx.array): The input prompt. | ||
temp (float): The temperature for sampling, if 0 the argmax is used. | ||
Default: ``0``. | ||
top_p (float, optional): Nulceus sampling, higher means model considers | ||
more less likely words. | ||
sampler (str): PR TODO | ||
logits_processor: PR TODO | ||
Yields: | ||
Generator[Tuple[mx.array, mx.array]]: A generator producing | ||
one token and probability per call. | ||
""" | ||
# PR TODO: Type hints in fn signature | ||
|
||
temp = temp or 1.0 | ||
|
||
def sample(logits: "mx.array") -> Tuple["mx.array", float]: | ||
softmax_logits = mx.softmax(logits) | ||
|
||
if temp == 0 or sampler == "greedy": | ||
token = mx.argmax(logits, axis=-1) | ||
elif sampler == "multinomial": | ||
if top_p is not None and top_p > 0 and top_p < 1.0: | ||
token = mlx_lm.sample_utils.top_p_sampling(logits, top_p, temp) | ||
else: | ||
token = mx.random.categorical(logits * (1 / temp)) | ||
else: | ||
raise ValueError(f"Invalid mlx-lm sampler: `{sampler}`") | ||
|
||
prob = softmax_logits[0, token] | ||
return token, prob | ||
|
||
kv_heads = ( | ||
[self.model.n_kv_heads] * len(self.model.layers) | ||
if isinstance(self.model.n_kv_heads, int) | ||
else self.model.n_kv_heads | ||
) | ||
cache = [mlx_lm.models.base.KVCache(self.model.head_dim, n) for n in kv_heads] | ||
|
||
input_ids = prompt | ||
while True: | ||
logits = self.model(input_ids, cache=cache) | ||
logits = logits[:, -1, :] | ||
|
||
if logits_processor is not None: | ||
logits = logits_processor(input_ids, logits) | ||
|
||
new_token, prob = sample(logits) | ||
yield new_token.item(), prob | ||
|
||
input_ids = mx.concatenate([input_ids, new_token[None]], axis=1) | ||
|
||
|
||
def mlxlm( | ||
model_name: str, | ||
tokenizer_config: dict = {}, | ||
model_config: dict = {}, | ||
adapter_path: Optional[str] = None, | ||
lazy: bool = False, | ||
): | ||
"""Instantiate a model from the `mlx_lm` library and its tokenizer. | ||
Signature adapted from | ||
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L422 | ||
Parameters | ||
---------- | ||
Args: | ||
path_or_hf_repo (Path): The path or the huggingface repository to load the model from. | ||
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. | ||
Defaults to an empty dictionary. | ||
model_config(dict, optional): Configuration parameters specifically for the model. | ||
Defaults to an empty dictionary. | ||
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers | ||
to the model. Default: ``None``. | ||
lazy (bool): If False eval the model parameters to make sure they are | ||
loaded in memory before returning, otherwise they will be loaded | ||
when needed. Default: ``False`` | ||
Returns | ||
------- | ||
A `MLXLM` model instance. | ||
""" | ||
try: | ||
import mlx.core as mx | ||
import mlx_lm | ||
except ImportError: | ||
raise ImportError( | ||
"The `mlx_lm` library needs to be installed in order to use `mlx_lm` models." | ||
) | ||
if not mx.metal.is_available(): | ||
raise RuntimeError("You cannot use `mlx_lm` without Apple Silicon (Metal)") | ||
|
||
model, tokenizer = mlx_lm.load( | ||
model_name, | ||
tokenizer_config=tokenizer_config, | ||
model_config=model_config, | ||
adapter_path=adapter_path, | ||
lazy=lazy, | ||
) | ||
return MLXLM(model, tokenizer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .structured import ( | ||
CFGLogitsProcessor, | ||
FSMLogitsProcessor, | ||
JSONLogitsProcessor, | ||
RegexLogitsProcessor, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from abc import abstractmethod | ||
from typing import List, Protocol, Union | ||
|
||
import numpy as np | ||
import torch | ||
from numpy.typing import NDArray | ||
|
||
|
||
def is_mlx_array(logits): | ||
try: | ||
import mlx.core as mx | ||
except ImportError: | ||
return False | ||
return isinstance(logits, mx.array) | ||
|
||
|
||
class BaseLogitsProcessor(Protocol): | ||
""" | ||
Base class for logits processors which normalizes types of logits: | ||
- ndarray (used by llama-cpp-python), converted to torch.Tensor | ||
- torch.Tensor (used by everything else) | ||
Normalization of types and conversion to torch.Tensor | ||
doesn't move memory, it just casts the type. | ||
Normalizing the types allows all logits processors inheriting from this class | ||
to implement a single method for all the business logit: `process_logits()` | ||
""" | ||
|
||
@abstractmethod | ||
def process_logits( | ||
self, input_ids: List[int], logits: torch.Tensor | ||
) -> torch.Tensor: | ||
... | ||
|
||
def __call__( | ||
self, | ||
input_ids: Union[NDArray[np.int64], List[int], torch.Tensor], | ||
logits: Union[NDArray[np.float32], torch.Tensor], | ||
) -> Union[NDArray[np.int64], torch.Tensor]: | ||
""" | ||
Apply logits processor | ||
Unify type | ||
- convert input_ids: either ndarray, List[int], or Tensor -> List[int] | ||
- convert logits: either ndarray, mlx array, Tensor -> Tensor | ||
Call process_logits() to perform business logic | ||
""" | ||
if not isinstance(input_ids, list): | ||
input_ids = input_ids.tolist() | ||
|
||
if isinstance(logits, np.ndarray): | ||
# Unify type, convert numpy array to Tensor | ||
# from_numpy and .numpy() don't copy the data, it uses the same memory address | ||
torch_logits = torch.from_numpy(logits) | ||
processed_torch_logits = self.process_logits(input_ids, torch_logits) | ||
return processed_torch_logits.detach().numpy() | ||
|
||
elif isinstance(logits, torch.Tensor): | ||
return self.process_logits(input_ids, logits) | ||
|
||
elif is_mlx_array(logits): | ||
import mlx.core as mx | ||
|
||
# https://ml-explore.github.io/mlx/build/html/usage/numpy.html | ||
np_view_logits = np.array(logits.astype(mx.float32), copy=False) | ||
torch_logits = torch.from_numpy(np_view_logits) | ||
processed_torch_logits = self.process_logits(input_ids, torch_logits) | ||
processed_np_logits = processed_torch_logits.detach().numpy() | ||
# TODO: don't copy from numpy, instead cast from torch once mlx can | ||
# https://github.com/ml-explore/mlx/issues/413 | ||
mx_processed_logits = mx.array(processed_np_logits) | ||
return mx_processed_logits | ||
|
||
else: | ||
raise TypeError( | ||
"LogitsProcessor must be called with either np.NDArray" | ||
", torch.Tensor, or mlx.core.array typed logits" | ||
) |
Oops, something went wrong.