Skip to content

Commit

Permalink
Use outlines.processors for vLLM
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 authored and rlouf committed Jul 20, 2024
1 parent a7e3381 commit 47dfa4b
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 107 deletions.
4 changes: 2 additions & 2 deletions outlines/generate/choice.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import singledispatch
from typing import Callable, List

from outlines.generate.api import SequenceGenerator
from outlines.generate.api import SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.samplers import Sampler, multinomial

Expand All @@ -11,7 +11,7 @@
@singledispatch
def choice(
model, choices: List[str], sampler: Sampler = multinomial()
) -> SequenceGenerator:
) -> SequenceGeneratorAdapter:
regex_str = r"(" + r"|".join(choices) + r")"

generator = regex(model, regex_str, sampler)
Expand Down
6 changes: 4 additions & 2 deletions outlines/generate/format.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from functools import singledispatch

from outlines.fsm.types import python_types_to_regex
from outlines.generate.api import SequenceGenerator
from outlines.generate.api import SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.samplers import Sampler, multinomial

from .regex import regex


@singledispatch
def format(model, python_type, sampler: Sampler = multinomial()) -> SequenceGenerator:
def format(
model, python_type, sampler: Sampler = multinomial()
) -> SequenceGeneratorAdapter:
"""Generate structured data that can be parsed as a Python type.
Parameters
Expand Down
24 changes: 11 additions & 13 deletions outlines/generate/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,13 @@
SequenceGeneratorAdapter,
VisionSequenceGeneratorAdapter,
)
from outlines.models import MLXLM, LlamaCpp, Transformers, TransformersVision
from outlines.models import ExLlamaV2Model, TransformersVision
from outlines.samplers import Sampler, multinomial


@singledispatch
def fsm(
model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()
) -> SequenceGenerator:
fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)
return generator


@fsm.register(MLXLM)
@fsm.register(Transformers)
@fsm.register(LlamaCpp)
def fsm_unified(
model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()
) -> SequenceGeneratorAdapter:
from outlines.processors import FSMLogitsProcessor

Expand All @@ -42,3 +30,13 @@ def fsm_vision(model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()
fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
logits_processor = FSMLogitsProcessor(tokenizer=model.tokenizer, fsm=fsm)
return VisionSequenceGeneratorAdapter(model, logits_processor, sampler)


@fsm.register(ExLlamaV2Model)
def fsm_exllamav2(
model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()
) -> SequenceGenerator:
fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)
return generator
4 changes: 2 additions & 2 deletions outlines/generate/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import BaseModel

from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_signature
from outlines.generate.api import SequenceGenerator
from outlines.generate.api import SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.samplers import Sampler, multinomial

Expand All @@ -18,7 +18,7 @@ def json(
schema_object: Union[str, object, Callable],
sampler: Sampler = multinomial(),
whitespace_pattern: Optional[str] = None,
) -> SequenceGenerator:
) -> SequenceGeneratorAdapter:
"""
Generate structured JSON data with a `Transformer` model based on a specified JSON Schema.
Expand Down
43 changes: 11 additions & 32 deletions outlines/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,7 @@
SequenceGeneratorAdapter,
VisionSequenceGeneratorAdapter,
)
from outlines.models import (
MLXLM,
VLLM,
LlamaCpp,
OpenAI,
Transformers,
TransformersVision,
)
from outlines.models import ExLlamaV2Model, OpenAI, TransformersVision
from outlines.samplers import Sampler, multinomial


