Skip to content

Commit

Permalink
Cache Legal-Token Mask as torch.tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 authored and rlouf committed Jul 3, 2024
1 parent 833f68f commit 25b6bcd
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 84 deletions.
106 changes: 81 additions & 25 deletions benchmarks/bench_processors.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import mlx.core as mx
import numpy as np
import torch

from outlines.processors import OutlinesLogitsProcessor
import outlines.models as models
from outlines.processors import OutlinesLogitsProcessor, RegexLogitsProcessor

try:
import mlx.core as mx
except ImportError:
pass


def is_mlx_lm_allowed():
Expand All @@ -13,40 +18,91 @@ def is_mlx_lm_allowed():
return mx.metal.is_available()


def get_mock_processor_inputs(array_library, num_tokens=30000):
"""
logits: (4, 30,000 ) dtype=float
input_ids shape: (4, 2048) dtype=int
"""
if array_library == "torch":
logits = torch.rand((4, num_tokens), dtype=torch.float)
input_ids = torch.randint(
low=0, high=num_tokens, size=(4, 2048), dtype=torch.int
)
elif array_library == "torch_cuda":
logits = torch.rand((4, num_tokens), dtype=torch.float, device="cuda")
input_ids = torch.randint(
low=0, high=num_tokens, size=(4, 2048), dtype=torch.int, device="cuda"
)
elif array_library == "numpy":
logits = np.random.rand(4, num_tokens).astype(np.float32)
input_ids = np.random.randint(low=0, high=num_tokens, size=(4, 2048))
elif array_library == "mlx":
logits = mx.random.uniform(
low=-1e9, high=1e9, shape=(4, num_tokens), dtype=mx.float32
)
input_ids = mx.random.randint(
low=0, high=num_tokens, shape=(4, 2048), dtype=mx.int32
)
else:
raise ValueError

return logits, input_ids


class HalvingLogitsProcessor(OutlinesLogitsProcessor):
"""Simply halve the passed logits"""

def process_logits(self, input_ids, logits):
return logits / 2


class LogitsProcessorBenchmark:
class LogitsProcessorPassthroughBenchmark:
"""
Benchmark the time it takes to convert between array frameworks
This should be on the order of microseconds
"""

params = ["torch", "numpy"]
if mx.metal.is_available():
if is_mlx_lm_allowed():
params += ["mlx"]
if torch.cuda.is_available():
params += ["torch_cuda"]

def setup(self, array_library):
self.logits_processor = HalvingLogitsProcessor()

# logits: (4, 30,000 ) dtype=float
# input_ids shape: (4, 2048) dtype=int
if array_library == "torch":
self.logits = torch.rand((4, 30000), dtype=torch.float)
self.input_ids = torch.randint(
low=0, high=30000, size=(4, 2048), dtype=torch.int
)
elif array_library == "numpy":
self.logits = np.random.rand(4, 30000).astype(np.float32)
self.input_ids = np.random.randint(low=0, high=30000, size=(4, 2048))
elif array_library == "mlx":
self.logits = mx.random.uniform(
low=-1e9, high=1e9, shape=(4, 30000), dtype=mx.float32
)
self.input_ids = mx.random.randint(
low=0, high=30000, shape=(4, 2048), dtype=mx.int32
)
else:
raise ValueError

def time_logits_processor(self, array_library):
self.logits, self.input_ids = get_mock_processor_inputs(array_library)

def time_passthrough(self, *params):
self.logits_processor(self.input_ids, self.logits)


class LogitsProcessorStructuredBenchmark:
"""
Benchmark structured generation mask application for single decoder pass
"""

array_libraries = ["torch", "numpy"]
if is_mlx_lm_allowed():
array_libraries += ["mlx"]
# PR TODO
if torch.cuda.is_available():
array_libraries += ["torch_cuda"]

# accept very many or very few tokens, respectively
patterns = [r"[^Z]*", "Z*"]

params = [array_libraries, patterns]
param_names = ["array_library, pattern"]

def setup(self, array_library, pattern):
tokenizer = models.transformers("facebook/opt-125m", device="cpu").tokenizer

self.logits_processor = RegexLogitsProcessor(pattern, tokenizer)

self.logits, self.input_ids = get_mock_processor_inputs(
array_library, len(tokenizer.vocabulary)
)

