Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a trie to speed up index construction #887

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions benchmarks/bench_regex_fsm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import random

from outlines.caching import cache_disabled
from outlines.fsm.regex import reduced_vocabulary
from outlines.models.tokenizer import Tokenizer

from .common import ensure_numba_compiled


class MockTokenizer(Tokenizer):
def __init__(self, token_strs):
self.eos_token = "<eos>"
self.eos_token_id = 0
self.pad_token_id = 1
self.special_tokens = {0, 1}

self.vocabulary = {"<eos>": 0, "<pad>": 1}

for i, tok in enumerate(token_strs):
self.vocabulary[tok] = i + 2

@classmethod
def from_random_tokens(cls, n_tokens, max_token_length=8, seed=42):
random.seed(seed)
tokens = [
"".join(
chr(random.randint(0, 4096))
for __ in range(random.randint(0, max_token_length))
)
for _ in range(n_tokens)
]
return cls(tokens)

def convert_token_to_string(self, token):
return token

def __hash__(self):
return hash(tuple(sorted(self.vocabulary.items())))


def reduced_vocabulary_uncached(*args, **kwargs):
return reduced_vocabulary.__wrapped__(*args, **kwargs)


class RegexReducedVocabularyBenchmark:
params = [10000, 100000, 1000000]
param_names = ["vocab_size"]

def setup(self, vocab_size):
ensure_numba_compiled(MockTokenizer([chr(i) for i in range(128)]))

self.tokenizer = MockTokenizer.from_random_tokens(vocab_size)

@cache_disabled()
def time_reduced_vocabulary(self, _):
reduced_vocabulary_uncached(self.tokenizer)
147 changes: 107 additions & 40 deletions outlines/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from numba.typed.typedobjectutils import _nonoptional
from tqdm import tqdm

from outlines.fsm.vocab_trie import VocabTrie

if TYPE_CHECKING:
from outlines.models.tokenizer import Tokenizer

Expand Down Expand Up @@ -664,30 +666,39 @@ def state_scan_tokens(
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
vocabulary: List[Tuple[str, Sequence[int]]],
vocabulary_transition_keys: List[Sequence[int]],
vocabulary: List[Tuple[Sequence[str], Sequence[int]]],
vocab_trie: VocabTrie,
start_state: int,
) -> Set[Tuple[int, int]]:
res = set()

for (token, token_ids), token_transition_keys in zip(
vocabulary, vocabulary_transition_keys
):
# Initialize the stack with tokens having no prefixes
stack = numba.typed.List()
for token_transitions_seq in vocab_trie.get_children():
stack.append(token_transitions_seq)

# Process the tokens using the stack
while stack:
token_transitions_seq = stack.pop()
state_seq = _walk_fsm(
fsm_transitions,
fsm_initial,
fsm_finals,
token_transition_keys,
token_transitions_seq,
start_state,
False,
)

if state_seq is not None and len(state_seq) < len(token_transition_keys):
if len(state_seq) < len(token_transitions_seq):
continue

for token_id in token_ids:
for token_id in vocab_trie.get_token_ids(token_transitions_seq):
res.add((token_id, state_seq[-1]))

# Add successors to the stack
for new_token in vocab_trie.get_children(token_transitions_seq):
stack.append(new_token)

return res


Expand Down Expand Up @@ -805,7 +816,7 @@ def create_fsm_index_end_to_end(
desc="Compiling FSM index for all state transitions",
)

