Skip to content

Commit

Permalink
Introduce outlines.models.mlxlm
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Jun 7, 2024
1 parent ed44a47 commit 1b3f8fe
Show file tree
Hide file tree
Showing 11 changed files with 592 additions and 4 deletions.
8 changes: 5 additions & 3 deletions outlines/generate/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.models.llamacpp import LlamaCpp
from outlines.models.mlxlm import MLXLM
from outlines.models.vllm import VLLM
from outlines.samplers import Sampler, multinomial

Expand Down Expand Up @@ -33,14 +34,15 @@ def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenera
return generator


@cfg.register(MLXLM)
@cfg.register(VLLM)
def cfg_vllm(
model: VLLM,
def cfg_unimplemented(
model,
cfg_str: str,
sampler: Sampler = multinomial(),
):
raise NotImplementedError(
"The CFG Logits processor is not available for the vLLM integration."
f"The CFG Logits processor is not available for {type(model)}."
)


Expand Down
14 changes: 14 additions & 0 deletions outlines/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.models.llamacpp import LlamaCpp
from outlines.models.mlxlm import MLXLM
from outlines.models.vllm import VLLM
from outlines.samplers import Sampler, multinomial

Expand Down Expand Up @@ -37,6 +38,19 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):
return generator


# TODO: have all models dispatched to this function
@regex.register(MLXLM)
def regex_unified(
model: MLXLM,
regex_str: str,
sampler: Sampler = multinomial(),
):
from outlines.processors import RegexLogitsProcessor

logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@regex.register(LlamaCpp)
def regex_llamacpp(
model: LlamaCpp,
Expand Down
8 changes: 7 additions & 1 deletion outlines/generate/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from outlines.fsm.guide import StopAtEOSGuide
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.models import VLLM, LlamaCpp, OpenAI
from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -36,6 +36,12 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator:
return generator


# TODO: have all models dispatched to this function
@text.register(MLXLM)
def text_unified(model, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)


@text.register(VLLM)
def text_vllm(model: VLLM, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)
Expand Down
1 change: 1 addition & 0 deletions outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .exllamav2 import ExLlamaV2Model, exl2
from .llamacpp import LlamaCpp, llamacpp
from .mamba import Mamba, mamba
from .mlxlm import MLXLM, mlxlm
from .openai import OpenAI, azure_openai, openai
from .transformers import Transformers, transformers
from .vllm import VLLM, vllm
Expand Down
219 changes: 219 additions & 0 deletions outlines/models/mlxlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import dataclasses
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union

from .transformers import TransformerTokenizer

if TYPE_CHECKING:
import mlx.nn as nn
from transformers import PreTrainedTokenizer

from outlines.generate.api import GenerationParameters, SamplingParameters

try:
import mlx.core as mx
import mlx_lm
except ImportError:
pass


class MLXLM:
"""
Represents an `mlx_lm` model
"""

def __init__(
self,
model: "nn.Module",
tokenizer: "PreTrainedTokenizer",
):
self.model = model
self.tokenizer = TransformerTokenizer(tokenizer._tokenizer) # HF Tokenizer

def generate(
self,
prompts: Union[str, List[str]],
generation_parameters: "GenerationParameters",
logits_processor,
sampling_parameters: "SamplingParameters",
) -> str:
streamer = self.stream(
prompts, generation_parameters, logits_processor, sampling_parameters
)
return "".join(list(streamer))

def stream(
self,
prompts: Union[str, List[str]],
generation_parameters: "GenerationParameters",
logits_processor,
sampling_parameters: "SamplingParameters",
) -> Iterator[str]:
"""Generate text using `mlx_lm`.
Arguments
---------
prompts
A prompt or list of prompts.
generation_parameters
An instance of `GenerationParameters` that contains the prompt,
the maximum number of tokens, stop sequences and seed. All the
arguments to `SequenceGeneratorAdapter`'s `__cal__` method.
logits_processor
The logits processor to use when generating text.
sampling_parameters
An instance of `SamplingParameters`, a dataclass that contains
the name of the sampler to use and related parameters as available
in Outlines.
Returns
-------
The generated text.
"""
max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters)
sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple(
sampling_parameters
)

if not isinstance(prompts, str):
raise NotImplementedError(
"The `mlx-lm` library does not support batch inference."
)
if sampler == "beam_search":
raise NotImplementedError(
"The `mlx-lm` library does not support Beam Search."
)
if num_samples != 1:
raise NotImplementedError(
"The `mlx-lm` library does not allow to take several samples."
)

# PR TODO: error if top_k or seed are set OR implement them
generate_kwargs = {
"temp": temperature,
"top_p": top_p,
"sampler": sampler,
"logits_processor": logits_processor,
}

# Adapted from
# https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L267
np_tokens = self.tokenizer.encode(prompts)[0].numpy()
prompt_tokens = mx.array(np_tokens)

for (token, prob), n in zip(
self.generate_step(prompt_tokens, **generate_kwargs),
range(max_tokens),
):
if token == self.tokenizer.eos_token_id:
break # PR TODO: should use stop_at instead?
yield self.tokenizer.decode([token])[0]