def time_structured_generation(self, array_library, pattern):
self.logits_processor(self.input_ids, self.logits)
10 changes: 7 additions & 3 deletions docs/community/contribute.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,21 @@ You can run the benchmark test suite locally with the following command:
asv run --config benchmarks/asv.conf.json
```

Run a specific test:
Caveats:
- If you're on a device with CUDA, you must add the argument `--launch-method spawn`
- Uncommitted code will not be benchmarked, you must first commit your changes.

#### Run a specific test:
```
asv run --config benchmarks/asv.conf.json -b bench_json_schema.JsonSchemaBenchmark.time_json_schema_to_fsm
```

Profile a specific test:
#### Profile a specific test:
```
asv run --config benchmarks/asv.conf.json --profile -b bench_json_schema.JsonSchemaBenchmark.time_json_schema_to_fsm
```

Compare to `origin/main`
#### Compare to `origin/main`
```
get fetch origin
asv continuous origin/main HEAD --config benchmarks/asv.conf.json
Expand Down
16 changes: 12 additions & 4 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, List, Optional, Protocol, Tuple, Union

import interegular
import torch
from lark import Lark

from outlines import grammars
Expand Down Expand Up @@ -146,6 +147,13 @@ def __init__(self, regex_string: str, tokenizer):
self.eos_token_id = tokenizer.eos_token_id
self.final_states = fsm_finals | {-1}

# cache returned masks token masks
# this increases performance of the mask substantially
self.states_to_token_mask = {
state: torch.tensor(list(next_tokens_to_end_states.keys()))
for state, next_tokens_to_end_states in self.states_to_token_maps.items()
}

def get_next_instruction(self, state: int) -> Instruction:
"""Return the next instruction for guided generation.
Expand All @@ -169,11 +177,11 @@ def get_next_instruction(self, state: int) -> Instruction:
A `Generate` instance that contains the model and the allowed token ids.
"""
next_tokens_to_end_states = self.states_to_token_maps.get(state)
if next_tokens_to_end_states is None:
return Write([self.eos_token_id])
next_tokens_mask = self.states_to_token_mask.get(state)
if next_tokens_mask is None:
return Write(torch.tensor([self.eos_token_id]))

return Generate(list(next_tokens_to_end_states.keys()))
return Generate(next_tokens_mask)

