Skip to content

Commit

Permalink
introduce shared logits processors and example outlines.generate.regex
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed May 29, 2024
1 parent 538f77a commit 62cc80e
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 69 deletions.
23 changes: 12 additions & 11 deletions outlines/generate/regex.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import singledispatch

from outlines.fsm.guide import RegexGuide
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.generate.api import SequenceGeneratorAdapter
from outlines.integrations.logits_processors import RegexLogitsProcessor
from outlines.models import OpenAI
from outlines.models.llamacpp import LlamaCpp
from outlines.models.vllm import VLLM
Expand Down Expand Up @@ -29,12 +29,11 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):
regular expression.
"""
fsm = RegexGuide(regex_str, model.tokenizer)
# PR TODO: add device argument
# device = model.device

device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)

return generator
logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@regex.register(LlamaCpp)
Expand All @@ -43,9 +42,10 @@ def regex_llamacpp(
regex_str: str,
sampler: Sampler = multinomial(),
):
from outlines.integrations.llamacpp import RegexLogitsProcessor
from outlines.integrations.llamacpp import LlamaCppTokenizer

logits_processor = RegexLogitsProcessor(regex_str, llm=model.model)
tokenizer = LlamaCppTokenizer(model=model.model)
logits_processor = RegexLogitsProcessor(regex_str, tokenizer=tokenizer)
return SequenceGeneratorAdapter(model, logits_processor, sampler)


Expand All @@ -55,9 +55,10 @@ def regex_vllm(
regex_str: str,
sampler: Sampler = multinomial(),
):
from outlines.integrations.vllm import RegexLogitsProcessor
from outlines.integrations.utils import get_vllm_tokenizer

logits_processor = RegexLogitsProcessor(regex_str, model.model)
tokenizer = get_vllm_tokenizer(model=model.model)
logits_processor = RegexLogitsProcessor(regex_str, tokenizer)
return SequenceGeneratorAdapter(model, logits_processor, sampler)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Make LlamaCpp compatible with Outlines' structured generation.
"""
_______________________________
/ Don't want to self-host? \
\\ Try .json at http://dottxt.co /
Expand All @@ -24,9 +23,9 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import math
from typing import TYPE_CHECKING, Dict, Optional, Set, Type, Union
from abc import abstractmethod
from typing import List, Optional, Protocol, Type, Union

import numpy as np
import torch
Expand All @@ -36,39 +35,54 @@
from outlines.fsm.guide import CFGGuide, Guide, RegexGuide
from outlines.fsm.json_schema import build_regex_from_schema
from outlines.integrations.utils import convert_json_schema_to_str
from outlines.models.tokenizer import Tokenizer

if TYPE_CHECKING:
from llama_cpp import Llama


class LlamaCppTokenizer:
def __init__(self, model: "Llama"):
self.eos_token_id = model.token_eos()
self.eos_token = model.tokenizer().decode([self.eos_token_id])
self.pad_token_id = self.eos_token_id
self.special_tokens: Set[int] = set()

self.vocabulary: Dict[str, int] = dict()

tokenizer = model.tokenizer()

self.decode = tokenizer.decode
class BaseLogitsProcessor(Protocol):
"""
Base class for logits processors which normalizes types of logits:
- ndarray (used by llama-cpp-python) is converted to
- torch.Tensor (used by everything else)
So a single process_logits() can be implemented for each processor type.
"""

# TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved
try:
self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab()
except AttributeError:
# ###
for t in range(model.n_vocab()):
token_piece = model.tokenizer().decode([t])
self.vocabulary[token_piece] = t
@abstractmethod
def process_logits(
self, input_ids: List[int], logits: torch.Tensor
) -> torch.Tensor:
...

def convert_token_to_string(self, token: str) -> str:
return token
def __call__(
self,
input_ids: Union[NDArray[np.int64], List[int], torch.Tensor],
logits: Union[NDArray[np.float32], torch.Tensor],
) -> Union[NDArray[np.int64], torch.Tensor]:
"""
Apply logits processor
Unify type
- convert input_ids: either ndarray, List[int], or Tensor -> List[int]
- convert logits: either ndarray or Tensor -> Tensor
Call process_logits() to perform business logic
"""
if not isinstance(input_ids, list):
input_ids = input_ids.tolist()

if isinstance(logits, np.ndarray):
# Unify type, convert numpy array to Tensor
# from_numpy and .numpy() don't copy the data, it uses the same memory address
torch_logits = torch.from_numpy(logits)
processed_torch_logits = self.process_logits(input_ids, torch_logits)
return processed_torch_logits.detach().numpy()
elif isinstance(logits, torch.Tensor):
return self.process_logits(input_ids, logits)
else:
raise TypeError(
"LogitsProcessor must be called with either ndarray or torch.Tensor logits"
)


class LogitsProcessor:
"""Bias LlamaCpp generation using a finite state machine.
class FSMLogitsProcessor(BaseLogitsProcessor):
"""Bias generation using a finite state machine.
Attributes
----------
Expand All @@ -78,7 +92,7 @@ class LogitsProcessor:
The finite state machine which is used to bias the logits.
"""

def __init__(self, tokenizer: LlamaCppTokenizer, fsm: Guide):
def __init__(self, tokenizer: Tokenizer, fsm: Guide):
"""A FSM-based logits processor.
Parameters
Expand All @@ -94,7 +108,7 @@ def __init__(self, tokenizer: LlamaCppTokenizer, fsm: Guide):
self._is_first_token = True

