Skip to content

Commit

Permalink
Integrate with llama.cpp using logit processors
Browse files Browse the repository at this point in the history
  • Loading branch information
dtiarks authored and rlouf committed Jan 26, 2024
1 parent 5d67a5a commit b20d502
Show file tree
Hide file tree
Showing 10 changed files with 310 additions and 225 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ __pycache__
docs/build
.coverage
.idea/
*.gguf
2 changes: 1 addition & 1 deletion examples/llamacpp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class Character(BaseModel):

if __name__ == "__main__":
# Download model from https://huggingface.co/TheBloke/phi-2-GGUF
model = outlines.models.llamacpp("./phi-2.Q3_K_M.gguf", device="cpu")
model = outlines.models.llamacpp("./phi-2.Q4_K_M.gguf", device="cpu")

# Construct guided sequence generator
generator = outlines.generate.json(model, Character, max_tokens=512)
Expand Down
49 changes: 49 additions & 0 deletions examples/llamacpp_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from enum import Enum

from llama_cpp import Llama, LogitsProcessorList
from pydantic import BaseModel, constr

from outlines.models.llamacpp import LlamaCppTokenizer, LlamaJSONLogitsProcessor


class Weapon(str, Enum):
sword = "sword"
axe = "axe"
mace = "mace"
spear = "spear"
bow = "bow"
crossbow = "crossbow"


class Armor(str, Enum):
leather = "leather"
chainmail = "chainmail"
plate = "plate"


class Character(BaseModel):
name: constr(max_length=10)
age: int
armor: Armor
weapon: Weapon
strength: int


if __name__ == "__main__":
llama = Llama("./phi-2.Q4_K_M.gguf")
tokenizer = LlamaCppTokenizer(llama)

prompt = "Instruct: You are a leading role play gamer. You have seen thousands of different characters and their attributes.\nPlease return a JSON object with common attributes of an RPG character. Give me a character description\nOutput:"

logits_processor = LlamaJSONLogitsProcessor(Character, tokenizer)

json_str = llama.create_completion(
prompt,
top_k=40,
top_p=0.95,
temperature=0.7,
max_tokens=100,
logits_processor=LogitsProcessorList([logits_processor]),
)["choices"][0]["text"]

print(json_str)
22 changes: 22 additions & 0 deletions outlines/generate/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

from outlines.fsm.fsm import CFGFSM
from outlines.generate.api import SequenceGenerator
from outlines.generate.processors import CFGLogitsProcessor
from outlines.generate.samplers import Sampler, multinomial
from outlines.models import OpenAI
from outlines.models.llamacpp import LlamaCpp


