Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use outlines.processors and SequenceGeneratorAdapter for outlines.models.vllm #1053

Merged
merged 1 commit into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions outlines/generate/choice.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import singledispatch
from typing import Callable, List

from outlines.generate.api import SequenceGenerator
from outlines.generate.api import SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.samplers import Sampler, multinomial

Expand All @@ -11,7 +11,7 @@
@singledispatch
def choice(
model, choices: List[str], sampler: Sampler = multinomial()
) -> SequenceGenerator:
) -> SequenceGeneratorAdapter:
regex_str = r"(" + r"|".join(choices) + r")"

generator = regex(model, regex_str, sampler)
Expand Down
6 changes: 4 additions & 2 deletions outlines/generate/format.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from functools import singledispatch

from outlines.fsm.types import python_types_to_regex
from outlines.generate.api import SequenceGenerator
from outlines.generate.api import SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.samplers import Sampler, multinomial

from .regex import regex


@singledispatch
def format(model, python_type, sampler: Sampler = multinomial()) -> SequenceGenerator:
def format(
model, python_type, sampler: Sampler = multinomial()
) -> SequenceGeneratorAdapter:
"""Generate structured data that can be parsed as a Python type.

Parameters
Expand Down
24 changes: 11 additions & 13 deletions outlines/generate/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,13 @@
SequenceGeneratorAdapter,
VisionSequenceGeneratorAdapter,
)
from outlines.models import MLXLM, LlamaCpp, Transformers, TransformersVision
from outlines.models import ExLlamaV2Model, TransformersVision
from outlines.samplers import Sampler, multinomial


@singledispatch
def fsm(
model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()
) -> SequenceGenerator:
fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)
return generator


@fsm.register(MLXLM)
@fsm.register(Transformers)
@fsm.register(LlamaCpp)
def fsm_unified(
model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()
) -> SequenceGeneratorAdapter:
from outlines.processors import FSMLogitsProcessor

Expand All @@ -42,3 +30,13 @@ def fsm_vision(model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()
fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
logits_processor = FSMLogitsProcessor(tokenizer=model.tokenizer, fsm=fsm)
return VisionSequenceGeneratorAdapter(model, logits_processor, sampler)


@fsm.register(ExLlamaV2Model)
def fsm_exllamav2(
model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()
) -> SequenceGenerator:
fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)
return generator
4 changes: 2 additions & 2 deletions outlines/generate/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import BaseModel

from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_signature
from outlines.generate.api import SequenceGenerator
from outlines.generate.api import SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.samplers import Sampler, multinomial

Expand All @@ -18,7 +18,7 @@ def json(
schema_object: Union[str, object, Callable],
sampler: Sampler = multinomial(),
whitespace_pattern: Optional[str] = None,
) -> SequenceGenerator:
) -> SequenceGeneratorAdapter:
"""
Generate structured JSON data with a `Transformer` model based on a specified JSON Schema.
Expand Down
43 changes: 11 additions & 32 deletions outlines/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,7 @@
SequenceGeneratorAdapter,
VisionSequenceGeneratorAdapter,
)
from outlines.models import (
MLXLM,
VLLM,
LlamaCpp,
OpenAI,
Transformers,
TransformersVision,
)
from outlines.models import ExLlamaV2Model, OpenAI, TransformersVision
from outlines.samplers import Sampler, multinomial