Expand All @@ -34,26 +27,10 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):
Returns
-------
A `SequenceGenerator` instance that generates text constrained by the
A `SequenceGeneratorAdapter` instance that generates text constrained by the
regular expression.
"""
fsm = RegexGuide(regex_str, model.tokenizer)

device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)

return generator


@regex.register(MLXLM)
@regex.register(Transformers)
@regex.register(LlamaCpp)
def regex_unified(
model,
regex_str: str,
sampler: Sampler = multinomial(),
):
from outlines.processors import RegexLogitsProcessor

logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
Expand All @@ -72,16 +49,18 @@ def regex_vision(
return VisionSequenceGeneratorAdapter(model, logits_processor, sampler)


@regex.register(VLLM)
def regex_vllm(
model: VLLM,
@regex.register(ExLlamaV2Model)
def regex_exllamav2(
model,
regex_str: str,
sampler: Sampler = multinomial(),
):
from outlines.integrations.vllm import RegexLogitsProcessor
) -> SequenceGenerator:
fsm = RegexGuide(regex_str, model.tokenizer)

logits_processor = RegexLogitsProcessor(regex_str, model.model)
return SequenceGeneratorAdapter(model, logits_processor, sampler)
device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)

return generator


@regex.register(OpenAI)
Expand Down
34 changes: 9 additions & 25 deletions outlines/generate/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,12 @@
SequenceGeneratorAdapter,
VisionSequenceGeneratorAdapter,
)
from outlines.models import (
MLXLM,
VLLM,
LlamaCpp,
OpenAI,
Transformers,
TransformersVision,
)
from outlines.models import ExLlamaV2Model, OpenAI, TransformersVision
from outlines.samplers import Sampler, multinomial


@singledispatch
def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator:
def text(model, sampler: Sampler = multinomial()) -> SequenceGeneratorAdapter:
"""Generate text with a `Transformer` model.
Note
Expand All @@ -37,33 +30,24 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator:
Returns
-------
A `SequenceGenerator` instance that generates text.
A `SequenceGeneratorAdapter` instance that generates text.
"""
fsm = StopAtEOSGuide(model.tokenizer)
device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)

return generator
return SequenceGeneratorAdapter(model, None, sampler)


@text.register(MLXLM)
@text.register(Transformers)
@text.register(LlamaCpp)
def text_unified(model, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)
@text.register(ExLlamaV2Model)
def text_exllamav2(model, sampler: Sampler = multinomial()) -> SequenceGenerator:
fsm = StopAtEOSGuide(model.tokenizer)
device = model.device
return SequenceGenerator(fsm, model, sampler, device)


@text.register(TransformersVision)
def text_vision(model, sampler: Sampler = multinomial()):
return VisionSequenceGeneratorAdapter(model, None, sampler)


@text.register(VLLM)
def text_vllm(model: VLLM, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)


@text.register(OpenAI)
def text_openai(model: OpenAI, sampler: Sampler = multinomial()) -> OpenAI:
if not isinstance(sampler, multinomial):
Expand Down
22 changes: 22 additions & 0 deletions outlines/models/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, List, Optional, Union

from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.integrations.utils import adapt_tokenizer

if TYPE_CHECKING:
from vllm import LLM
Expand All @@ -22,6 +23,23 @@ def __init__(self, model: "LLM"):
self.model = model
self.lora_request = None

self.tokenizer = self._get_tokenizer()

def _get_tokenizer(self):
if hasattr(self.model, "get_tokenizer"):
tokenizer = self.model.get_tokenizer()
elif hasattr(self.model, "tokenizer"):
if hasattr(self.model.tokenizer, "tokenizer"):
tokenizer = self.model.tokenizer.tokenizer
else:
tokenizer = self.model.tokenizer
else:
raise ValueError(
"The provided LLM instance neither has a "
"`tokenizer` attribute or a `get_tokenizer` method."
)
return adapt_tokenizer(tokenizer=tokenizer)

def generate(
self,
prompts: Union[str, List[str]],
Expand Down Expand Up @@ -100,6 +118,10 @@ def generate(
sampling_params.top_p = top_p
if top_k is not None and sampling_params.top_k == -1:
sampling_params.top_k = top_k
# TODO: remove this if statement once fixed
# https://github.com/vllm-project/vllm/issues/5404#issuecomment-2175972897
if top_k == 1:
sampling_params.repetition_penalty = 0
if temperature is not None and sampling_params.temperature == 1.0:
sampling_params.temperature = temperature
if sampler == "beam_search":
Expand Down
20 changes: 10 additions & 10 deletions outlines/processors/base_logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,18 @@ def __call__(
3) Call self.process_logits() to perform business logic
4) Cast logits back to original array library type
"""