@singledispatch
Expand All @@ -25,6 +27,26 @@ def cfg(
return generator


@cfg.register(LlamaCpp)
def cfg_llamacpp(
model: LlamaCpp,
cfg_str: str,
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
sampler: Sampler = multinomial,
):
if not sampler == multinomial:
raise NotImplementedError(
r"The llama.cpp integration does not currently support any other sampling algorithm "
+ "that the multinomial sampler."
)

logits_processor = CFGLogitsProcessor(cfg_str, model.tokenizer)
model.logits_processor = logits_processor

return model


@cfg.register(OpenAI)
def cfg_openai(
model,
Expand Down
31 changes: 31 additions & 0 deletions outlines/generate/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from outlines.fsm.json_schema import build_regex_from_object, get_schema_from_signature
from outlines.generate.samplers import Sampler, multinomial
from outlines.models import OpenAI
from outlines.models.llamacpp import LlamaCpp

from .regex import regex

Expand Down Expand Up @@ -43,6 +44,36 @@ def json(
return generator


@regex.register(LlamaCpp)
def json_llamacpp(
model,
schema_object: Union[str, object, Callable],
max_tokens: Optional[int] = None,
sampler: Sampler = multinomial,
):
if isinstance(schema_object, type(BaseModel)):
schema = pyjson.dumps(schema_object.model_json_schema())
regex_str = build_regex_from_object(schema)
elif callable(schema_object):
schema = pyjson.dumps(get_schema_from_signature(schema_object))
regex_str = build_regex_from_object(schema)
elif isinstance(schema_object, str):
schema = schema_object
regex_str = build_regex_from_object(schema)
else:
raise ValueError(
f"Cannot parse schema {schema_object}. The schema must be either "
+ "a Pydantic object, a function or a string that contains the JSON "
+ "Schema specification"
)

# TODO: format the output
# We should be able to use the same interface as transformers and make this
# function redundant by adding a `format_sequence` method to `LlamaCpp`

return regex(model, regex_str, max_tokens, sampler)


@json.register(OpenAI)
def json_openai(
model,
Expand Down
100 changes: 100 additions & 0 deletions outlines/generate/processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import json
import math
from typing import Union

import numpy as np
import torch
from numpy.typing import NDArray

from outlines.fsm.fsm import CFGFSM, FSM, FSMState, RegexFSM
from outlines.fsm.json_schema import build_regex_from_object
from outlines.models.tokenizer import Tokenizer


class LogitsProcessor:
def __init__(self, tokenizer: Tokenizer, fsm: FSM):
"""Super class for logit processors.
Parameters
----------
tokenizer
An instance of `Tokenizer`
"""
self.tokenizer = tokenizer
self.fsm_state: FSMState = None # type: ignore
self.fsm: FSM = fsm

def __call__(
self, input_ids: NDArray[np.int64], scores: NDArray[np.float32]
) -> NDArray[np.float32]:
"""Use the FSM to bias the logits before sampling the next token."""
if self.fsm is None:
raise NotImplementedError()

if self.fsm_state is None:
self.fsm_state = FSMState(0)
else:
last_token = input_ids[-1]
self.fsm_state = self.fsm.next_state(self.fsm_state, last_token)

allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state)

mask = torch.full((scores.shape[-1],), -math.inf, device="cpu").numpy()
mask[allowed_tokens] = 0
biased_scores = scores + mask

biased_scores[self.tokenizer.eos_token_id] = 0

return biased_scores


class RegexLogitsProcessor(LogitsProcessor):
def __init__(self, regex_string: str, tokenizer: Tokenizer):
"""Compile the FSM that drives the regex-guided generation.
Parameters
----------
regex_string
A string that represents a regular expression
tokenizer
An instance of `Tokenizer`
"""
fsm = RegexFSM(regex_string, tokenizer)
super().__init__(tokenizer, fsm)


class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema: Union[str, dict], tokenizer: Tokenizer):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
----------
schema
A JSON schema that encodes the structure we want the model to generate
tokenizer
An instance of `Tokenizer`
"""
# TODO: Why is this needed? We are using regexes
if isinstance(schema, dict):
schema = json.dumps(schema)
regex_string = build_regex_from_object(schema)
super().__init__(regex_string, tokenizer)


class CFGLogitsProcessor(LogitsProcessor):
def __init__(self, cfg_str: str, tokenizer: Tokenizer):
"""Compile the FSM that drives the CFG-guided generation.
Parameters
----------
cfg_str
A string that represents a grammar
tokenizer
An instance of `Tokenizer`
"""
fsm = CFGFSM(cfg_str, tokenizer)
super().__init__(tokenizer, fsm)
23 changes: 22 additions & 1 deletion outlines/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

from outlines.fsm.fsm import RegexFSM
from outlines.generate.api import SequenceGenerator
from outlines.generate.processors import RegexLogitsProcessor
from outlines.generate.samplers import Sampler, multinomial
from outlines.models import OpenAI
from outlines.models.llamacpp import LlamaCpp


@singledispatch
Expand All @@ -22,9 +24,28 @@ def regex(
return generator


@regex.register(LlamaCpp)
def regex_llamacpp(
model: LlamaCpp,
regex_str: str,
max_tokens: Optional[int] = None,
sampler: Sampler = multinomial,
):
if not sampler == multinomial:
raise NotImplementedError(
r"The llama.cpp integration does not currently support any other sampling algorithm "
+ "that the multinomial sampler."
)

logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer)
model.logits_processor = logits_processor

return model


@regex.register(OpenAI)
def regex_openai(
model,
model: OpenAI,
regex_str: str,
max_tokens: Optional[int] = None,
sampler: Sampler = multinomial,
Expand Down
24 changes: 18 additions & 6 deletions outlines/generate/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from outlines.generate import SequenceGenerator
from outlines.generate.samplers import Sampler, multinomial
from outlines.models import OpenAI
from outlines.models.llamacpp import LlamaCpp


@singledispatch
Expand All @@ -14,7 +15,6 @@ def text(
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
*,
samples: int = 1,
sampler: Sampler = multinomial,
) -> SequenceGenerator:
"""Generate text with a `Transformer` model.
Expand Down Expand Up @@ -43,11 +43,6 @@ def text(
A `SequenceGenerator` instance that generates text.
"""
if samples > 1:
raise NotImplementedError(
"It is currently impossible to generate several samples with `transformers` models."
)

fsm = StopAtEosFSM(model.tokenizer)

device = model.device
Expand All @@ -58,6 +53,23 @@ def text(
return generator


@text.register(LlamaCpp)
def text_llamacpp(
model: LlamaCpp,
max_tokens: Optional[int] = None,
stop_at: Optional[Union[List[str], str]] = None,
*,
sampler: Sampler = multinomial,
):
if not sampler == multinomial:
raise NotImplementedError(
r"The OpenAI API does not support any other sampling algorithm "
+ "that the multinomial sampler."
)

return model


@text.register(OpenAI)
def text_openai(
model: OpenAI,
Expand Down
Loading

0 comments on commit b20d502

Please sign in to comment.