Expand All @@ -34,26 +27,10 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):
Returns
-------
A `SequenceGenerator` instance that generates text constrained by the
A `SequenceGeneratorAdapter` instance that generates text constrained by the
regular expression.
"""
fsm = RegexGuide(regex_str, model.tokenizer)

device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)

return generator


@regex.register(MLXLM)
@regex.register(Transformers)
@regex.register(LlamaCpp)
def regex_unified(
model,
regex_str: str,
sampler: Sampler = multinomial(),
):
from outlines.processors import RegexLogitsProcessor

logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
Expand All @@ -72,16 +49,18 @@ def regex_vision(
return VisionSequenceGeneratorAdapter(model, logits_processor, sampler)


@regex.register(VLLM)
def regex_vllm(
model: VLLM,
@regex.register(ExLlamaV2Model)
def regex_exllamav2(
model,
regex_str: str,
sampler: Sampler = multinomial(),
):
from outlines.integrations.vllm import RegexLogitsProcessor
) -> SequenceGenerator:
fsm = RegexGuide(regex_str, model.tokenizer)

logits_processor = RegexLogitsProcessor(regex_str, model.model)
return SequenceGeneratorAdapter(model, logits_processor, sampler)
device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)

return generator


@regex.register(OpenAI)
Expand Down
34 changes: 9 additions & 25 deletions outlines/generate/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,12 @@
SequenceGeneratorAdapter,
VisionSequenceGeneratorAdapter,
)
from outlines.models import (
MLXLM,
VLLM,
LlamaCpp,
OpenAI,
Transformers,
TransformersVision,
)
from outlines.models import ExLlamaV2Model, OpenAI, TransformersVision
from outlines.samplers import Sampler, multinomial


@singledispatch
def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator:
def text(model, sampler: Sampler = multinomial()) -> SequenceGeneratorAdapter:
"""Generate text with a `Transformer` model.
Note
Expand All @@ -37,33 +30,24 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator:
Returns
-------
A `SequenceGenerator` instance that generates text.
A `SequenceGeneratorAdapter` instance that generates text.
"""
fsm = StopAtEOSGuide(model.tokenizer)
device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)

return generator
return SequenceGeneratorAdapter(model, None, sampler)


@text.register(MLXLM)
@text.register(Transformers)
@text.register(LlamaCpp)
def text_unified(model, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)
@text.register(ExLlamaV2Model)
def text_exllamav2(model, sampler: Sampler = multinomial()) -> SequenceGenerator:
fsm = StopAtEOSGuide(model.tokenizer)
device = model.device
return SequenceGenerator(fsm, model, sampler, device)


@text.register(TransformersVision)
def text_vision(model, sampler: Sampler = multinomial()):
return VisionSequenceGeneratorAdapter(model, None, sampler)


@text.register(VLLM)
def text_vllm(model: VLLM, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)


@text.register(OpenAI)
def text_openai(model: OpenAI, sampler: Sampler = multinomial()) -> OpenAI:
if not isinstance(sampler, multinomial):
Expand Down
22 changes: 22 additions & 0 deletions outlines/models/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, List, Optional, Union

from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.integrations.utils import adapt_tokenizer

if TYPE_CHECKING:
from vllm import LLM
Expand All @@ -22,6 +23,23 @@ def __init__(self, model: "LLM"):
self.model = model
self.lora_request = None

self.tokenizer = self._get_tokenizer()

def _get_tokenizer(self):
if hasattr(self.model, "get_tokenizer"):
tokenizer = self.model.get_tokenizer()
Comment on lines +29 to +30

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heads up, for the AyncLLMEngine (as shown in the outlines vLLM example), this will return a coroutine: https://github.com/vllm-project/vllm/blob/main/vllm/engine/async_llm_engine.py#L506 .

I'm trying to figure out the best path forward because I'd love to use this with my vLLM-based service, but it seems like this work is part of something bigger so I don't want to dive in and start propagating async through this code without checking in with you first. Happy to contribute, but could use a little guidance on the strategy 😁

Copy link
Collaborator Author

@lapp0 lapp0 Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's uncertain whether we will move towards async for outlines.generate, but it has been proposed #655 Currently outlines.generate with outlines.models.vllm uses a vllm.LLM

Bare in mind that outlines.serve already has a vllm server integration and vice versa, vllm has an outlines.processors integration in progress

Does outlines.serve or vLLM's outlines integration satisfy your needs, or were you thinking of something different?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh, thank you for the sanity check! After re-reviewing the outlines.serve code, I realized I didn't go deep enough and needed to pass my engine.engine (engines all the way down 🐢) to get all the way to the vllm.LLM. Thanks again for the pointers!

Copy link
Collaborator Author

@lapp0 lapp0 Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem! Bare in mind that after our next major release (because of this PR), the tokenizer, not the engine will be passed to the processor. serve.py has a PR to reflect this behavior https://github.com/outlines-dev/outlines/pull/1061/files#diff-535a1da5f8addb89d07782185c32b54f85189b25786d1c9b7cbd002b55939e16R74

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted! Will keep an eye out for that. Thanks again for everything; super excited for the awesome capabilities you all have enabled with outlines!

elif hasattr(self.model, "tokenizer"):
if hasattr(self.model.tokenizer, "tokenizer"):
tokenizer = self.model.tokenizer.tokenizer
else:
tokenizer = self.model.tokenizer
else:
raise ValueError(
"The provided LLM instance neither has a "
"`tokenizer` attribute or a `get_tokenizer` method."
)
return adapt_tokenizer(tokenizer=tokenizer)

def generate(
self,
prompts: Union[str, List[str]],
Expand Down Expand Up @@ -100,6 +118,10 @@ def generate(
sampling_params.top_p = top_p
if top_k is not None and sampling_params.top_k == -1:
sampling_params.top_k = top_k
# TODO: remove this if statement once fixed
# https://github.com/vllm-project/vllm/issues/5404#issuecomment-2175972897
if top_k == 1:
sampling_params.repetition_penalty = 0
if temperature is not None and sampling_params.temperature == 1.0:
sampling_params.temperature = temperature
if sampler == "beam_search":
Expand Down
20 changes: 10 additions & 10 deletions outlines/processors/base_logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,18 @@ def __call__(
3) Call self.process_logits() to perform business logic
4) Cast logits back to original array library type
"""

