Skip to content

Commit

Permalink
Update the documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
tscholak authored and rlouf committed Dec 29, 2023
1 parent 0032c65 commit 5f6166a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 212 deletions.
35 changes: 28 additions & 7 deletions docs/reference/vllm.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,50 @@ Outlines can be deployed as an LLM service using the vLLM inference engine and a
pip install outlines[serve]
```

Note: only vLLM v0.2.6 with ray 2.9.0 is supported at the moment.

You can then start the server with:

```python
```bash
python -m outlines.serve.serve
```

This will by default start a server at `http://127.0.0.1:8000` (check what the console says, though) with the OPT-125M model. If you want to specify another model:
This will by default start a server at `http://127.0.0.1:8000` (check what the console says, though) with the OPT-125M model. If you want to specify another model (e.g. Mistral-7B-Instruct-v0.2), you can do so with the `--model` parameter:

```python
python -m outlines.serve.serve --model="mistralai/Mistral-7B-v0.1"
```bash
python -m outlines.serve.serve --model="mistralai/Mistral-7B-Instruct-v0.2"
```

You can then query the model in shell by passing a prompt and a [JSON Schema][jsonschema]{:target="_blank"} specification for the structure of the output:
You can then query the model in shell by passing a prompt and either

1. a [JSON Schema][jsonschema]{:target="_blank"} specification or
2. a [Regex][regex]{:target="_blank"} pattern

with the `schema` or `regex` parameters, respectively, to the `/generate` endpoint. If both are specified, the schema will be used. If neither is specified, the generated text will be unconstrained.

For example, to generate a string that matches the schema `{"type": "string"}` (any string):

```bash
curl http://0.0.0.1:8000 \
curl http://127.0.0.1:8000/generate \
-d '{
"prompt": "What is the capital of France?",
"schema": {"type": "string"}
}'
```

Or use the [requests][requests]{:target="_blank"} library from another python program. You can read the [vLLM documentation][vllm]{:target="_blank"} for more details.
To generate a string that matches the regex `(-)?(0|[1-9][0-9]*)(\.[0-9]+)?([eE][+-][0-9]+)?` (a number):

```bash
curl http://127.0.0.1:8000/generate \
-d '{
"prompt": "What is Pi? Give me the first 15 digits: ",
"regex": "(-)?(0|[1-9][0-9]*)(\\.[0-9]+)?([eE][+-][0-9]+)?"
}'
```

Instead of `curl`, you can also use the [requests][requests]{:target="_blank"} library from another python program.

Please consult the [vLLM documentation][vllm]{:target="_blank"} for details on additional request parameters.

You can also [read the code](https://github.com/outlines-dev/outlines/blob/main/outlines/serve/serve.py) in case you need to customize the solution to your needs.

Expand Down
209 changes: 4 additions & 205 deletions examples/vllm_integration.py
Original file line number Diff line number Diff line change
@@ -1,212 +1,11 @@
import math
from collections import defaultdict
from typing import List, Optional

import torch
import torch.nn as nn
import vllm
import vllm.model_executor.layers.sampler as sampler
from pydantic import BaseModel
from vllm.model_executor.layers.sampler import (
_SAMPLING_EPS,
_apply_min_p,
_apply_penalties,
_apply_top_p_top_k,
_build_sampler_output,
_get_logits,
_get_logprobs,
_get_penalties,
_get_temperatures,
_get_top_p_top_k_min_p,
_prune_hidden_states,
_sample,
)

from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_object


def _patched_apply_logits_processors(
logits,
sampling_metadata,
):
logits_row_idx = 0
found_logits_processors = False
for seq_ids, sampling_params in sampling_metadata.seq_groups:
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
logits_row = logits[logits_row_idx]
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(seq_id, token_ids, logits_row)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
logits_row_idx += len(seq_ids)
if found_logits_processors:
assert logits_row_idx == logits.shape[0]
return logits


class JSONLogitsProcessor:
def __init__(self, pydantic_model, llm):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
----------
pydantic_model
A Pydantic `BaseModel` that encodes the structure we want
the model to generate.
llm
An instance of `vllm.LLM`
"""
schema = pydantic_model.schema_json()
regex_str = build_regex_from_object(schema)
tokenizer = self.adapt_tokenizer(llm.get_tokenizer())

fsm = RegexFSM(regex_str, tokenizer)
self.fsm = fsm

def __call__(
self, seq_id: int, input_ids: List[int], scores: torch.Tensor
) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""

if len(input_ids) == 0: # Initialize the fsm states
self.fsm_state = defaultdict(int)
else:
last_token = input_ids[-1]
self.fsm_state[seq_id] = self.fsm.next_state(
self.fsm_state[seq_id], last_token
)

allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])

mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device)
mask[allowed_tokens] = 0
biased_scores = scores + mask

return biased_scores

def adapt_tokenizer(self, tokenizer):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)

def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = tokenizer.convert_tokens_to_string([token])

# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string

tokenizer.convert_token_to_string = convert_token_to_string

return tokenizer


class Sampler(nn.Module):
"""Samples the next tokens from the model's outputs.
This layer does the following:
1. Discard the hidden states that are not used for sampling (i.e., all
tokens except the final one in each prompt).
2. Compute the logits for the next tokens.
3. Apply presence, frequency and repetition penalties.
4. Apply temperature scaling.
5. Apply top-p and top-k truncation.
6. Sample the next tokens.
Here, each sequence group within the batch can have different sampling
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
"""

def __init__(self, vocab_size: int) -> None:
super().__init__()
self.vocab_size = vocab_size

def forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
sampling_metadata,
embedding_bias: Optional[torch.Tensor] = None,
):
# Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)

# Get the logits for the next tokens.
logits = _get_logits(hidden_states, embedding, embedding_bias, self.vocab_size)

# Apply logits processors (if any).
logits = _patched_apply_logits_processors(logits, sampling_metadata)
# Apply presence and frequency penalties.
presence_penalties, frequency_penalties, repetition_penalties = _get_penalties(
sampling_metadata
)
assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0]
assert len(repetition_penalties) == logits.shape[0]
logits = _apply_penalties(
logits,
sampling_metadata,
presence_penalties,
frequency_penalties,
repetition_penalties,
)

# Apply temperature scaling.
temperatures = _get_temperatures(sampling_metadata)
assert len(temperatures) == logits.shape[0]
if any(t != 1.0 for t in temperatures):
t = torch.tensor(temperatures, dtype=logits.dtype, device=logits.device)
# Use in-place division to avoid creating a new tensor.
logits.div_(t.unsqueeze(dim=1))

# Apply top-p and top-k truncation.
top_ps, top_ks, min_ps = _get_top_p_top_k_min_p(
sampling_metadata, self.vocab_size
)
assert len(top_ps) == len(top_ks) == logits.shape[0]
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
do_top_k = any(k != self.vocab_size for k in top_ks)
if do_top_p or do_top_k:
logits = _apply_top_p_top_k(logits, top_ps, top_ks)

do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps)
if do_min_p:
logits = _apply_min_p(logits, min_ps)

# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities.
# Use log_softmax to ensure numerical stability.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)

# Sample the next tokens.
sample_results = _sample(probs, logprobs, sampling_metadata)
# Get the logprobs query results.
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, sampling_metadata, sample_results
)
return _build_sampler_output(
sample_results, sampling_metadata, prompt_logprobs, sample_logprobs
)

from outlines.serve.vllm import JSONLogitsProcessor, _patched_apply_logits_processors

vllm.model_executor.layers.sampler.Sampler = Sampler
# Patch the _apply_logits_processors so it is compatible with `JSONLogitsProcessor`
sampler._apply_logits_processors = _patched_apply_logits_processors


class User(BaseModel):
Expand Down

0 comments on commit 5f6166a

Please sign in to comment.