From 1537695a621d5d904bbf8098091e3cb52faffbc4 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 13 Jun 2024 18:59:36 -0500 Subject: [PATCH] Use LogitsProcessors for models.transformers -> outlines.generate.* --- docs/reference/models/transformers.md | 53 +++++- outlines/__init__.py | 1 + outlines/generate/cfg.py | 42 +---- outlines/generate/regex.py | 6 +- outlines/generate/text.py | 5 +- outlines/models/__init__.py | 2 +- outlines/models/mlxlm.py | 6 +- outlines/models/transformers.py | 179 ++++++++++++++++++- outlines/processors/__init__.py | 2 +- outlines/processors/base_logits_processor.py | 41 ++++- outlines/processors/structured.py | 30 ++-- tests/generate/conftest.py | 36 ++-- tests/generate/test_generate.py | 127 +++++++++++-- 13 files changed, 447 insertions(+), 83 deletions(-) diff --git a/docs/reference/models/transformers.md b/docs/reference/models/transformers.md index 286df4367..7c1febd02 100644 --- a/docs/reference/models/transformers.md +++ b/docs/reference/models/transformers.md @@ -15,7 +15,7 @@ Outlines provides an integration with the `torch` implementation of causal model ```python from outlines import models -model = models.transformers("mistralai/Mistral-7B-v0.1", device="cuda") +model = models.transformers("mistralai/Mistral-7B-v0.3", device="cuda") ``` If you need more fine-grained control you can also initialize the model and tokenizer separately: @@ -30,4 +30,55 @@ tokenizer = AutoTokenizer.from_pretrained("gpt2") model = models.Transformers(llm, tokenizer) ``` +# Using Logits Processors + +There are two ways to use Outlines Structured Generation with HuggingFace Transformers: +- 1) Use Outlines generation wrapper, `outlines.models.transformers` +- 2) Use `OutlinesLogitsProcessor` with `transformers.AutoModelForCausalLM` + +Outlines supports a myriad of logits processors for structured generation. In these example, we will use the `RegexLogitsProcessor` which guarantees generated text matches the specified pattern. + +## Example: `outlines.models.transformers` + +``` +import outlines + +time_regex_pattern = r"(0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?" + +model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct", device="cuda") +generator = outlines.generate.regex(model, time_regex_pattern) + +output = generator("The the best time to visit a dentist is at ") +print(output) +# 2:30 pm +``` + +## Example: Direct `transformers` library use + +``` +import outlines +import transformers + + +model_uri = "microsoft/Phi-3-mini-4k-instruct" + +outlines_tokenizer = outlines.models.TransformerTokenizer( + transformers.AutoTokenizer.from_pretrained(model_uri) +) +phone_number_logits_processor = outlines.processors.RegexLogitsProcessor( + "\\+?[1-9][0-9]{7,14}", # phone number pattern + outlines_tokenizer, +) + +generator = transformers.pipeline('text-generation', model=model_uri) + +output = generator( + "Jenny gave me her number it's ", + logits_processor=transformers.LogitsProcessorList([phone_number_logits_processor]) +) +print(output) +# [{'generated_text': "Jenny gave me her number it's 2125550182"}] +# not quite 8675309 what we expected, but it is a valid phone number +``` + [transformers]: https://github.com/huggingface/transformers diff --git a/outlines/__init__.py b/outlines/__init__.py index 3eb6a2f94..307d2ba6f 100644 --- a/outlines/__init__.py +++ b/outlines/__init__.py @@ -2,6 +2,7 @@ import outlines.generate import outlines.grammars import outlines.models +import outlines.processors import outlines.types from outlines.base import vectorize from outlines.caching import clear_cache, disable_cache, get_cache diff --git a/outlines/generate/cfg.py b/outlines/generate/cfg.py index e473c26a6..0df833067 100644 --- a/outlines/generate/cfg.py +++ b/outlines/generate/cfg.py @@ -1,16 +1,14 @@ from functools import singledispatch -from outlines.fsm.guide import CFGGuide -from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter +from outlines.generate.api import SequenceGeneratorAdapter from outlines.models import OpenAI -from outlines.models.llamacpp import LlamaCpp -from outlines.models.mlxlm import MLXLM -from outlines.models.vllm import VLLM from outlines.samplers import Sampler, multinomial @singledispatch -def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenerator: +def cfg( + model, cfg_str: str, sampler: Sampler = multinomial() +) -> SequenceGeneratorAdapter: """Generate text in the language of a Context-Free Grammar Arguments @@ -24,40 +22,16 @@ def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenera Returns ------- - A `SequenceGenerator` instance that generates text. + A `SequenceGeneratorAdapter` instance that generates text. """ - fsm = CFGGuide(cfg_str, model.tokenizer) - device = model.device - generator = SequenceGenerator(fsm, model, sampler, device) - - return generator - - -@cfg.register(MLXLM) -@cfg.register(VLLM) -def cfg_unimplemented( - model, - cfg_str: str, - sampler: Sampler = multinomial(), -): raise NotImplementedError( - f"The CFG Logits processor is not available for {type(model)}." + f"The CFG Logits processor is not available for {type(model)}. " + + "Please subscribe to https://github.com/outlines-dev/outlines/issues/684" + + " for updates on the fix." ) -@cfg.register(LlamaCpp) -def cfg_llamacpp( - model: LlamaCpp, - cfg_str: str, - sampler: Sampler = multinomial(), -): - from outlines.integrations.llamacpp import CFGLogitsProcessor - - logits_processor = CFGLogitsProcessor(cfg_str, model.model) - return SequenceGeneratorAdapter(model, logits_processor, sampler) - - @cfg.register(OpenAI) def cfg_openai(model, cfg_str: str, sampler: Sampler = multinomial()): raise NotImplementedError( diff --git a/outlines/generate/regex.py b/outlines/generate/regex.py index 6b6656fe9..cdf64a21f 100644 --- a/outlines/generate/regex.py +++ b/outlines/generate/regex.py @@ -5,6 +5,7 @@ from outlines.models import OpenAI from outlines.models.llamacpp import LlamaCpp from outlines.models.mlxlm import MLXLM +from outlines.models.transformers import Transformers from outlines.models.vllm import VLLM from outlines.samplers import Sampler, multinomial @@ -39,8 +40,9 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()): @regex.register(MLXLM) -def regex_mlxlm( - model: MLXLM, +@regex.register(Transformers) +def regex_unified( + model, regex_str: str, sampler: Sampler = multinomial(), ): diff --git a/outlines/generate/text.py b/outlines/generate/text.py index 081ba0920..b8feb7659 100644 --- a/outlines/generate/text.py +++ b/outlines/generate/text.py @@ -2,7 +2,7 @@ from outlines.fsm.guide import StopAtEOSGuide from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter -from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI +from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI, Transformers from outlines.samplers import Sampler, multinomial @@ -37,7 +37,8 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator: @text.register(MLXLM) -def text_mlxlm(model: MLXLM, sampler: Sampler = multinomial()): +@text.register(Transformers) +def text_unified(model, sampler: Sampler = multinomial()): return SequenceGeneratorAdapter(model, None, sampler) diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index fb18824b3..65491ad18 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -12,7 +12,7 @@ from .mamba import Mamba, mamba from .mlxlm import MLXLM, mlxlm from .openai import OpenAI, azure_openai, openai -from .transformers import Transformers, transformers +from .transformers import Transformers, TransformerTokenizer, transformers from .vllm import VLLM, vllm LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, Mamba] diff --git a/outlines/models/mlxlm.py b/outlines/models/mlxlm.py index f561f269d..57aa6f596 100644 --- a/outlines/models/mlxlm.py +++ b/outlines/models/mlxlm.py @@ -9,7 +9,7 @@ from transformers import PreTrainedTokenizer from outlines.generate.api import GenerationParameters, SamplingParameters - from outlines.processors import BaseLogitsProcessor + from outlines.processors import OutlinesLogitsProcessor class MLXLM: @@ -120,7 +120,7 @@ def generate_step( temp: Optional[float], top_p: Optional[float], sampler: str, - logits_processor: "BaseLogitsProcessor", + logits_processor: "OutlinesLogitsProcessor", ) -> Generator[Tuple[int, float], None, None]: """ Adapted from @@ -135,7 +135,7 @@ def generate_step( top_p (float, optional): Nulceus sampling, higher means model considers more less likely words. sampler (str): The sampler string defined by SequenceGeneratorAdapter - logits_processor (BaseLogitsProcessor): Augment logits before sampling. + logits_processor (OutlinesLogitsProcessor): Augment logits before sampling. """ import mlx.core as mx import mlx_lm diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index fae9b8e74..7d32a43bd 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -1,12 +1,17 @@ -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +import dataclasses +from threading import Thread +from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union from datasets.fingerprint import Hasher +from outlines.generate.api import GenerationParameters, SamplingParameters from outlines.models.tokenizer import Tokenizer if TYPE_CHECKING: import torch - from transformers import PreTrainedModel, PreTrainedTokenizer + from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizer + + from outlines.processors import OutlinesLogitsProcessor __all__ = ["transformers"] @@ -129,7 +134,6 @@ def __init__( model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", ): - self.device = model.device self.model = model self.tokenizer = TransformerTokenizer(tokenizer) @@ -190,6 +194,175 @@ def __call__( return next_token_logits, kv_cache + def _get_generation_config( + self, + generation_parameters: GenerationParameters, + sampling_parameters: SamplingParameters, + ) -> "GenerationConfig": + """ + Conert outlines generation parameters into the transformers.GenerationConfig + """ + from transformers import GenerationConfig, set_seed + + max_new_tokens, stop_at, seed = dataclasses.astuple(generation_parameters) + sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple( + sampling_parameters + ) + if max_new_tokens is None: + max_new_tokens = int(1e9) + + if isinstance(stop_at, str): + stop_at = [stop_at] + + # global seed, not desirable + if seed is not None: + set_seed(seed) + + return GenerationConfig( + max_new_tokens=max_new_tokens, + stop_strings=stop_at, + num_return_sequences=(num_samples or 1), + top_p=top_p, + top_k=top_k, + temperature=temperature, + do_sample=(sampler == "multinomial"), + num_beams=(num_samples if sampler == "beam_search" else 1), + eos_token_id=self.tokenizer.eos_token_id, + ) + + def generate( + self, + prompts: Union[str, List[str]], + generation_parameters: GenerationParameters, + logits_processor: Optional["OutlinesLogitsProcessor"], + sampling_parameters: SamplingParameters, + ) -> Union[str, List[str]]: + """Generate text using `transformers`. + + Arguments + --------- + prompts + A prompt or list of prompts. + generation_parameters + An instance of `GenerationParameters` that contains the prompt, + the maximum number of tokens, stop sequences and seed. All the + arguments to `SequenceGeneratorAdapter`'s `__cal__` method. + logits_processor + The logits processor to use when generating text. + sampling_parameters + An instance of `SamplingParameters`, a dataclass that contains + the name of the sampler to use and related parameters as available + in Outlines. + + Returns + ------- + The generated text + """ + from transformers import LogitsProcessorList + + if isinstance(prompts, str): + # convert to 2d + input_ids, attention_mask = self.tokenizer.encode([prompts]) + else: + input_ids, attention_mask = self.tokenizer.encode(prompts) + + input_ids = input_ids.to(self.model.device) + attention_mask = attention_mask.to(self.model.device) + + generation_config = self._get_generation_config( + generation_parameters, + sampling_parameters, + ) + + if logits_processor is not None: + logits_processor_list = LogitsProcessorList([logits_processor]) + else: + logits_processor_list = None + + output_ids = self.model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + logits_processor=logits_processor_list, + generation_config=generation_config, + ) + # encoder-decoder returns output_ids only, decoder-only returns full seq ids + if self.model.config.is_encoder_decoder: + generated_ids = output_ids + else: + generated_ids = output_ids[:, input_ids.shape[1] :] + + outputs = self.tokenizer.decode(generated_ids) + + if isinstance(prompts, str): + # convert back to 1d + return outputs[0] + else: + return outputs + + def stream( + self, + prompts: Union[str, List[str]], + generation_parameters: GenerationParameters, + logits_processor: Optional["OutlinesLogitsProcessor"], + sampling_parameters: SamplingParameters, + ) -> Iterator[str]: + """Stream text using `transformers`. + + Arguments + --------- + prompts + A prompt or list of prompts. + generation_parameters + An instance of `GenerationParameters` that contains the prompt, + the maximum number of tokens, stop sequences and seed. All the + arguments to `SequenceGeneratorAdapter`'s `__cal__` method. + logits_processor + The logits processor to use when generating text. + sampling_parameters + An instance of `SamplingParameters`, a dataclass that contains + the name of the sampler to use and related parameters as available + in Outlines. + + Returns + ------- + A token generator + """ + from transformers import LogitsProcessorList, TextIteratorStreamer + + if not isinstance(prompts, str): + raise TypeError("Cannot stream batch inputs") + + input_ids, attention_mask = self.tokenizer.encode(prompts) + input_ids = input_ids.to(self.model.device) + attention_mask = attention_mask.to(self.model.device) + + generation_config = self._get_generation_config( + generation_parameters, + sampling_parameters, + ) + + if logits_processor is not None: + logits_processor_list = LogitsProcessorList([logits_processor]) + else: + logits_processor_list = None + + streamer = TextIteratorStreamer( + self.tokenizer.tokenizer, skip_prompt=True, skip_special_tokens=True + ) + kwargs = dict( + input_ids=input_ids, + attention_mask=attention_mask, + logits_processor=logits_processor_list, + streamer=streamer, + generation_config=generation_config, + ) + thread = Thread(target=self.model.generate, kwargs=kwargs) + thread.start() + try: + yield from streamer + finally: + thread.join() + def transformers( model_name: str, diff --git a/outlines/processors/__init__.py b/outlines/processors/__init__.py index 5c6a697ed..22c10d905 100644 --- a/outlines/processors/__init__.py +++ b/outlines/processors/__init__.py @@ -1,7 +1,7 @@ from .structured import ( - BaseLogitsProcessor, CFGLogitsProcessor, FSMLogitsProcessor, JSONLogitsProcessor, + OutlinesLogitsProcessor, RegexLogitsProcessor, ) diff --git a/outlines/processors/base_logits_processor.py b/outlines/processors/base_logits_processor.py index dabfd91b0..a831e02bc 100644 --- a/outlines/processors/base_logits_processor.py +++ b/outlines/processors/base_logits_processor.py @@ -14,7 +14,7 @@ def is_mlx_array(logits): return isinstance(logits, mx.array) -class BaseLogitsProcessor(Protocol): +class OutlinesLogitsProcessor(Protocol): """ Base class for logits processors which normalizes types of logits: - ndarray (used by llama-cpp-python), converted to torch.Tensor @@ -29,10 +29,37 @@ class BaseLogitsProcessor(Protocol): @abstractmethod def process_logits( - self, input_ids: List[int], logits: torch.Tensor + self, input_ids: List[List[int]], logits: torch.Tensor ) -> torch.Tensor: ... + def _batch_process_logits(self, input_ids, logits: torch.Tensor) -> torch.Tensor: + """ + Arguments: + input_ids: List[int] | List[List[int]] + logits: 1D tensor | 2D tensor + Returns: + Augmented logits: 1D tensor | 2D tensor + + If given a list of input sequences and a 2D tensor, handle batch request, + otherwise handle request normally. + """ + logits_in_batch_mode = len(logits.shape) > 1 + input_ids_in_batch_mode = isinstance(input_ids[0], list) + + if logits_in_batch_mode != input_ids_in_batch_mode: + raise TypeError( + f"Logits and input_ids incompatible: " + f"logits bach mode={logits_in_batch_mode}, " + f"input ids in batch mode={input_ids_in_batch_mode}" + ) + + if logits_in_batch_mode: + return self.process_logits(input_ids, logits) + else: + logits_2d = self.process_logits([input_ids], logits.unsqueeze(0)) + return logits_2d.squeeze(0) + def __call__( self, input_ids: Union[NDArray[np.int64], List[int], torch.Tensor], @@ -53,11 +80,13 @@ def __call__( # Unify type, convert numpy array to Tensor # from_numpy and .numpy() don't copy the data, it uses the same memory address torch_logits = torch.from_numpy(logits) - processed_torch_logits = self.process_logits(input_ids, torch_logits) + processed_torch_logits = self._batch_process_logits( + input_ids, torch_logits + ) return processed_torch_logits.detach().numpy() elif isinstance(logits, torch.Tensor): - return self.process_logits(input_ids, logits) + return self._batch_process_logits(input_ids, logits) elif is_mlx_array(logits): # mlx -> torch -> mlx conversion docs: @@ -65,7 +94,9 @@ def __call__( import mlx.core as mx torch_logits = torch.from_dlpack(logits) - processed_torch_logits = self.process_logits(input_ids, torch_logits) + processed_torch_logits = self._batch_process_logits( + input_ids, torch_logits + ) # numpy doesn't support bfloat16, mlx doesn't support direct conversion from torch logits_float32_numpy = processed_torch_logits.float().numpy() diff --git a/outlines/processors/structured.py b/outlines/processors/structured.py index b8ef5b2da..9eb62194b 100644 --- a/outlines/processors/structured.py +++ b/outlines/processors/structured.py @@ -24,7 +24,7 @@ limitations under the License. """ import math -from typing import TYPE_CHECKING, List, Optional, Type, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union import numpy as np import torch @@ -35,13 +35,13 @@ from outlines.fsm.json_schema import build_regex_from_schema from outlines.integrations.utils import convert_json_schema_to_str -from .base_logits_processor import BaseLogitsProcessor +from .base_logits_processor import OutlinesLogitsProcessor if TYPE_CHECKING: from outlines.models.tokenizer import Tokenizer -class FSMLogitsProcessor(BaseLogitsProcessor): +class FSMLogitsProcessor(OutlinesLogitsProcessor): """Bias generation using a finite state machine. Attributes @@ -63,15 +63,19 @@ def __init__(self, tokenizer: "Tokenizer", fsm: Guide): The finite state machine which is used to bias the logits. """ self.tokenizer = tokenizer - self._fsm_state = 0 + self._fsm_states: Dict[int, int] = {} self.fsm: Guide = fsm self._is_first_token = True def process_logits( - self, input_ids: List[int], logits: torch.Tensor + self, input_ids: List[List[int]], logits: torch.Tensor ) -> NDArray[np.float32]: """Use the FSM to bias the logits before sampling the next token. + Assumptions: + - input_ids and logits are for batch requests + - logits processors are only used once and never re-applied for a new sequence generator + Parameters ---------- input_ids @@ -86,15 +90,19 @@ def process_logits( """ if self._is_first_token: self._is_first_token = False + self._fsm_states = {i: 0 for i in range(len(input_ids))} else: - last_token = input_ids[-1] - self._fsm_state = self.fsm.get_next_state(self._fsm_state, last_token) - - allowed_tokens = self.fsm.get_next_instruction(self._fsm_state).tokens - allowed_tokens = torch.tensor(allowed_tokens, device=logits.device) + for i, seq_ids in enumerate(input_ids): + last_token = seq_ids[-1] + self._fsm_states[i] = self.fsm.get_next_state( + self._fsm_states[i], last_token + ) mask = torch.full_like(logits, -math.inf) - mask[allowed_tokens] = logits[allowed_tokens] + for i, fsm_state in self._fsm_states.items(): + allowed_tokens = self.fsm.get_next_instruction(fsm_state).tokens + mask[i, allowed_tokens] = logits[i, allowed_tokens] + return mask def copy(self) -> "FSMLogitsProcessor": diff --git a/tests/generate/conftest.py b/tests/generate/conftest.py index 5b3e6f79c..1b73faaf8 100644 --- a/tests/generate/conftest.py +++ b/tests/generate/conftest.py @@ -1,24 +1,40 @@ from importlib import reload import pytest +import torch -def pytest_collection_modifyitems(config, items): - """If mlxlm and Metal aren't available, skip mlxlm tests""" +def is_mlx_available(): try: import mlx.core as mx import mlx_lm # noqa: F401 assert mx.metal.is_available() except (ImportError, AssertionError): - skip_marker = pytest.mark.skip( - reason="Skipping test because mlx-lm or Metal are not available" - ) - for item in items: - if "model_fixture" in item.fixturenames: - model_param = item.callspec.params.get("model_fixture", None) - if model_param == "model_mlxlm": - item.add_marker(skip_marker) + return False + return True + + +def is_cuda_available(): + return torch.cuda.is_available() + + +def pytest_collection_modifyitems(config, items): + """If mlxlm and Metal aren't available, skip mlxlm tests""" + skipped_models = [] + if not is_mlx_available(): + skipped_models.append("model_mlxlm") + if not is_cuda_available(): + skipped_models.append("model_vllm") + + skip_marker = pytest.mark.skip( + reason="Skipping test due to incompatible hardware / system." + ) + for item in items: + if "model_fixture" in item.fixturenames: + model_param = item.callspec.params.get("model_fixture", None) + if model_param in skipped_models: + item.add_marker(skip_marker) @pytest.fixture diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index 1f1a3aea2..5ceb82589 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -1,9 +1,11 @@ +import contextlib import re import pytest import outlines.generate as generate import outlines.models as models +import outlines.samplers as samplers @pytest.fixture(scope="session") @@ -24,21 +26,87 @@ def model_transformers(tmp_path_factory): return models.transformers("Locutusque/TinyMistral-248M-v2-Instruct", device="cpu") -@pytest.mark.parametrize( - "model_fixture", - ("model_llamacpp", "model_mlxlm", "model_transformers"), +@pytest.fixture(scope="session") +def model_vllm(tmp_path_factory): + return models.vllm("facebook/opt-125m") + + +# TODO: mamba / exllamav2 failing in main, address in https://github.com/outlines-dev/outlines/issues/808 +""" +@pytest.fixture(scope="session") +def model_exllamav2(tmp_path_factory): + return models.exllamav2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + device="cpu" + ) + + +@pytest.fixture(scope="session") +def model_mamba(tmp_path_factory): + return models.mamba( + model_name="state-spaces/mamba-130m-hf", + device="cpu" + ) + +ALL_MODEL_FIXTURES = ("model_llamacpp", "model_mlxlm", "model_transformers", "model_vllm", "model_exllamav2", "model_mamba") +""" + + +ALL_MODEL_FIXTURES = ( + "model_llamacpp", + "model_mlxlm", + "model_transformers", + "model_vllm", ) -def test_generate_text(request, model_fixture): + + +NOT_IMPLEMENTED = { + "batch": ["model_llamacpp"], + "stream": ["model_vllm"], + "beam_search": ["model_llamacpp"], +} + + +def enforce_not_implemented(task_name, model_fixture): + """ + Per `NOT_IMPLEMENTED`, mapping, if a model hasn't implemented a task, + assert an NotImplementedError is raised. Otherwise, run normally + """ + if model_fixture in NOT_IMPLEMENTED.get(task_name, []): + return pytest.raises(NotImplementedError) + else: + return contextlib.nullcontext() + + +@pytest.mark.parametrize("sampler_name", ("greedy", "multinomial", "beam_search")) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_text(request, model_fixture, sampler_name): + model = request.getfixturevalue(model_fixture) + generator = generate.text(model, getattr(samplers, sampler_name)()) + with enforce_not_implemented(sampler_name, model_fixture): + res = generator("test", max_tokens=10) + assert isinstance(res, str) + + +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_batch_text(request, model_fixture): model = request.getfixturevalue(model_fixture) generator = generate.text(model) - res = generator("test", max_tokens=10) - assert isinstance(res, str) + with enforce_not_implemented("batch", model_fixture): + res = generator(["test", "test2"], max_tokens=10) + assert isinstance(res, list) + assert isinstance(res[0], str) + + +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_text_stream(request, model_fixture): + model = request.getfixturevalue(model_fixture) + generator = generate.text(model) + with enforce_not_implemented("stream", model_fixture): + for token in generator.stream("a b c ", max_tokens=10): + assert isinstance(token, str) -@pytest.mark.parametrize( - "model_fixture", - ("model_llamacpp", "model_mlxlm", "model_transformers"), -) @pytest.mark.parametrize( "pattern", ( @@ -47,8 +115,47 @@ def test_generate_text(request, model_fixture): "\\+?[1-9][0-9]{7,14}", ), ) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) def test_generate_regex(request, model_fixture, pattern): model = request.getfixturevalue(model_fixture) generator = generate.regex(model, pattern) res = generator("foobarbaz", max_tokens=20) assert re.match(pattern, res) is not None, res + + +@pytest.mark.parametrize( + "pattern", + ( + "[0-9]", + "abc*", + "\\+?[1-9][0-9]{7,14}", + ), +) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_regex_stream(request, model_fixture, pattern): + model = request.getfixturevalue(model_fixture) + generator = generate.regex(model, pattern) + with enforce_not_implemented("stream", model_fixture): + output = "" + for token in generator.stream("output:", max_tokens=10): + output += token + assert re.match(pattern, output) is not None, output + + +@pytest.mark.parametrize( + "pattern", + ( + "(123456789)|(abcdefghijklmnop)", + "abc*", + "\\+?[1-9][0-9]{7,14}", + ), +) +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_regex_batch(request, model_fixture, pattern): + """Ensure batch requests work and fsm order is maintained""" + model = request.getfixturevalue(model_fixture) + generator = generate.regex(model, pattern) + with enforce_not_implemented("batch", model_fixture): + outputs = generator(["abc", "123", "123bce", "33aa"], max_tokens=20) + for output in outputs: + assert re.match(pattern, output) is not None, output