# ensure logits are torch Tensors
torch_logits = self._to_torch(logits)
input_ids = self._to_torch(input_ids)

assert torch_logits.shape[:-1] == self._to_torch(input_ids).shape[:-1]

# ensure input_ids are List
if not isinstance(input_ids, list):
input_ids = input_ids.tolist() # compatible with numpy, torch, and mlx
assert torch_logits.shape[:-1] == input_ids.shape[:-1]

# Guarantee passed as 2D Tensors, then covert back to original (1D or 2D) shape
if len(torch_logits.shape) == 2:
processed_logits = self.process_logits(input_ids, torch_logits)
processed_logits = self.process_logits(input_ids.tolist(), torch_logits)
elif len(torch_logits.shape) == 1:
processed_logits = self.process_logits(
[input_ids], torch_logits.unsqueeze(0)
[input_ids.tolist()], torch_logits.unsqueeze(0)
).squeeze(0)

# return logits as passed array type
Expand All @@ -97,7 +93,7 @@ def _to_torch(tensor_like: Array) -> torch.Tensor:
elif isinstance(tensor_like, np.ndarray):
return torch.from_numpy(tensor_like)

elif isinstance(tensor_like, list):
elif isinstance(tensor_like, (list, tuple)):
return torch.tensor(tensor_like)

elif is_mlx_array_type(type(tensor_like)):
Expand All @@ -108,7 +104,8 @@ def _to_torch(tensor_like: Array) -> torch.Tensor:
else:
raise TypeError(
"LogitsProcessor must be called with either np.NDArray, "
"torch.Tensor, list, or mlx.core.array typed logits"
"torch.Tensor, list, or mlx.core.array typed logits. "
f"Logits type: `{type(tensor_like)}`"
)

@staticmethod
Expand All @@ -123,6 +120,9 @@ def _from_torch(tensor: torch.Tensor, target_type: Type) -> Array:
elif target_type == list:
return tensor.detach().tolist()

elif target_type == tuple:
return tuple(tensor.detach().tolist())

elif is_mlx_array_type(target_type):
import mlx.core as mx

Expand Down
27 changes: 11 additions & 16 deletions outlines/processors/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,8 @@ def __init__(self, tokenizer: "Tokenizer", fsm: Guide):
The finite state machine which is used to bias the logits.
"""
self.tokenizer = tokenizer
self._fsm_states: Dict[int, int] = {}
self._fsm_states: Dict[int, int] = {hash(tuple([])): 0}
self.fsm: Guide = fsm
self._is_first_token = True
self._seq_start_idx: Optional[int] = None

def process_logits(
Expand All @@ -83,25 +82,21 @@ def process_logits(
torch.Tensor
The biased logits.
"""
sequence_states: List[int] = [] # vector of states corresponding to `input_ids`

if self._is_first_token:
self._is_first_token = False
if self._seq_start_idx is None:
self._seq_start_idx = len(input_ids[0])

self._fsm_states = {hash(tuple([])): 0}
sequence_states = [0] * len(input_ids)

else:
for seq_ids in input_ids:
prev_state_key = hash(tuple(seq_ids[self._seq_start_idx : -1]))
prev_state = self._fsm_states[prev_state_key]
sequence_states: List[int] = [] # vector of states corresponding to `input_ids`

curr_state_key = hash(tuple(seq_ids[self._seq_start_idx :]))
curr_state = self.fsm.get_next_state(prev_state, seq_ids[-1])
for seq_ids in input_ids:
gen_ids = seq_ids[self._seq_start_idx :]
curr_state_key = hash(tuple(gen_ids))

if curr_state_key not in self._fsm_states:
prev_state = self._fsm_states[hash(tuple(gen_ids[:-1]))]
curr_state = self.fsm.get_next_state(prev_state, gen_ids[-1])
self._fsm_states[curr_state_key] = curr_state
sequence_states.append(curr_state)

sequence_states.append(self._fsm_states[curr_state_key])

mask = torch.full_like(logits, -math.inf)
for i, fsm_state in enumerate(sequence_states):
Expand Down
Loading
Loading