def generate_step(self, prompt, temp, top_p, sampler, logits_processor):
"""
Adapted from
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L129
A generator producing token ids based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
sampler (str): PR TODO
logits_processor: PR TODO
Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing
one token and probability per call.
"""
# PR TODO: Type hints in fn signature

temp = temp or 1.0

def sample(logits: "mx.array") -> Tuple["mx.array", float]:
softmax_logits = mx.softmax(logits)

if temp == 0 or sampler == "greedy":
token = mx.argmax(logits, axis=-1)
elif sampler == "multinomial":
if top_p is not None and top_p > 0 and top_p < 1.0:
token = mlx_lm.sample_utils.top_p_sampling(logits, top_p, temp)
else:
token = mx.random.categorical(logits * (1 / temp))
else:
raise ValueError(f"Invalid mlx-lm sampler: `{sampler}`")

prob = softmax_logits[0, token]
return token, prob

kv_heads = (
[self.model.n_kv_heads] * len(self.model.layers)
if isinstance(self.model.n_kv_heads, int)
else self.model.n_kv_heads
)
cache = [mlx_lm.models.base.KVCache(self.model.head_dim, n) for n in kv_heads]

input_ids = prompt
while True:
logits = self.model(input_ids, cache=cache)
logits = logits[:, -1, :]

if logits_processor is not None:
logits = logits_processor(input_ids, logits)

new_token, prob = sample(logits)
yield new_token.item(), prob

input_ids = mx.concatenate([input_ids, new_token[None]], axis=1)


def mlxlm(
model_name: str,
tokenizer_config: dict = {},
model_config: dict = {},
adapter_path: Optional[str] = None,
lazy: bool = False,
):
"""Instantiate a model from the `mlx_lm` library and its tokenizer.
Signature adapted from
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L422
Parameters
----------
Args:
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary.
model_config(dict, optional): Configuration parameters specifically for the model.
Defaults to an empty dictionary.
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
to the model. Default: ``None``.
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
Returns
-------
A `MLXLM` model instance.
"""
try:
import mlx.core as mx
import mlx_lm
except ImportError:
raise ImportError(
"The `mlx_lm` library needs to be installed in order to use `mlx_lm` models."
)
if not mx.metal.is_available():
raise RuntimeError("You cannot use `mlx_lm` without Apple Silicon (Metal)")

model, tokenizer = mlx_lm.load(
model_name,
tokenizer_config=tokenizer_config,
model_config=model_config,
adapter_path=adapter_path,
lazy=lazy,
)
return MLXLM(model, tokenizer)
6 changes: 6 additions & 0 deletions outlines/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .structured import (
CFGLogitsProcessor,
FSMLogitsProcessor,
JSONLogitsProcessor,
RegexLogitsProcessor,
)
78 changes: 78 additions & 0 deletions outlines/processors/base_logits_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from abc import abstractmethod
from typing import List, Protocol, Union

import numpy as np
import torch
from numpy.typing import NDArray


def is_mlx_array(logits):
try:
import mlx.core as mx
except ImportError:
return False
return isinstance(logits, mx.array)


class BaseLogitsProcessor(Protocol):
"""
Base class for logits processors which normalizes types of logits:
- ndarray (used by llama-cpp-python), converted to torch.Tensor
- torch.Tensor (used by everything else)
Normalization of types and conversion to torch.Tensor
doesn't move memory, it just casts the type.
Normalizing the types allows all logits processors inheriting from this class
to implement a single method for all the business logit: `process_logits()`
"""

@abstractmethod
def process_logits(
self, input_ids: List[int], logits: torch.Tensor
) -> torch.Tensor:
...

def __call__(
self,
input_ids: Union[NDArray[np.int64], List[int], torch.Tensor],
logits: Union[NDArray[np.float32], torch.Tensor],
) -> Union[NDArray[np.int64], torch.Tensor]:
"""
Apply logits processor
Unify type
- convert input_ids: either ndarray, List[int], or Tensor -> List[int]
- convert logits: either ndarray, mlx array, Tensor -> Tensor
Call process_logits() to perform business logic
"""
if not isinstance(input_ids, list):
input_ids = input_ids.tolist()

if isinstance(logits, np.ndarray):
# Unify type, convert numpy array to Tensor
# from_numpy and .numpy() don't copy the data, it uses the same memory address
torch_logits = torch.from_numpy(logits)
processed_torch_logits = self.process_logits(input_ids, torch_logits)
return processed_torch_logits.detach().numpy()

elif isinstance(logits, torch.Tensor):
return self.process_logits(input_ids, logits)

elif is_mlx_array(logits):
import mlx.core as mx

# https://ml-explore.github.io/mlx/build/html/usage/numpy.html
np_view_logits = np.array(logits.astype(mx.float32), copy=False)
torch_logits = torch.from_numpy(np_view_logits)
processed_torch_logits = self.process_logits(input_ids, torch_logits)
processed_np_logits = processed_torch_logits.detach().numpy()
# TODO: don't copy from numpy, instead cast from torch once mlx can
# https://github.com/ml-explore/mlx/issues/413
mx_processed_logits = mx.array(processed_np_logits)
return mx_processed_logits

else:
raise TypeError(
"LogitsProcessor must be called with either np.NDArray"
", torch.Tensor, or mlx.core.array typed logits"
)
Loading

0 comments on commit 1b3f8fe

Please sign in to comment.