Skip to content

Commit

Permalink
Add CFG to vllm serving
Browse files Browse the repository at this point in the history
  • Loading branch information
mory91 authored and rlouf committed Jan 12, 2024
1 parent 04bbb96 commit fde61a8
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 36 deletions.
13 changes: 12 additions & 1 deletion docs/reference/vllm.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ You can then query the model in shell by passing a prompt and either

1. a [JSON Schema][jsonschema]{:target="_blank"} specification or
2. a [Regex][regex]{:target="_blank"} pattern
2. an EBNF grammar

with the `schema` or `regex` parameters, respectively, to the `/generate` endpoint. If both are specified, the schema will be used. If neither is specified, the generated text will be unconstrained.
with the `schema`, `regex` of `cfg` parameters, respectively, to the `/generate` endpoint. If both are specified, the schema will be used. If neither is specified, the generated text will be unconstrained.

For example, to generate a string that matches the schema `{"type": "string"}` (any string):

Expand All @@ -47,6 +48,16 @@ curl http://127.0.0.1:8000/generate \
}'
```

To generate a string that matches the grammar `<grammar>`:

```bash
curl http://127.0.0.1:8000/generate \
-d '{
"prompt": "What is Pi? Give me the first 15 digits: ",
"cfg": <grammar>
}'
```

Instead of `curl`, you can also use the [requests][requests]{:target="_blank"} library from another python program.

Please consult the [vLLM documentation][vllm]{:target="_blank"} for details on additional request parameters. You can also [read the code](https://github.com/outlines-dev/outlines/blob/main/outlines/serve/serve.py) in case you need to customize the solution to your needs.
Expand Down
4 changes: 4 additions & 0 deletions outlines/serve/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from vllm.utils import random_uuid

from .vllm import (
CFGLogitsProcessor,
JSONLogitsProcessor,
RegexLogitsProcessor,
_patched_apply_logits_processors,
Expand Down Expand Up @@ -65,10 +66,13 @@ async def generate(request: Request) -> Response:

json_schema = request_dict.pop("schema", None)
regex_string = request_dict.pop("regex", None)
cfg_string = request_dict.pop("cfg", None)
if json_schema is not None:
logits_processors = [JSONLogitsProcessor(json_schema, engine.engine)]
elif regex_string is not None:
logits_processors = [RegexLogitsProcessor(regex_string, engine.engine)]
elif cfg_string is not None:
logits_processors = [CFGLogitsProcessor(cfg_string, engine.engine)]
else:
logits_processors = []

Expand Down
100 changes: 65 additions & 35 deletions outlines/serve/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,50 @@
import json
import math
from collections import defaultdict
from typing import DefaultDict, List
from typing import Callable, DefaultDict, List

import torch

from outlines.fsm.fsm import RegexFSM
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
from outlines.fsm.json_schema import build_regex_from_object


def _adapt_tokenizer(tokenizer):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)

def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = tokenizer.convert_tokens_to_string([token])

# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string

def change_decoder(
decoder: Callable[[List[int]], str]
) -> Callable[[List[int]], List[str]]:
def new_decoder(inp_tokens: List[int]) -> List[str]:
return [decoder(inp_tokens)]

return new_decoder

tokenizer.convert_token_to_string = convert_token_to_string
tokenizer.decode = change_decoder(tokenizer.decode)

return tokenizer


def _patched_apply_logits_processors(
logits,
sampling_metadata,
Expand Down Expand Up @@ -39,21 +75,9 @@ def _patched_apply_logits_processors(
return logits


class RegexLogitsProcessor:
def __init__(self, regex_string, llm):
"""Compile the FSM that drives the regex-guided generation.
Parameters
----------
regex_string
A string that represents a regular expression
llm
An instance of `vllm.LLM`
"""
tokenizer = self.adapt_tokenizer(llm.tokenizer)

fsm = RegexFSM(regex_string, tokenizer)
class FSMLogitsProcessor:
def __init__(self):
fsm = FSM()
self.fsm = fsm

def __call__(
Expand All @@ -77,31 +101,37 @@ def __call__(

return biased_scores

def adapt_tokenizer(self, tokenizer):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.

"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
class RegexLogitsProcessor(FSMLogitsProcessor):
def __init__(self, regex_string, llm):
"""Compile the FSM that drives the regex-guided generation.
def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
Parameters
----------
regex_string
A string that represents a regular expression
llm
An instance of `vllm.LLM`
string = tokenizer.convert_tokens_to_string([token])
"""
fsm = RegexFSM(regex_string, llm.tokenizer)
self.fsm = fsm

# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string
class CFGLogitsProcessor(FSMLogitsProcessor):
def __init__(self, cfg_string, llm):
"""Compile the FSM that drives the cfg-guided generation.
tokenizer.convert_token_to_string = convert_token_to_string
Parameters
----------
regex_string
A string that represents a regular expression
llm
An instance of `vllm.LLM`
return tokenizer
"""
fsm = CFGFSM(cfg_string, llm.tokenizer)
self.fsm = fsm


class JSONLogitsProcessor(RegexLogitsProcessor):
Expand Down

0 comments on commit fde61a8

Please sign in to comment.