diff --git a/docs/reference/vllm.md b/docs/reference/vllm.md index 4f5c095d2..787e3b24a 100644 --- a/docs/reference/vllm.md +++ b/docs/reference/vllm.md @@ -6,29 +6,50 @@ Outlines can be deployed as an LLM service using the vLLM inference engine and a pip install outlines[serve] ``` +Note: only vLLM v0.2.6 with ray 2.9.0 is supported at the moment. + You can then start the server with: -```python +```bash python -m outlines.serve.serve ``` -This will by default start a server at `http://127.0.0.1:8000` (check what the console says, though) with the OPT-125M model. If you want to specify another model: +This will by default start a server at `http://127.0.0.1:8000` (check what the console says, though) with the OPT-125M model. If you want to specify another model (e.g. Mistral-7B-Instruct-v0.2), you can do so with the `--model` parameter: -```python -python -m outlines.serve.serve --model="mistralai/Mistral-7B-v0.1" +```bash +python -m outlines.serve.serve --model="mistralai/Mistral-7B-Instruct-v0.2" ``` -You can then query the model in shell by passing a prompt and a [JSON Schema][jsonschema]{:target="_blank"} specification for the structure of the output: +You can then query the model in shell by passing a prompt and either + +1. a [JSON Schema][jsonschema]{:target="_blank"} specification or +2. a [Regex][regex]{:target="_blank"} pattern + +with the `schema` or `regex` parameters, respectively, to the `/generate` endpoint. If both are specified, the schema will be used. If neither is specified, the generated text will be unconstrained. + +For example, to generate a string that matches the schema `{"type": "string"}` (any string): ```bash -curl http://0.0.0.1:8000 \ +curl http://127.0.0.1:8000/generate \ -d '{ "prompt": "What is the capital of France?", "schema": {"type": "string"} }' ``` -Or use the [requests][requests]{:target="_blank"} library from another python program. You can read the [vLLM documentation][vllm]{:target="_blank"} for more details. +To generate a string that matches the regex `(-)?(0|[1-9][0-9]*)(\.[0-9]+)?([eE][+-][0-9]+)?` (a number): + +```bash +curl http://127.0.0.1:8000/generate \ + -d '{ + "prompt": "What is Pi? Give me the first 15 digits: ", + "regex": "(-)?(0|[1-9][0-9]*)(\\.[0-9]+)?([eE][+-][0-9]+)?" + }' +``` + +Instead of `curl`, you can also use the [requests][requests]{:target="_blank"} library from another python program. + +Please consult the [vLLM documentation][vllm]{:target="_blank"} for details on additional request parameters. You can also [read the code](https://github.com/outlines-dev/outlines/blob/main/outlines/serve/serve.py) in case you need to customize the solution to your needs. diff --git a/examples/vllm_integration.py b/examples/vllm_integration.py index c7ca226e6..c2d38883a 100644 --- a/examples/vllm_integration.py +++ b/examples/vllm_integration.py @@ -1,212 +1,11 @@ -import math -from collections import defaultdict -from typing import List, Optional - -import torch -import torch.nn as nn import vllm +import vllm.model_executor.layers.sampler as sampler from pydantic import BaseModel -from vllm.model_executor.layers.sampler import ( - _SAMPLING_EPS, - _apply_min_p, - _apply_penalties, - _apply_top_p_top_k, - _build_sampler_output, - _get_logits, - _get_logprobs, - _get_penalties, - _get_temperatures, - _get_top_p_top_k_min_p, - _prune_hidden_states, - _sample, -) - -from outlines.fsm.fsm import RegexFSM -from outlines.fsm.json_schema import build_regex_from_object - - -def _patched_apply_logits_processors( - logits, - sampling_metadata, -): - logits_row_idx = 0 - found_logits_processors = False - for seq_ids, sampling_params in sampling_metadata.seq_groups: - logits_processors = sampling_params.logits_processors - if logits_processors: - found_logits_processors = True - for seq_id in seq_ids: - logits_row = logits[logits_row_idx] - token_ids = sampling_metadata.seq_data[seq_id].output_token_ids - for logits_processor in logits_processors: - logits_row = logits_processor(seq_id, token_ids, logits_row) - logits[logits_row_idx] = logits_row - logits_row_idx += 1 - else: - logits_row_idx += len(seq_ids) - if found_logits_processors: - assert logits_row_idx == logits.shape[0] - return logits - - -class JSONLogitsProcessor: - def __init__(self, pydantic_model, llm): - """Compile the FSM that drives the JSON-guided generation. - - Parameters - ---------- - pydantic_model - A Pydantic `BaseModel` that encodes the structure we want - the model to generate. - llm - An instance of `vllm.LLM` - - """ - schema = pydantic_model.schema_json() - regex_str = build_regex_from_object(schema) - tokenizer = self.adapt_tokenizer(llm.get_tokenizer()) - - fsm = RegexFSM(regex_str, tokenizer) - self.fsm = fsm - - def __call__( - self, seq_id: int, input_ids: List[int], scores: torch.Tensor - ) -> torch.Tensor: - """Use the FSM to bias the logits before sampling the next token.""" - - if len(input_ids) == 0: # Initialize the fsm states - self.fsm_state = defaultdict(int) - else: - last_token = input_ids[-1] - self.fsm_state[seq_id] = self.fsm.next_state( - self.fsm_state[seq_id], last_token - ) - - allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id]) - - mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device) - mask[allowed_tokens] = 0 - biased_scores = scores + mask - - return biased_scores - - def adapt_tokenizer(self, tokenizer): - """Adapt vLLM's tokenizer to use to compile the FSM. - - The API of Outlines tokenizers is slightly different to that of - `transformers`. In addition we need to handle the missing spaces to - Llama's tokenizer to be able to compile FSMs for this model. - - """ - tokenizer.vocabulary = tokenizer.get_vocab() - tokenizer.special_tokens = set(tokenizer.all_special_tokens) - - def convert_token_to_string(token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE - - string = tokenizer.convert_tokens_to_string([token]) - - # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": - return " " + string - - return string - - tokenizer.convert_token_to_string = convert_token_to_string - - return tokenizer - - -class Sampler(nn.Module): - """Samples the next tokens from the model's outputs. - - This layer does the following: - 1. Discard the hidden states that are not used for sampling (i.e., all - tokens except the final one in each prompt). - 2. Compute the logits for the next tokens. - 3. Apply presence, frequency and repetition penalties. - 4. Apply temperature scaling. - 5. Apply top-p and top-k truncation. - 6. Sample the next tokens. - Here, each sequence group within the batch can have different sampling - parameters (e.g., sampling method, temperature, top-p, top-k, etc.). - """ - - def __init__(self, vocab_size: int) -> None: - super().__init__() - self.vocab_size = vocab_size - - def forward( - self, - embedding: torch.Tensor, - hidden_states: torch.Tensor, - sampling_metadata, - embedding_bias: Optional[torch.Tensor] = None, - ): - # Get the hidden states that we use for sampling. - hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) - - # Get the logits for the next tokens. - logits = _get_logits(hidden_states, embedding, embedding_bias, self.vocab_size) - - # Apply logits processors (if any). - logits = _patched_apply_logits_processors(logits, sampling_metadata) - # Apply presence and frequency penalties. - presence_penalties, frequency_penalties, repetition_penalties = _get_penalties( - sampling_metadata - ) - assert len(presence_penalties) == logits.shape[0] - assert len(frequency_penalties) == logits.shape[0] - assert len(repetition_penalties) == logits.shape[0] - logits = _apply_penalties( - logits, - sampling_metadata, - presence_penalties, - frequency_penalties, - repetition_penalties, - ) - - # Apply temperature scaling. - temperatures = _get_temperatures(sampling_metadata) - assert len(temperatures) == logits.shape[0] - if any(t != 1.0 for t in temperatures): - t = torch.tensor(temperatures, dtype=logits.dtype, device=logits.device) - # Use in-place division to avoid creating a new tensor. - logits.div_(t.unsqueeze(dim=1)) - - # Apply top-p and top-k truncation. - top_ps, top_ks, min_ps = _get_top_p_top_k_min_p( - sampling_metadata, self.vocab_size - ) - assert len(top_ps) == len(top_ks) == logits.shape[0] - do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps) - do_top_k = any(k != self.vocab_size for k in top_ks) - if do_top_p or do_top_k: - logits = _apply_top_p_top_k(logits, top_ps, top_ks) - - do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps) - if do_min_p: - logits = _apply_min_p(logits, min_ps) - - # We use float32 for probabilities and log probabilities. - # Compute the probabilities. - probs = torch.softmax(logits, dim=-1, dtype=torch.float) - # Compute the log probabilities. - # Use log_softmax to ensure numerical stability. - logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) - - # Sample the next tokens. - sample_results = _sample(probs, logprobs, sampling_metadata) - # Get the logprobs query results. - prompt_logprobs, sample_logprobs = _get_logprobs( - logprobs, sampling_metadata, sample_results - ) - return _build_sampler_output( - sample_results, sampling_metadata, prompt_logprobs, sample_logprobs - ) +from outlines.serve.vllm import JSONLogitsProcessor, _patched_apply_logits_processors -vllm.model_executor.layers.sampler.Sampler = Sampler +# Patch the _apply_logits_processors so it is compatible with `JSONLogitsProcessor` +sampler._apply_logits_processors = _patched_apply_logits_processors class User(BaseModel):