Skip to content

Commit

Permalink
Move function that adapts the tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Aug 16, 2024
1 parent 5a7de73 commit 4879606
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 238 deletions.
1 change: 0 additions & 1 deletion outlines/integrations/__init__.py

This file was deleted.

70 changes: 0 additions & 70 deletions outlines/integrations/utils.py

This file was deleted.

158 changes: 0 additions & 158 deletions outlines/integrations/vllm.py

This file was deleted.

41 changes: 40 additions & 1 deletion outlines/models/vllm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import dataclasses
from typing import TYPE_CHECKING, List, Optional, Union

from transformers import SPIECE_UNDERLINE, PreTrainedTokenizerBase

from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.integrations.utils import adapt_tokenizer

if TYPE_CHECKING:
from vllm import LLM
Expand Down Expand Up @@ -185,3 +186,41 @@ def vllm(model_name: str, **vllm_model_params):
model = LLM(model_name, **vllm_model_params)

return VLLM(model)


def adapt_tokenizer(tokenizer: PreTrainedTokenizerBase) -> PreTrainedTokenizerBase:
"""Adapt a tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of `transformers`. In
addition we need to handle the missing spaces to Llama's tokenizer to be able to
compile FSMs for this model.
Parameters
----------
tokenizer
The tokenizer of the model.
Returns
-------
PreTrainedTokenizerBase
The adapted tokenizer.
"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)

def convert_token_to_string(token: Union[str, bytes]) -> str:
string = tokenizer.convert_tokens_to_string([token])

# A hack to handle missing spaces to HF's Llama tokenizers
if (
type(token) is str
and token.startswith(SPIECE_UNDERLINE)
or token == "<0x20>"
):
return " " + string

return string

tokenizer.convert_token_to_string = convert_token_to_string

return tokenizer
9 changes: 6 additions & 3 deletions outlines/serve/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

from outlines.integrations.vllm import JSONLogitsProcessor, RegexLogitsProcessor
from outlines.models.vllm import adapt_tokenizer
from outlines.processors.structured import JSONLogitsProcessor, RegexLogitsProcessor

TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
app = FastAPI()
engine = None
tokenizer = None


@app.get("/health")
Expand Down Expand Up @@ -69,9 +71,9 @@ async def generate(request: Request) -> Response:
json_schema = request_dict.pop("schema", None)
regex_string = request_dict.pop("regex", None)
if json_schema is not None:
logits_processors = [JSONLogitsProcessor(json_schema, engine.engine)]
logits_processors = [JSONLogitsProcessor(json_schema, tokenizer)]
elif regex_string is not None:
logits_processors = [RegexLogitsProcessor(regex_string, engine.engine)]
logits_processors = [RegexLogitsProcessor(regex_string, tokenizer)]
else:
logits_processors = []

Expand Down Expand Up @@ -124,6 +126,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:

# Sets default for the model (`facebook/opt-125m`)
engine = AsyncLLMEngine.from_engine_args(engine_args)
tokenizer = adapt_tokenizer(tokenizer=engine.engine.tokenizer.tokenizer)

uvicorn.run(
app,
Expand Down
4 changes: 0 additions & 4 deletions outlines/serve/vllm.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/fsm/test_regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
reduced_vocabulary,
walk_fsm,
)
from outlines.integrations.utils import adapt_tokenizer
from outlines.models.transformers import TransformerTokenizer
from outlines.models.vllm import adapt_tokenizer


def identity(s):
Expand Down

0 comments on commit 4879606

Please sign in to comment.