From 4879606b63fe952445390de309eb11403e71643d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 22 Jul 2024 15:53:02 +0200 Subject: [PATCH] Move function that adapts the tokenizer --- outlines/integrations/__init__.py | 1 - outlines/integrations/utils.py | 70 ------------- outlines/integrations/vllm.py | 158 ------------------------------ outlines/models/vllm.py | 41 +++++++- outlines/serve/serve.py | 9 +- outlines/serve/vllm.py | 4 - tests/fsm/test_regex.py | 2 +- 7 files changed, 47 insertions(+), 238 deletions(-) delete mode 100644 outlines/integrations/__init__.py delete mode 100644 outlines/integrations/utils.py delete mode 100644 outlines/integrations/vllm.py delete mode 100644 outlines/serve/vllm.py diff --git a/outlines/integrations/__init__.py b/outlines/integrations/__init__.py deleted file mode 100644 index b0a90d5ea..000000000 --- a/outlines/integrations/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Utility functions and classes used to integrate `outlines` into other packages.""" diff --git a/outlines/integrations/utils.py b/outlines/integrations/utils.py deleted file mode 100644 index edf77e5c7..000000000 --- a/outlines/integrations/utils.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Utility functions used in integrations with other packages. - - _______________________________ -/ Don't want to self-host? \ -\\ Try .json at http://dottxt.co / - ------------------------------- - \\ ^__^ - \\ (oo)\\_______ - (__)\\ )\\/\ - ||----w | - || || - -Copyright 2024- the Outlines developers - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import json -from typing import Type, Union - -from pydantic import BaseModel -from transformers import SPIECE_UNDERLINE, PreTrainedTokenizerBase - - -def adapt_tokenizer(tokenizer: PreTrainedTokenizerBase) -> PreTrainedTokenizerBase: - """Adapt a tokenizer to use to compile the FSM. - - The API of Outlines tokenizers is slightly different to that of `transformers`. In - addition we need to handle the missing spaces to Llama's tokenizer to be able to - compile FSMs for this model. - - Parameters - ---------- - tokenizer - The tokenizer of the model. - - Returns - ------- - PreTrainedTokenizerBase - The adapted tokenizer. - """ - tokenizer.vocabulary = tokenizer.get_vocab() - tokenizer.special_tokens = set(tokenizer.all_special_tokens) - - def convert_token_to_string(token: Union[str, bytes]) -> str: - string = tokenizer.convert_tokens_to_string([token]) - - # A hack to handle missing spaces to HF's Llama tokenizers - if ( - type(token) is str - and token.startswith(SPIECE_UNDERLINE) - or token == "<0x20>" - ): - return " " + string - - return string - - tokenizer.convert_token_to_string = convert_token_to_string - - return tokenizer diff --git a/outlines/integrations/vllm.py b/outlines/integrations/vllm.py deleted file mode 100644 index 2a5f26e35..000000000 --- a/outlines/integrations/vllm.py +++ /dev/null @@ -1,158 +0,0 @@ -"""Make vLLM compatible with Outlines' structured generation. - - _______________________________ -/ Don't want to self-host? \ -\\ Try .json at http://dottxt.co / - ------------------------------- - \\ ^__^ - \\ (oo)\\_______ - (__)\\ )\\/\ - ||----w | - || || - -Copyright 2024- the Outlines developers - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import math -from collections import defaultdict -from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Type, Union - -import torch -from pydantic import BaseModel - -from outlines.fsm.guide import RegexGuide -from outlines.fsm.json_schema import build_regex_from_schema -from outlines.integrations.utils import adapt_tokenizer, convert_json_schema_to_str - -if TYPE_CHECKING: - from vllm import LLM - - -class RegexLogitsProcessor: - """Bias vLLM generation based on a regular expression. - - Attributes - ---------- - fsm - The finite state machine which is used to bias the logits. - """ - - def __init__(self, regex_string: str, llm: "LLM"): - """Compile the FSM that drives the regex-structured generation. - - Parameters - ---------- - regex_string - A string that represents a regular expression. - llm - The vLLM model. - - Raises - ------ - ValueError - If the provided LLM instance in `RegexLogitsProcessor` neither has a - `tokenizer` attribute or a `get_tokenizer` method. - """ - if hasattr(llm, "get_tokenizer"): - tokenizer = llm.get_tokenizer() - elif hasattr(llm, "tokenizer"): - if hasattr(llm.tokenizer, "tokenizer"): - tokenizer = llm.tokenizer.tokenizer - else: - tokenizer = llm.tokenizer - else: - raise ValueError( - "The provided LLM instance in `RegexLogitsProcessor` neither has a " - "`tokenizer` attribute or a `get_tokenizer` method." - ) - tokenizer = adapt_tokenizer(tokenizer=tokenizer) - self.mask_cache: Dict[int, torch.Tensor] = {} - self.fsm = RegexGuide(regex_string, tokenizer) - self._fsm_state: DefaultDict[int, int] = defaultdict(int) - - def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: - """Use the FSM to bias the logits before sampling the next token. - - Parameters - ---------- - input_ids - The tokens of the current sentence. - scores - The logits of the current sentence. - - Returns - ------- - torch.Tensor - The biased logits. - """ - seq_id = hash(tuple(input_ids)) - - # Initialize the FSM state dictionary if the input_ids are empty, as this means - # that the input_ids are the first tokens of the sequence. - if len(input_ids) > 0: - last_token = input_ids[-1] - last_seq_id = hash(tuple(input_ids[:-1])) - self._fsm_state[seq_id] = self.fsm.get_next_state( - state=self._fsm_state[last_seq_id], token_id=last_token - ) - - state_id = self._fsm_state[seq_id] - if state_id not in self.mask_cache: - allowed_tokens = self.fsm.get_next_instruction( - state=self._fsm_state[seq_id] - ).tokens - mask = torch.full((scores.shape[-1],), -math.inf) - mask[allowed_tokens] = 0 - mask = mask.pin_memory() - self.mask_cache[state_id] = mask - else: - mask = self.mask_cache[state_id] - mask = mask.to(device=scores.device, non_blocking=True) - biased_scores = scores + mask - - return biased_scores - - -class JSONLogitsProcessor(RegexLogitsProcessor): - """Bias vLLM generation based on a JSON schema. - - Attributes - ---------- - fsm - The finite state machine which is used to bias the logits. - """ - - def __init__( - self, - schema: Union[dict, Type[BaseModel], str], - llm: "LLM", - whitespace_pattern: Optional[str] = None, - ): - """Compile the FSM that drives the JSON-guided generation. - - Parameters - ---------- - schema - A JSON schema that encodes the structure we want the model to generate. - llm - The vLLM model. - whitespace_pattern - Pattern to use for JSON syntactic whitespace (doesn't impact string - literals). For example, to allow only a single space or newline with - `whitespace_pattern=r"[\n ]?"` - """ - schema_str = convert_json_schema_to_str(json_schema=schema) - regex_string = build_regex_from_schema(schema_str, whitespace_pattern) - super().__init__(regex_string=regex_string, llm=llm) diff --git a/outlines/models/vllm.py b/outlines/models/vllm.py index 2ae3d99d4..d1f97bde2 100644 --- a/outlines/models/vllm.py +++ b/outlines/models/vllm.py @@ -1,8 +1,9 @@ import dataclasses from typing import TYPE_CHECKING, List, Optional, Union +from transformers import SPIECE_UNDERLINE, PreTrainedTokenizerBase + from outlines.generate.api import GenerationParameters, SamplingParameters -from outlines.integrations.utils import adapt_tokenizer if TYPE_CHECKING: from vllm import LLM @@ -185,3 +186,41 @@ def vllm(model_name: str, **vllm_model_params): model = LLM(model_name, **vllm_model_params) return VLLM(model) + + +def adapt_tokenizer(tokenizer: PreTrainedTokenizerBase) -> PreTrainedTokenizerBase: + """Adapt a tokenizer to use to compile the FSM. + + The API of Outlines tokenizers is slightly different to that of `transformers`. In + addition we need to handle the missing spaces to Llama's tokenizer to be able to + compile FSMs for this model. + + Parameters + ---------- + tokenizer + The tokenizer of the model. + + Returns + ------- + PreTrainedTokenizerBase + The adapted tokenizer. + """ + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: Union[str, bytes]) -> str: + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if ( + type(token) is str + and token.startswith(SPIECE_UNDERLINE) + or token == "<0x20>" + ): + return " " + string + + return string + + tokenizer.convert_token_to_string = convert_token_to_string + + return tokenizer diff --git a/outlines/serve/serve.py b/outlines/serve/serve.py index fb8c80139..ba40e6ffa 100644 --- a/outlines/serve/serve.py +++ b/outlines/serve/serve.py @@ -35,12 +35,14 @@ from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid -from outlines.integrations.vllm import JSONLogitsProcessor, RegexLogitsProcessor +from outlines.models.vllm import adapt_tokenizer +from outlines.processors.structured import JSONLogitsProcessor, RegexLogitsProcessor TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds. app = FastAPI() engine = None +tokenizer = None @app.get("/health") @@ -69,9 +71,9 @@ async def generate(request: Request) -> Response: json_schema = request_dict.pop("schema", None) regex_string = request_dict.pop("regex", None) if json_schema is not None: - logits_processors = [JSONLogitsProcessor(json_schema, engine.engine)] + logits_processors = [JSONLogitsProcessor(json_schema, tokenizer)] elif regex_string is not None: - logits_processors = [RegexLogitsProcessor(regex_string, engine.engine)] + logits_processors = [RegexLogitsProcessor(regex_string, tokenizer)] else: logits_processors = [] @@ -124,6 +126,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: # Sets default for the model (`facebook/opt-125m`) engine = AsyncLLMEngine.from_engine_args(engine_args) + tokenizer = adapt_tokenizer(tokenizer=engine.engine.tokenizer.tokenizer) uvicorn.run( app, diff --git a/outlines/serve/vllm.py b/outlines/serve/vllm.py deleted file mode 100644 index ddc50b47d..000000000 --- a/outlines/serve/vllm.py +++ /dev/null @@ -1,4 +0,0 @@ -from outlines.integrations.vllm import ( # noqa[F401] - JSONLogitsProcessor, - RegexLogitsProcessor, -) diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py index fa72ad0dc..a1340d6d1 100644 --- a/tests/fsm/test_regex.py +++ b/tests/fsm/test_regex.py @@ -18,8 +18,8 @@ reduced_vocabulary, walk_fsm, ) -from outlines.integrations.utils import adapt_tokenizer from outlines.models.transformers import TransformerTokenizer +from outlines.models.vllm import adapt_tokenizer def identity(s):