# ensure logits are torch Tensors
torch_logits = self._to_torch(logits)
input_ids = self._to_torch(input_ids)

assert torch_logits.shape[:-1] == self._to_torch(input_ids).shape[:-1]

# ensure input_ids are List
if not isinstance(input_ids, list):
input_ids = input_ids.tolist() # compatible with numpy, torch, and mlx
assert torch_logits.shape[:-1] == input_ids.shape[:-1]

# Guarantee passed as 2D Tensors, then covert back to original (1D or 2D) shape
if len(torch_logits.shape) == 2:
processed_logits = self.process_logits(input_ids, torch_logits)
processed_logits = self.process_logits(input_ids.tolist(), torch_logits)
elif len(torch_logits.shape) == 1:
processed_logits = self.process_logits(
[input_ids], torch_logits.unsqueeze(0)
[input_ids.tolist()], torch_logits.unsqueeze(0)
).squeeze(0)

# return logits as passed array type
Expand All @@ -97,7 +93,7 @@ def _to_torch(tensor_like: Array) -> torch.Tensor:
elif isinstance(tensor_like, np.ndarray):
return torch.from_numpy(tensor_like)

elif isinstance(tensor_like, list):
elif isinstance(tensor_like, (list, tuple)):
return torch.tensor(tensor_like)

elif is_mlx_array_type(type(tensor_like)):
Expand All @@ -108,7 +104,8 @@ def _to_torch(tensor_like: Array) -> torch.Tensor:
else:
raise TypeError(
"LogitsProcessor must be called with either np.NDArray, "
"torch.Tensor, list, or mlx.core.array typed logits"
"torch.Tensor, list, or mlx.core.array typed logits. "
f"Logits type: `{type(tensor_like)}`"
)

@staticmethod
Expand All @@ -123,6 +120,9 @@ def _from_torch(tensor: torch.Tensor, target_type: Type) -> Array:
elif target_type == list:
return tensor.detach().tolist()

elif target_type == tuple:
return tuple(tensor.detach().tolist())

elif is_mlx_array_type(target_type):
import mlx.core as mx

Expand Down
27 changes: 11 additions & 16 deletions outlines/processors/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,8 @@ def __init__(self, tokenizer: "Tokenizer", fsm: Guide):
The finite state machine which is used to bias the logits.
"""
self.tokenizer = tokenizer
self._fsm_states: Dict[int, int] = {}
self._fsm_states: Dict[int, int] = {hash(tuple([])): 0}
self.fsm: Guide = fsm
self._is_first_token = True
self._seq_start_idx: Optional[int] = None

def process_logits(
Expand All @@ -83,25 +82,21 @@ def process_logits(
torch.Tensor
The biased logits.
"""
sequence_states: List[int] = [] # vector of states corresponding to `input_ids`

if self._is_first_token:
self._is_first_token = False
if self._seq_start_idx is None:
self._seq_start_idx = len(input_ids[0])

self._fsm_states = {hash(tuple([])): 0}
sequence_states = [0] * len(input_ids)

else:
for seq_ids in input_ids:
prev_state_key = hash(tuple(seq_ids[self._seq_start_idx : -1]))
prev_state = self._fsm_states[prev_state_key]
sequence_states: List[int] = [] # vector of states corresponding to `input_ids`

curr_state_key = hash(tuple(seq_ids[self._seq_start_idx :]))
curr_state = self.fsm.get_next_state(prev_state, seq_ids[-1])
for seq_ids in input_ids:
gen_ids = seq_ids[self._seq_start_idx :]
curr_state_key = hash(tuple(gen_ids))

if curr_state_key not in self._fsm_states:
prev_state = self._fsm_states[hash(tuple(gen_ids[:-1]))]
curr_state = self.fsm.get_next_state(prev_state, gen_ids[-1])
self._fsm_states[curr_state_key] = curr_state
sequence_states.append(curr_state)

sequence_states.append(self._fsm_states[curr_state_key])

mask = torch.full_like(logits, -math.inf)
for i, fsm_state in enumerate(sequence_states):
Expand Down
Loading

0 comments on commit 47dfa4b

Please sign in to comment.