def __call__(
self, input_ids: NDArray[np.int64], scores: NDArray[np.float32]
self, input_ids: List[int], scores: torch.Tensor
) -> NDArray[np.float32]:
"""Use the FSM to bias the logits before sampling the next token.
Expand All @@ -107,7 +121,7 @@ def __call__(
Returns
-------
NDArray[np.float32]
torch.Tensor
The biased logits.
"""
if self._is_first_token:
Expand All @@ -118,19 +132,19 @@ def __call__(

allowed_tokens = self.fsm.get_next_instruction(self._fsm_state).tokens

mask = torch.full((scores.shape[-1],), -math.inf, device="cpu").numpy()
mask = torch.full((scores.shape[-1],), -math.inf, device="cpu")
mask[allowed_tokens] = 0
biased_scores = scores + mask

return biased_scores

def copy(self) -> "LogitsProcessor":
def copy(self) -> "FSMLogitsProcessor":
"""Return a copy of the logits processor."""
return LogitsProcessor(tokenizer=self.tokenizer, fsm=self.fsm.copy())
return FSMLogitsProcessor(tokenizer=self.tokenizer, fsm=self.fsm.copy())


class RegexLogitsProcessor(LogitsProcessor):
"""Bias LlamaCpp generation based on a regular expression.
class RegexLogitsProcessor(FSMLogitsProcessor):
"""Bias generation based on a regular expression.
Attributes
----------
Expand All @@ -140,23 +154,22 @@ class RegexLogitsProcessor(LogitsProcessor):
The finite state machine which is used to bias the logits.
"""

def __init__(self, regex_string: str, llm: "Llama"):
def __init__(self, regex_string: str, tokenizer: Tokenizer):
"""Compile the FSM that drives the regex-guided generation.
Parameters
----------
regex_string
A string that represents a regular expression
llm
The Llama model.
tokenizer
An Outlines tokenizer
"""
tokenizer = LlamaCppTokenizer(model=llm)
fsm = RegexGuide(regex_string, tokenizer)
super().__init__(tokenizer=tokenizer, fsm=fsm)


class JSONLogitsProcessor(RegexLogitsProcessor):
"""Bias LlamaCpp generation based on a JSON schema.
"""Bias generation based on a JSON schema.
Attributes
----------
Expand All @@ -169,7 +182,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(
self,
schema: Union[dict, Type[BaseModel], str],
llm: "Llama",
tokenizer: Tokenizer,
whitespace_pattern: Optional[str] = None,
):
"""Compile the FSM that drives the JSON-guided generation.
Expand All @@ -178,39 +191,38 @@ def __init__(
----------
schema
A JSON schema that encodes the structure we want the model to generate.
llm
The Llama model.
tokenizer
The tokenizer used to convert tokens to ids.
whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact string
literals). For example, to allow only a single space or newline with
`whitespace_pattern=r"[\n ]?"`
"""
schema_str = convert_json_schema_to_str(json_schema=schema)
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
super().__init__(regex_string=regex_string, llm=llm)
super().__init__(regex_string=regex_string, tokenizer=tokenizer)


class CFGLogitsProcessor(LogitsProcessor):
"""Bias LlamaCpp generation based on a context-free grammar.
class CFGLogitsProcessor(FSMLogitsProcessor):
"""Bias generation based on a context-free grammar.
Attributes
----------
llm
The Llama model.
tokenizer
The tokenizer used to convert tokens to ids.
fsm
The finite state machine which is used to bias the logits.
"""

def __init__(self, cfg_str: str, llm: "Llama"):
def __init__(self, cfg_str: str, tokenizer: Tokenizer):
"""Compile the FSM that drives the CFG-guided generation.
Parameters
----------
cfg_str
A string that represents a grammar
llm
The Llama model.
tokenizer
The tokenizer used to convert tokens to ids.
"""
tokenizer = LlamaCppTokenizer(model=llm)
fsm = CFGGuide(cfg_string=cfg_str, tokenizer=tokenizer)
super().__init__(tokenizer=tokenizer, fsm=fsm)
cfg_automata = CFGGuide(cfg_string=cfg_str, tokenizer=tokenizer)
super().__init__(tokenizer=tokenizer, fsm=cfg_automata)
16 changes: 16 additions & 0 deletions outlines/integrations/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,22 @@
from vllm import LLM


def get_vllm_tokenizer(llm: "LLM"):
if hasattr(llm, "get_tokenizer"):
tokenizer = llm.get_tokenizer()
elif hasattr(llm, "tokenizer"):
if hasattr(llm.tokenizer, "tokenizer"):
tokenizer = llm.tokenizer.tokenizer
else:
tokenizer = llm.tokenizer
else:
raise ValueError(
"The provided LLM instance in `RegexLogitsProcessor` neither has a "
"`tokenizer` attribute or a `get_tokenizer` method."
)
return adapt_tokenizer(tokenizer)


class RegexLogitsProcessor:
"""Bias vLLM generation based on a regular expression.
Expand Down

0 comments on commit 62cc80e

Please sign in to comment.