vocabulary_transition_keys = get_vocabulary_transition_keys(
vocabulary_transitions = get_vocabulary_transition_keys(
fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
vocabulary,
Expand All @@ -815,18 +826,24 @@ def create_fsm_index_end_to_end(
else numba.typed.List.empty_list(numba.types.unicode_type)
),
)
vocab_trie = VocabTrie(vocabulary_transitions, vocabulary)

while next_states:
start_state = next_states.pop()

pbar.update(1)

if start_state not in seen:
seen.add(start_state)

token_ids_end_states = state_scan_tokens(
fsm_info.transitions,
fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
fsm_info.initial,
fsm_info.finals,
vocabulary,
vocabulary_transition_keys,
vocab_trie,
start_state,
)

Expand All @@ -838,10 +855,6 @@ def create_fsm_index_end_to_end(
if end_state not in seen:
next_states.add(end_state)

if start_state not in seen:
pbar.update(1)
seen.add(start_state)

pbar.close()

return states_to_token_subsets
Expand Down Expand Up @@ -887,23 +900,11 @@ def gpt2_unicode_to_bytes():
return {v: k for k, v in gpt2_bytes_to_unicode().items()}


# TODO: Cannot cache typed collections to disk, yet. See
# https://github.com/numba/numba/issues/4698
@lru_cache
def reduced_vocabulary(
tokenizer: "Tokenizer",
) -> Tuple[List[Tuple[str, Sequence[int]]], Set[int]]:
"""Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
def get_normalized_vocab(tokenizer: "Tokenizer") -> Tuple[Dict[int, str], Set[int]]:
norm_vocab = {}
empty_token_ids = set()
vocabulary: Dict[Union[str, Tuple[str, ...]], List[int]] = {}
for token, token_idx in tokenizer.vocabulary.items():
if token in tokenizer.special_tokens:
continue

token_str: Union[str, Tuple[str, ...]] = tokenizer.convert_token_to_string(
token
)

token_str = tokenizer.convert_token_to_string(token)
if token_str:
# invalid utf-8 sequences are replaced with � (\ufffd), but there
# might also be tokens specifically for �, ��, ���, etc.
Expand All @@ -927,22 +928,88 @@ def reduced_vocabulary(
)
token_str = "".join(byte_symbol(b) for b in token_bytes)

vocabulary.setdefault(token_str, []).append(token_idx)
norm_vocab[token_idx] = token_str
else:
empty_token_ids.add(numba.int64(token_idx))

vocabulary_nb = numba.typed.List.empty_list(
numba.types.Tuple(
(
nb_unicode_type,
numba.int64[:],
)
)
return norm_vocab, empty_token_ids


@numba.njit(cache=True, nogil=True)
def to_numba_dict(keys: List[int], values: List[str]):
"""
Pure-python numba dict construction is extremely slow.
This helper accepts equal length key and value arrays, and constructs a numba dict
"""
# Define the key and value types for the Numba dictionary
numba_dict = numba.typed.Dict.empty(
key_type=numba.types.int64,
value_type=numba.types.unicode_type,
)

# Fill the Numba dictionary with values from the input lists
for i in range(len(keys)):
numba_dict[keys[i]] = values[i]

return numba_dict


token_id_str_pair = numba.types.Tuple((nb_unicode_type, numba.int64[:]))


@numba.njit(
numba.types.ListType(token_id_str_pair)(
numba.types.DictType(numba.int64, nb_unicode_type)
),
cache=True,
nogil=True,
)
def vocab_dict_to_inverted_vocab_list(
vocab_dict_nb: Dict[int, str]
) -> List[Tuple[str, Sequence[int]]]:
"""
Helper for `reduced_vocabulary`

Convert
- from `vocab_dict_nb`: Dict[token_id, token_str]
- to `vocab_nb`: List[token_str, token_id[:]]
"""
inverse_vocab_dict = numba.typed.Dict.empty(
key_type=numba.types.unicode_type, value_type=numba.types.int64[:]
)
for token_str, token_ids in vocabulary.items():
token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64"))
vocabulary_nb.append((token_str, token_ids_np))

# Fill the temporary dictionary
for key in vocab_dict_nb:
value = vocab_dict_nb[key]
if value not in inverse_vocab_dict:
inverse_vocab_dict[value] = np.zeros(0, dtype=np.int64)
inverse_vocab_dict[value] = np.append(inverse_vocab_dict[value], key)

# Transfer data from the temporary dictionary to the final dictionary
vocab_nb = numba.typed.List.empty_list(token_id_str_pair)

for value in inverse_vocab_dict:
vocab_nb.append((value, inverse_vocab_dict[value]))

return vocab_nb


# TODO: Cannot cache typed collections to disk, yet. See
# https://github.com/numba/numba/issues/4698
@lru_cache
def reduced_vocabulary(
tokenizer: "Tokenizer",
) -> Tuple[List[Tuple[str, Sequence[int]]], Set[int]]:
"""
Provided the tokenizer, calculate the
- vocabulary_nb: mapping of (normalized token str -> token_ids[:])
- empty token ids
"""
norm_vocab, empty_token_ids = get_normalized_vocab(tokenizer)
norm_vocab_dict_nb = to_numba_dict(
np.fromiter(norm_vocab.keys(), dtype=np.int64), list(norm_vocab.values())
)
vocabulary_nb = vocab_dict_to_inverted_vocab_list(norm_vocab_dict_nb)
return vocabulary_nb, empty_token_ids


Expand Down
Loading
Loading