def get_next_state(self, state: int, token_id: int) -> int:
"""Update the state of the guide.
Expand Down
57 changes: 32 additions & 25 deletions tests/fsm/test_fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
from outlines.fsm.fsm import CFGFSM, RegexFSM, StopAtEosFSM


def assert_expected_tensor_ids(tensor, ids):
assert len(tensor) == len(ids)
norm_tensor = sorted(map(int, tensor))
norm_ids = sorted(map(int, tensor))
assert norm_tensor == norm_ids, (norm_tensor, norm_ids)


def test_stop_at_eos():
class MockTokenizer:
vocabulary = {"a": 1, "eos": 2}
Expand Down Expand Up @@ -50,7 +57,7 @@ def convert_token_to_string(self, token):
fsm = RegexFSM(regex_str, tokenizer)

assert fsm.states_to_token_maps == {0: {1: 1}}
assert fsm.allowed_token_ids(state=0) == [1]
assert_expected_tensor_ids(fsm.allowed_token_ids(state=0), [1])
assert fsm.next_state(state=0, token_id=1) == 1
assert fsm.next_state(state=0, token_id=tokenizer.eos_token_id) == -1

Expand Down Expand Up @@ -111,27 +118,27 @@ def decode(self, token_ids):
with pytest.warns(UserWarning):
fsm = CFGFSM(cfg_str, tokenizer)

assert set(fsm.allowed_token_ids(state=fsm.start_state)) == {1, 3, 5}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=fsm.start_state), [1, 3, 5])
state = fsm.next_state(state=fsm.start_state, token_id=1)
assert fsm.generation == "{"
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {1, 2, 3}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 2, 3])
state = fsm.next_state(state=state, token_id=3)
assert fsm.generation == "{["
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {1, 3, 4}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 3, 4])
state = fsm.next_state(state=state, token_id=4)
assert fsm.generation == "{[]"
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {2}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [2])
state = fsm.next_state(state=state, token_id=2)
assert fsm.generation == "{[]}"
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {5}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [5])
state = fsm.next_state(state=state, token_id=5)
assert fsm.generation == "{[]}"
assert fsm.is_final_state(state)
Expand Down Expand Up @@ -164,24 +171,24 @@ def decode(self, token_ids):
with pytest.warns(UserWarning):
fsm = CFGFSM(cfg_str, tokenizer)

assert set(fsm.allowed_token_ids(state=fsm.start_state)) == {1}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=fsm.start_state), [1])
state = fsm.next_state(state=fsm.start_state, token_id=1)
assert fsm.generation == "("
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {1, 2}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 2])
state = fsm.next_state(state=state, token_id=2)
assert fsm.generation == "()"
assert not fsm.is_final_state(state)

# possible to continue or terminate
assert set(fsm.allowed_token_ids(state=state)) == {1, 3}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 3])
state = fsm.next_state(state=state, token_id=3) # feed eos
assert fsm.generation == "()"
assert fsm.is_final_state(state)

# once eos generated, can only terminate
assert set(fsm.allowed_token_ids(state=state)) == {3}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [3])


def test_cfg_ignore_directive():
Expand Down Expand Up @@ -214,38 +221,38 @@ def decode(self, token_ids):

state = 0

assert set(fsm.allowed_token_ids(state=0)) == {1, 2}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=0), [1, 2])
state = fsm.next_state(state=0, token_id=2)
assert fsm.generation == " "
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=0)) == {1, 2}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=0), [1, 2])
state = fsm.next_state(state=0, token_id=1)
assert fsm.generation == " a"
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {1, 2, 3}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 2, 3])
state = fsm.next_state(state=state, token_id=2)
assert fsm.generation == " a "
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {1, 2, 3}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 2, 3])
state = fsm.next_state(state=state, token_id=2)
assert fsm.generation == " a "
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {1, 2, 3}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 2, 3])
state = fsm.next_state(state=state, token_id=1)
assert fsm.generation == " a a"
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {1, 2, 3}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 2, 3])
state = fsm.next_state(state=state, token_id=3)
assert fsm.generation == " a a"
assert fsm.is_final_state(state)

# once eos generated, can only terminate
assert set(fsm.allowed_token_ids(state=state)) == {3}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [3])


def test_cfg_multitoken_terminal():
Expand Down Expand Up @@ -274,19 +281,19 @@ def decode(self, token_ids):
with pytest.warns(UserWarning):
fsm = CFGFSM(cfg_str, tokenizer)

assert set(fsm.allowed_token_ids(state=fsm.start_state)) == {1, 2}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=fsm.start_state), [1, 2])
assert fsm.reset_state # starting new regex
state = fsm.next_state(state=fsm.start_state, token_id=1)
assert fsm.generation == "a"
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {1}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1])
assert not fsm.reset_state # continuing current regex
state = fsm.next_state(state=state, token_id=1)
assert fsm.generation == "aa"
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {3}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [3])
assert not fsm.reset_state # completing current regex
state = fsm.next_state(state=state, token_id=3)
assert fsm.generation == "aa"
Expand Down Expand Up @@ -319,27 +326,27 @@ def decode(self, token_ids):
with pytest.warns(UserWarning):
fsm = CFGFSM(cfg_str, tokenizer)

assert set(fsm.allowed_token_ids(state=fsm.start_state)) == {1, 3}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=fsm.start_state), [1, 3])
state = fsm.next_state(state=fsm.start_state, token_id=1)
assert fsm.generation == "("
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {1, 3}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [1, 3])
state = fsm.next_state(state=state, token_id=3)
assert fsm.generation == "(a"
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {2, 3}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [2, 3])
state = fsm.next_state(state=state, token_id=3)
assert fsm.generation == "(aa"
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {2, 3}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [2, 3])
state = fsm.next_state(state=state, token_id=2)
assert fsm.generation == "(aa)"
assert not fsm.is_final_state(state)

assert set(fsm.allowed_token_ids(state=state)) == {4}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=state), [4])
state = fsm.next_state(state=state, token_id=4)
assert fsm.generation == "(aa)"
assert fsm.is_final_state(state)
Loading

0 comments on commit 25b6bcd

Please sign in to comment.