From c7b3cc8df79b81e22513d42e11d129ff1310ff78 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 6 Sep 2023 16:47:03 -0500 Subject: [PATCH] Refactor Regex and introduce Numba-based FSM utilities --- examples/parsing.py | 17 - outlines/text/fsm.py | 697 ++++++++++++++++++++++++++++++ outlines/text/generate/regex.py | 237 ++++++---- outlines/text/parsing.py | 421 +----------------- pyproject.toml | 5 +- tests/text/generate/test_regex.py | 45 ++ tests/text/test_fsm.py | 353 +++++++++++++++ tests/text/test_parsing.py | 367 +--------------- 8 files changed, 1259 insertions(+), 883 deletions(-) create mode 100644 outlines/text/fsm.py create mode 100644 tests/text/test_fsm.py diff --git a/examples/parsing.py b/examples/parsing.py index 1e6e05c41..f1b71cba3 100644 --- a/examples/parsing.py +++ b/examples/parsing.py @@ -26,23 +26,6 @@ checkpoint, trust_remote_code=True, revision=revision ).to(device) -# import urllib.request -# -# sql_grammar_url = "https://github.com/zbrookle/sql_to_ibis/raw/0e9226da42065940ce21439d490f9fcacadc7f92/sql_to_ibis/grammar/sql.lark" -# sql_grammar = "".join( -# [line.decode("utf-8") for line in urllib.request.urlopen(sql_grammar_url)] -# ) -# with open("sql_grammar.lark", "w") as f: -# f.write(sql_grammar) -# -# TODO: `_STRING_ESC_INNER` from `%import common.ESCAPED_STRING` introduces a -# (potentially superfluous) look-back; we need to replace it or implement -# look-backs. -# parser = PartialLark.open( -# "sql_grammar.lark", -# parser="lalr", -# ) - parser = PartialLark.open_from_package( "tests", "partial_python.lark", diff --git a/outlines/text/fsm.py b/outlines/text/fsm.py new file mode 100644 index 000000000..5119cf047 --- /dev/null +++ b/outlines/text/fsm.py @@ -0,0 +1,697 @@ +from itertools import chain +from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Set, Tuple + +import numba +from interegular.fsm import FSM, Alphabet, OblivionError, anything_else +from joblib import Parallel, delayed +from numba.experimental import structref +from numba.typed.typedobjectutils import _nonoptional + +if TYPE_CHECKING: + from outlines.models.tokenizer import Tokenizer + + +class BetterAlphabet(Alphabet): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.has_anything_else = anything_else in self._symbol_mapping + if self.has_anything_else: + self.anything_value = self._symbol_mapping[anything_else] + else: + self.anything_value = None + + def __getitem__(self, item): + return self._symbol_mapping.get(item, self.anything_value) + + def copy(self): + return BetterAlphabet(self._symbol_mapping.copy()) + + +class BetterFSM(FSM): + flat_transition_map: Dict[Tuple[int, int], int] + trans_key_to_states: Dict[int, List[int]] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if not isinstance(self.alphabet, BetterAlphabet): + self.__dict__["alphabet"] = BetterAlphabet(self.alphabet._symbol_mapping) + + flat_transition_map = {} + trans_key_to_states = {} + for from_state, trans_map in self.map.items(): + for trans_key, to_state in trans_map.items(): + flat_transition_map[(from_state, trans_key)] = to_state + trans_key_to_states.setdefault(trans_key, set()).add(from_state) + + self.__dict__["trans_key_to_states"] = trans_key_to_states + self.__dict__["flat_transition_map"] = flat_transition_map + self.__dict__["_fsm_info"] = None + + def copy(self): + return BetterFSM( + alphabet=self.alphabet.copy(), + states=self.states.copy(), + initial=self.initial, + finals=self.finals.copy(), + map=self.map.copy(), + __no_validation__=True, + ) + + @property + def fsm_info(self): + if self._fsm_info is None: + trans_key_to_states = numba.typed.Dict.empty( + numba.int64, numba.types.ListType(numba.int64) + ) + for trans_key, states in self.trans_key_to_states.items(): + new_states = numba.typed.List.empty_list(numba.int64) + for state in states: + new_states.append(numba.int64(state)) + trans_key_to_states[numba.int64(trans_key)] = new_states + + flat_transition_map = numba.typed.Dict.empty( + numba.types.UniTuple(numba.int64, 2), numba.int64 + ) + for trans_key, state in self.flat_transition_map.items(): + flat_transition_map[ + (numba.int64(trans_key[0]), numba.int64(trans_key[1])) + ] = numba.int64(state) + + alphabet_symbol_map = numba.typed.Dict.empty( + numba.types.string, numba.int64 + ) + for symbol, trans_key in self.alphabet._symbol_mapping.items(): + if symbol is not anything_else: + alphabet_symbol_map[symbol] = numba.int64(trans_key) + + initial = numba.int64(self.initial) + + finals = numba.typed.List.empty_list(numba.int64) + for final in self.finals: + finals.append(numba.int64(final)) + + anything_value = numba.int64(self.alphabet.anything_value) + + self.__dict__["_fsm_info"] = FSMInfo( + initial, + finals, + flat_transition_map, + trans_key_to_states, + anything_value, + alphabet_symbol_map, + ) + + return self._fsm_info + + +spec = [ + ("initial", numba.int64), + ("finals", numba.types.Set(numba.int64)), + ( + "transitions", + numba.types.DictType(numba.types.UniTuple(numba.int64, 2), numba.int64), + ), + ( + "trans_key_to_states", + numba.types.DictType(numba.int64, numba.types.ListType(numba.int64)), + ), + ("alphabet_anything_value", numba.optional(numba.int64)), + ("alphabet_symbol_mapping", numba.types.DictType(numba.types.string, numba.int64)), +] + + +@structref.register +class FSMInfoType(numba.types.StructRef): + def preprocess_fields(self, fields): + return tuple((name, numba.types.unliteral(typ)) for name, typ in fields) + + +class FSMInfo(structref.StructRefProxy): + def __new__( + cls, + initial, + finals, + transitions, + trans_key_to_states, + alphabet_anything_value, + alphabet_symbol_mapping, + ): + return structref.StructRefProxy.__new__( + cls, + initial, + finals, + transitions, + trans_key_to_states, + alphabet_anything_value, + alphabet_symbol_mapping, + ) + + @property + def initial(self): + return FSMInfo_get_initial(self) + + @property + def finals(self): + return FSMInfo_get_finals(self) + + @property + def transitions(self): + return FSMInfo_get_transitions(self) + + @property + def trans_key_to_states(self): + return FSMInfo_get_trans_key_to_states(self) + + @property + def alphabet_anything_value(self): + return FSMInfo_get_alphabet_anything_value(self) + + @property + def alphabet_symbol_mapping(self): + return FSMInfo_get_alphabet_symbol_mapping(self) + + +@numba.njit(nogil=True, inline="always") +def FSMInfo_get_initial(self): + return self.initial + + +@numba.njit(nogil=True, inline="always") +def FSMInfo_get_finals(self): + return self.finals + + +@numba.njit(nogil=True, inline="always") +def FSMInfo_get_transitions(self): + return self.transitions + + +@numba.njit(nogil=True, inline="always") +def FSMInfo_get_trans_key_to_states(self): + return self.trans_key_to_states + + +@numba.njit(nogil=True, inline="always") +def FSMInfo_get_alphabet_anything_value(self): + return self.alphabet_anything_value + + +@numba.njit(nogil=True, inline="always") +def FSMInfo_get_alphabet_symbol_mapping(self): + return self.alphabet_symbol_mapping + + +structref.define_proxy(FSMInfo, FSMInfoType, [name for name, _ in spec]) +FSMInfo_type = FSMInfoType(fields=spec) + + +def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]: + """Construct an equivalent FSM with deterministic state labels.""" + old_to_new_trans_keys = { + trans_key: i + for i, (trans_key, _) in enumerate( + sorted(fsm.alphabet.by_transition.items(), key=lambda x: sorted(x[1])) + ) + } + + new_symbol_mapping = { + symbol: old_to_new_trans_keys[trans_key] + for symbol, trans_key in fsm.alphabet._symbol_mapping.items() + } + + new_alphabet = BetterAlphabet(new_symbol_mapping) + + new_map = { + from_state: { + old_to_new_trans_keys[trans_key]: to_state + for trans_key, to_state in trans_map.items() + } + for from_state, trans_map in fsm.map.items() + } + + old_to_new_states = {} + old_to_new_states[fsm.initial] = 0 + + i = 0 + seen = {fsm.initial} + old_state_queue = [fsm.initial] + while old_state_queue: + old_state = old_state_queue.pop(-1) + transitions = new_map[old_state] + sorted_transitions = sorted(transitions.items(), key=lambda v: v[0]) + for _, old_state in sorted_transitions: + if old_state not in seen: + old_state_queue.append(old_state) + seen.add(old_state) + if old_state not in old_to_new_states: + i += 1 + old_to_new_states[old_state] = i + + new_map = dict( + sorted( + ( + ( + old_to_new_states[from_state], + dict( + sorted( + ( + (trans_key, old_to_new_states[to_state]) + for trans_key, to_state in trans_map.items() + ), + key=lambda v: v[0], + ) + ), + ) + for from_state, trans_map in new_map.items() + ), + key=lambda v: v[0], + ) + ) + + new_initial = 0 + new_finals = frozenset( + sorted(old_to_new_states[old_state] for old_state in fsm.finals) + ) + new_states = frozenset(sorted(new_map.keys())) + + new_fsm = BetterFSM(new_alphabet, new_states, new_initial, new_finals, new_map) + + return new_fsm, old_to_new_states + + +@numba.njit(nogil=True, cache=True) +def walk_fsm( + fsm_info: BetterFSM, + input_string: str, + start_state: int, + full_match: bool = True, +) -> List[int]: + state = fsm_info.initial + accepted_states: List[int] = numba.typed.List.empty_list(numba.int64) + last_final_idx = -1 + + # Apparently `fsm.alphabet.get` is incredibly slow, so we need to reproduce + # it here with the following: + alphabet_symbol_mapping = fsm_info.alphabet_symbol_mapping + anything_value = fsm_info.alphabet_anything_value + + for i, symbol in enumerate(input_string): + # Again, this is the logic from `fsm.alphabet.get` + trans_key = alphabet_symbol_mapping.get(symbol, anything_value) + + if state == fsm_info.initial: + new_state = fsm_info.transitions.get((start_state, trans_key)) + else: + new_state = fsm_info.transitions.get((state, trans_key)) + + if new_state is None: + if full_match: + if state in fsm_info.finals: + break + elif last_final_idx > -1: + accepted_states = accepted_states[: last_final_idx + 1] + break + + return numba.typed.List.empty_list(numba.int64) + + state = new_state + + if state in fsm_info.finals: + last_final_idx = i + + accepted_states.append(_nonoptional(state)) + + terminated = state in fsm_info.finals + if not terminated and state == fsm_info.initial: + return numba.typed.List.empty_list(numba.int64) + + return accepted_states + + +# TODO FIXME: Can't cache this due to https://github.com/numba/numba/issues/9177 +@numba.njit(nogil=True) +def find_partial_matches( + fsm_info: FSMInfo, + input_string: str, + full_match: bool = True, +) -> Generator[Tuple[int, List[int]], None, None]: + """Find the states in the finite state machine `fsm_info` that accept `input_string`. + + This will consider all possible states in the finite state machine (FSM) + that accept the beginning of `input_string` as starting points, unless a + specific `start_state` is provided. + + Parameters + ---------- + fsm_info + The finite state machine. + input_string + The string for which we generate partial matches. + full_match + Matches must cover the entire string. + + Returns + ------- + A set of tuples corresponding to each valid starting state in the FSM. The + first element of each tuple contains an integer indicating the position in + `input_string` at which the FSM stopped. The second element is the tuple + of states visited during execution of the FSM plus the next, unvisited + transition state. + + """ + + if len(input_string) == 0: + return + + trans_key = fsm_info.alphabet_symbol_mapping.get( + input_string[0], fsm_info.alphabet_anything_value + ) + + for state in fsm_info.trans_key_to_states.get( + trans_key, numba.typed.List.empty_list(numba.int64) # type: ignore + ): + path = walk_fsm(fsm_info, input_string, state, full_match=full_match) + if path: + path.insert(0, state) + res = (len(path) - 2, path) + yield res + + +@numba.njit(nogil=True, cache=True) +def process_token_string( + fsm_info: FSMInfo, + token: str, + token_idx: int, + final_state_string: Optional[str] = None, +) -> Set[Tuple[int, int]]: + res = set() + vocab_string_len = len(token) + + for end_idx, state_seq in find_partial_matches(fsm_info, token): + if end_idx is not None and end_idx < vocab_string_len - 1: + continue + + res.add((state_seq[0], token_idx)) + + if token == final_state_string: + # Allow transitions to EOS from all terminals FSM states + for state in fsm_info.finals: + res.add((state, token_idx)) + + return res + + +def create_fsm_index( + fsm_info: FSMInfo, + vocabulary: Dict[str, int], + final_state_string: Optional[str] = None, + n_jobs=-1, +) -> Dict[int, Set[int]]: + """Construct a map from FSM states to subsets of `vocabulary`. + + The subsets of `vocabulary` consist of elements that are accepted by--or + transition to--the corresponding partial parse states. + + Parameters + ---------- + fsm + The finite-state machine. + vocabulary + The vocabulary composed of token strings mapped to token IDs. + final_state_string + A string from `vocabulary` that is to be added to all the final states + in the FSM (e.g. ``""``). + """ + + results = Parallel(backend="threading", n_jobs=n_jobs, return_as="generator")( + delayed(process_token_string)(fsm_info, token, token_idx, final_state_string) + for token, token_idx in vocabulary.items() + ) + + states_to_token_subsets: Dict[int, Set[int]] = {} + + for fsm_state, token_idx in chain.from_iterable(results): + states_to_token_subsets.setdefault(fsm_state, set()).add(token_idx) + + return states_to_token_subsets + + +def fsm_union( + fsms: Sequence[FSM], +) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]: + """Construct an FSM representing the union of the FSMs in `fsms`. + + This is an updated version of `interegular.fsm.FSM.union` made to return an + extra map of component FSMs to the sets of state transitions that + correspond to them in the new FSM. + + """ + + alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms]) + + indexed_fsms = tuple(enumerate(fsms)) + + initial = {i: fsm.initial for (i, fsm) in indexed_fsms} + + # Dedicated function accepting a "superset" and returning the next + # "superset" obtained by following this transition in the new FSM + def follow(current_state, new_transition: int): + next = {} + for i, f in indexed_fsms: + old_transition = new_to_old[i][new_transition] + if ( + i in current_state + and current_state[i] in f.map + and old_transition in f.map[current_state[i]] + ): + next[i] = f.map[current_state[i]][old_transition] + if not next: + raise OblivionError + return next + + states = [initial] + finals: Set[int] = set() + map: Dict[int, Dict[int, int]] = {} + + # Map component FSMs to their new state-to-state transitions, finals, and a + # map translating component FSM states to aggregate FSM states + fsms_to_trans_finals: Dict[ + int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]] + ] = {} + + i = 0 + while i < len(states): + state = states[i] + + # Add to the finals of the aggregate FSM whenever we hit a final in a + # component FSM + if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms): + finals.add(i) + + # Compute the map for this state + map[i] = {} + for transition in alphabet.by_transition: + try: + next = follow(state, transition) + except OblivionError: + # Reached an oblivion state; don't list it + continue + else: + try: + # TODO: Seems like this could--and should--be avoided + j = states.index(next) + except ValueError: + j = len(states) + states.append(next) + + map[i][transition] = j + + for fsm_id, fsm_state in next.items(): + ( + fsm_transitions, + fsm_finals, + fsm_old_to_new, + ) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {})) + old_from = state[fsm_id] + old_to = fsm_state + fsm_old_to_new.setdefault(old_from, set()).add(i) + fsm_old_to_new.setdefault(old_to, set()).add(j) + fsm_transitions.add((i, j)) + if fsm_state in fsms[fsm_id].finals: + fsm_finals.add(j) + + i += 1 + + fsm = FSM( + alphabet=alphabet, + states=range(len(states)), + initial=0, + finals=finals, + map=map, + __no_validation__=True, + ) + + fsm, old_to_new_states = make_deterministic_fsm(fsm) + _fsms_to_trans_finals = { + fsm_id: ( + {(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions}, + {old_to_new_states[s] for s in finals}, + { + old_state: {old_to_new_states[new_state] for new_state in new_states} + for old_state, new_states in old_to_new.items() + }, + ) + for fsm_id, (transitions, finals, old_to_new) in sorted( + fsms_to_trans_finals.items(), key=lambda x: x[0] + ) + } + + return ( + fsm, + _fsms_to_trans_finals, + ) + + +def get_sub_fsms_from_seq( + state_seq: Sequence[int], + fsms_to_trans_finals: Dict[ + int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]] + ], +) -> Generator[Tuple[int, bool, bool], None, None]: + """Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`. + + Parameters + ---------- + state_seq + A state sequence. + fsms_to_trans_finals + A map from FSM indices to tuples containing sets of their state transitions + and sets of the final/accept states. + + Returns + ------- + A generator returning tuples containing each sub-FSM index (in the order + they were union-ed to construct `fsm`) and booleans indicating whether or + not there is another valid transition from the last state in the sequence + for the associated sub-FSM (i.e. if the FSM can continue + accepting/matching) and whether or not the sequence ends in a final state + of the sub-FSM. + """ + state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:])) + last_fsm_state = state_seq[-1] + yield from ( + ( + # The sub-FMS index + fsm_idx, + # Is there another possible transition in this sub-FSM? + any(last_fsm_state == from_s for (from_s, to_s) in transitions), + # Is this sub-FSM in a final state? + state_seq[-1] in finals, + ) + for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items() + if state_seq_transitions.issubset(transitions) + ) + + +@numba.njit(cache=True, nogil=True) +def state_scan_tokens( + fsm_info: FSMInfo, vocabulary: Dict[str, List[int]], start_state: int +) -> Set[Tuple[int, int]]: + res = set() + + for token, token_ids in vocabulary.items(): + state_seq = walk_fsm(fsm_info, token, start_state) + + if state_seq is not None and len(state_seq) < len(token): + continue + + for token_id in token_ids: + res.add((token_id, state_seq[-1])) + + return res + + +def create_fsm_index_end_to_end( + fsm_info: FSMInfo, + vocabulary: Dict[str, List[int]], +) -> Dict[int, Set[Tuple[int, int]]]: + """Create an FSM state-to-vocabulary map/index through end-to-end token parsing.""" + + # TODO: Consider using a `List` of `Set`s instead; that way we can JIT this + # code, too. + states_to_token_subsets: Dict[int, Set[Tuple[int, int]]] = {} + seen: Set[int] = set() + next_states = {fsm_info.initial} + + while next_states: + start_state = next_states.pop() + + token_ids_end_states = state_scan_tokens(fsm_info, vocabulary, start_state) + + for token_id_and_end_state in token_ids_end_states: + states_to_token_subsets.setdefault(start_state, set()).add( + token_id_and_end_state + ) + end_state = token_id_and_end_state[1] + if end_state not in seen: + next_states.add(end_state) + + seen.add(start_state) + + return states_to_token_subsets + + +# TODO: Cache these? +def reduced_vocabulary(tokenizer: "Tokenizer"): + """Create a map from decoded vocabulary tokens to lists of equivalent token ids.""" + vocabulary = numba.typed.Dict.empty( + numba.types.string, numba.types.ListType(numba.int64) + ) + empty_token_ids = set() + for token, token_idx in tokenizer.vocabulary.items(): + if token in tokenizer.special_tokens: + continue + + token_str = tokenizer.convert_token_to_string(token) + + if token_str: + vocabulary.setdefault( + token_str, + numba.typed.List.empty_list(numba.int64), + ).append(numba.int64(token_idx)) + else: + empty_token_ids.add(token_idx) + + return vocabulary, empty_token_ids + + +def create_fsm_index_tokenizer( + fsm: BetterFSM, + tokenizer: "Tokenizer", +) -> Tuple[Dict[int, Dict[int, int]], Set[int]]: + """Construct an FMS index from a tokenizer. + + This uses the end-to-end approach of `create_fsm_index_end_to_end`. + + .. warning:: + + `fsm` needs to be deterministically ordered so that the caching makes sense. + + """ + vocabulary, empty_token_ids = reduced_vocabulary(tokenizer) + + states_to_token_subsets = create_fsm_index_end_to_end(fsm.fsm_info, vocabulary) + + # Allow transitions to EOS from all terminals FSM states that are + # reachable + # TODO: Do we really need this anymore? + for state in fsm.fsm_info.finals: + subset = states_to_token_subsets.get(state) + if subset is not None: + subset.add((tokenizer.eos_token_id, state)) + + # Convert to token-to-end-state maps + states_to_token_subsets = {k: dict(v) for k, v in states_to_token_subsets.items()} + + return states_to_token_subsets, empty_token_ids diff --git a/outlines/text/generate/regex.py b/outlines/text/generate/regex.py index 77cdbeb7e..65917ccc8 100644 --- a/outlines/text/generate/regex.py +++ b/outlines/text/generate/regex.py @@ -1,19 +1,14 @@ -import collections import math from json import dumps -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import interegular import torch from pydantic import BaseModel +from outlines.text.fsm import create_fsm_index_tokenizer, make_deterministic_fsm from outlines.text.generate.continuation import Continuation from outlines.text.json_schema import build_regex_from_schema -from outlines.text.parsing import ( - find_partial_matches, - make_deterministic_fsm, - map_partial_states_to_vocab, -) class Regex(Continuation): @@ -29,51 +24,48 @@ class Regex(Continuation): """ - def __init__(self, model, regex_string: str, max_tokens: Optional[int]): - super().__init__(model, max_tokens) + def __init__( + self, + model, + regex_string: str, + max_tokens: Optional[int], + allow_empty_tokens: bool = True, + ): + """ + + Parameters + ---------- + regex_string + The regex with which the token sampling process is guided/constrained. + max_tokens + The maximum number of tokens to be sampled. + allow_empty_tokens + Allow sampling of tokens corresponding to empty strings. - vocabulary = model.tokenizer.vocabulary - sorted_vocabulary = [ - model.tokenizer.convert_token_to_string(k) - for k, v in sorted(vocabulary.items(), key=lambda kv: kv[1]) - ] + """ + super().__init__(model, max_tokens) + self.allow_empty_tokens = allow_empty_tokens regex_pattern = interegular.parse_pattern(regex_string) self.regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - def partial_match_filter(string, end_idx, state_seq): - if end_idx is not None and end_idx < len(string) - 1: - return False - return True - - pstate_to_vocab, paths = map_partial_states_to_vocab( - list(sorted_vocabulary), - {"REGEX": self.regex_fsm}, - partial_match_filter, - final_state_string=model.tokenizer.eos_token, + self.states_to_token_maps, self.empty_token_ids = create_fsm_index_tokenizer( + self.regex_fsm, model.tokenizer ) # Check whether a terminal path (from the initial state of the FSM to # one of its terminal states) exists, raise an exception otherwise. - traversed_states = set() - queue = collections.deque([self.regex_fsm.initial]) - while queue: - symbol = queue.popleft() - for prev_state in paths["REGEX"].get(symbol, ()): - if prev_state not in traversed_states: - traversed_states.add(prev_state) - queue.append(prev_state) - - if traversed_states.intersection(self.regex_fsm.finals) == set(): + if not any( + self.regex_fsm.finals.intersection(v.values()) + for v in self.states_to_token_maps.values() + ): raise ValueError( "The vocabulary does not allow us to build a sequence that matches the input regex" ) - self.pstate_to_vocab = {k: list(v) for k, v in pstate_to_vocab.items()} - # These tuples are comprised of the FSM name, last FSM state, and - # number of processed tokens. # When an EOS is observed, the last FSM state becomes `-1`. - self.pstates: List[Tuple[str, int, int]] = [] + self.last_fsm_states: List[int] = [] + self.mask_cache: Dict[Tuple[int, int], torch.LongTensor] = {} def create_proposal( self, generated_token_ids: torch.LongTensor, logits: torch.DoubleTensor @@ -89,72 +81,106 @@ def create_proposal( """ - if len(self.pstates) == 0: - self.pstates = [ - ("REGEX", self.regex_fsm.initial, 0) - for _ in range(generated_token_ids.shape[0]) + assert generated_token_ids.ndim == 2 + + if len(self.last_fsm_states) == 0: + self.last_fsm_states = [ + self.regex_fsm.initial for _ in range(generated_token_ids.shape[0]) ] - if generated_token_ids.shape[-1] > 0: - new_pstates = [] - for token_seq, (_, last_fsm_state, last_token_idx) in zip( + masks = [] + + for i, (token_seq, last_state) in enumerate( + zip( generated_token_ids, - self.pstates, - ): - # Get the tokens we haven't already processed, - readable_tokens = token_seq[last_token_idx:] - # excluding any EOS tokens. - not_eos_mask = [ - tk != self.model.tokenizer.eos_token_id for tk in readable_tokens - ] - readable_tokens = readable_tokens[not_eos_mask] - if len(readable_tokens) > 0: + self.last_fsm_states, + ) + ): + if token_seq.shape[0] > 0: + # Get the last token that was sampled + last_token = int(token_seq[-1]) + + if last_token in self.empty_token_ids: + # An empty token was sampled, so the FSM state hasn't changed + next_state = last_state + next_token_ids = list(self.states_to_token_maps[last_state].keys()) + + elif last_token != self.model.tokenizer.eos_token_id: # If we previously ended with an EOS, we shouldn't be # getting/sampling any more non-EOS tokens. - assert last_fsm_state > -1 + assert last_state > -1 - sequence = self.model.tokenizer.decode(readable_tokens) + last_token_to_end_state = self.states_to_token_maps[last_state] - ((_, state_seq),) = find_partial_matches( - self.regex_fsm, - "".join(sequence), - start_state=last_fsm_state, - ) - pstate = ( - "REGEX", - state_seq[-1], - last_token_idx + len(sequence), + next_state = last_token_to_end_state[last_token] + + next_tokens_to_end_states = self.states_to_token_maps.get( + next_state ) + + if next_tokens_to_end_states is None: + # If there are no transitions from the current state, + # then we must've been in a final state of the FSM. + # We produce EOS tokens from here on. + assert next_state in self.regex_fsm.finals + next_state = -1 + next_token_ids = [self.model.tokenizer.eos_token_id] + else: + next_token_ids = list(next_tokens_to_end_states.keys()) else: - pstate = ("REGEX", -1, last_token_idx) + # Since we already have an EOS, only sample EOS tokes from + # here on. + next_state = -1 + next_token_ids = [self.model.tokenizer.eos_token_id] + else: + # There weren't any previous tokens, so we can't update the state + next_state = last_state + next_token_ids = list(self.states_to_token_maps[last_state].keys()) - new_pstates.append(pstate) + mask = self._get_mask_for_state( + next_state, logits.shape[-1], next_token_ids + ) + masks.append(mask) + self.last_fsm_states[i] = next_state - self.pstates = new_pstates + mask = torch.concatenate(masks, dim=0) - masks = [] - mask_shape = (logits.shape[-1],) - for pstate in self.pstates: - mask = torch.full(mask_shape, -math.inf, device=self.device) + return logits + mask - if pstate[1] > -1: - next_support = self.pstate_to_vocab[pstate[:2]] - else: - next_support = [self.model.tokenizer.eos_token_id] + def _get_mask_for_state( + self, state: int, size: int, next_token_ids: List[int] + ) -> torch.LongTensor: + mask = self.mask_cache.get((state, size)) - mask[next_support] = 0 - masks.append(mask.unsqueeze(0)) + if mask is None: + mask = torch.full( + (size,), + -math.inf, + device=self.device, + ) - mask = torch.concatenate(masks, dim=0) + if self.allow_empty_tokens: + token_ids = list(self.empty_token_ids) + next_token_ids + else: + token_ids = next_token_ids - return logits + mask + mask[token_ids] = 0 + mask = mask.unsqueeze(0) + self.mask_cache[(state, size)] = mask + + return mask def postprocess_completions(self, completions: List[str]) -> List[str]: - self.pstates.clear() + self.last_fsm_states.clear() return super().postprocess_completions(completions) -def regex(model, regex_string: str, max_tokens: Optional[int] = None): +def regex( + model, + regex_string: str, + max_tokens: Optional[int] = None, + allow_empty_tokens: bool = True, +): """Generate text sequences that match the input regex. Parameters @@ -165,12 +191,14 @@ def regex(model, regex_string: str, max_tokens: Optional[int] = None): The regular expression that generated expressions must match. max_tokens The maximum number of tokens to generate. + allow_empty_tokens + Allow sampling of tokens corresponding to empty strings. """ - return Regex(model, regex_string, max_tokens) + return Regex(model, regex_string, max_tokens, allow_empty_tokens) -def integer(model, max_tokens: Optional[int] = None): +def integer(model, max_tokens: Optional[int] = None, allow_empty_tokens: bool = True): """Generate integers. The regex used to constrain the generation optionally matches plus or minus @@ -183,12 +211,14 @@ def integer(model, max_tokens: Optional[int] = None): The language model to use to compute the next-token logits. max_tokens The maximum number of tokens to generate. + allow_empty_tokens + Allow sampling of tokens corresponding to empty strings. """ - return Regex(model, r"[-+]?\d+", max_tokens) + return Regex(model, r"[-+]?\d+", max_tokens, allow_empty_tokens) -def float(model, max_tokens: Optional[int] = None): +def float(model, max_tokens: Optional[int] = None, allow_empty_tokens: bool = True): """Generate floating-point numbers. The regex used to constrain the generation optionally matches plus or minus @@ -201,18 +231,35 @@ def float(model, max_tokens: Optional[int] = None): The language model to use to compute the next-token logits. max_tokens The maximum number of tokens to generate. + allow_empty_tokens + Allow sampling of tokens corresponding to empty strings. """ - return Regex(model, r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))", max_tokens) - - -def choice(model, choices: List[str], max_tokens: Optional[int] = None): + return Regex( + model, + r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))", + max_tokens, + allow_empty_tokens, + ) + + +def choice( + model, + choices: List[str], + max_tokens: Optional[int] = None, + allow_empty_tokens: bool = True, +): """Choose between different sequences.""" regex_str = r"(" + r"|".join(choices) + r")" - return Regex(model, regex_str, max_tokens) + return Regex(model, regex_str, max_tokens, allow_empty_tokens) -def json(model, schema: Union[str, BaseModel], max_tokens: Optional[int] = None): +def json( + model, + schema: Union[str, BaseModel], + max_tokens: Optional[int] = None, + allow_empty_tokens: bool = True, +): """Generate a text sequence that follows a JSON schema or Pydantic model. Parameters @@ -223,6 +270,8 @@ def json(model, schema: Union[str, BaseModel], max_tokens: Optional[int] = None) The JSON schema or Pydantic model that guides the generation. max_tokens The maximum number of tokens to generate. + allow_empty_tokens + Allow sampling of tokens corresponding to empty strings. """ if isinstance(schema, type(BaseModel)): @@ -230,4 +279,4 @@ def json(model, schema: Union[str, BaseModel], max_tokens: Optional[int] = None) regex_str = build_regex_from_schema(schema) - return Regex(model, regex_str, max_tokens) + return Regex(model, regex_str, max_tokens, allow_empty_tokens) diff --git a/outlines/text/parsing.py b/outlines/text/parsing.py index 00a7a8a23..9bc15e69b 100644 --- a/outlines/text/parsing.py +++ b/outlines/text/parsing.py @@ -1,24 +1,10 @@ -from collections import ChainMap from copy import copy, deepcopy from dataclasses import dataclass from functools import cache -from typing import ( - Any, - Callable, - Dict, - FrozenSet, - Generator, - Iterable, - Iterator, - Optional, - Sequence, - Set, - Tuple, - Union, -) +from typing import Any, Dict, FrozenSet, Iterator, Optional, Set, Tuple, Union import interegular -from interegular.fsm import FSM, Alphabet, OblivionError +from interegular.fsm import FSM from interegular.patterns import Unsupported from lark import Lark, Token from lark.common import LexerConf, ParserConf @@ -52,6 +38,13 @@ from lark.parsers.lalr_interactive_parser import InteractiveParser from lark.parsers.lalr_parser import LALR_Parser, ParseConf, ParserState, _Parser +from outlines.text.fsm import ( + fsm_union, + get_sub_fsms_from_seq, + make_deterministic_fsm, + walk_fsm, +) + PartialParseState = Tuple[str, int] ParseStateType = Union[int, FrozenSet] @@ -72,80 +65,6 @@ class PartialTokensInfo: final_terminals_and_info: Tuple[PartialTerminalInfo, ...] -def make_deterministic_fsm(fsm: FSM) -> Tuple[FSM, Dict[int, int]]: - """Construct an equivalent FSM with deterministic state labels.""" - old_to_new_trans_keys = { - trans_key: i - for i, (trans_key, _) in enumerate( - sorted(fsm.alphabet.by_transition.items(), key=lambda x: sorted(x[1])) - ) - } - - new_symbol_mapping = { - symbol: old_to_new_trans_keys[trans_key] - for symbol, trans_key in fsm.alphabet._symbol_mapping.items() - } - - new_alphabet = Alphabet(new_symbol_mapping) - - new_map = { - from_state: { - old_to_new_trans_keys[trans_key]: to_state - for trans_key, to_state in trans_map.items() - } - for from_state, trans_map in fsm.map.items() - } - - old_to_new_states = {} - old_to_new_states[fsm.initial] = 0 - - i = 0 - seen = {fsm.initial} - old_state_queue = [fsm.initial] - while old_state_queue: - old_state = old_state_queue.pop(-1) - transitions = new_map[old_state] - sorted_transitions = sorted(transitions.items(), key=lambda v: v[0]) - for _, old_state in sorted_transitions: - if old_state not in seen: - old_state_queue.append(old_state) - seen.add(old_state) - if old_state not in old_to_new_states: - i += 1 - old_to_new_states[old_state] = i - - new_map = dict( - sorted( - ( - ( - old_to_new_states[from_state], - dict( - sorted( - ( - (trans_key, old_to_new_states[to_state]) - for trans_key, to_state in trans_map.items() - ), - key=lambda v: v[0], - ) - ), - ) - for from_state, trans_map in new_map.items() - ), - key=lambda v: v[0], - ) - ) - - new_initial = 0 - new_finals = frozenset( - sorted(old_to_new_states[old_state] for old_state in fsm.finals) - ) - new_states = frozenset(sorted(new_map.keys())) - - new_fsm = FSM(new_alphabet, new_states, new_initial, new_finals, new_map) - - return new_fsm, old_to_new_states - - class PartialParserConf(ParserConf): __serialize_fields__ = ( "rules", @@ -635,20 +554,20 @@ def match(self, text, pos, last_fsm_state_seq: Optional[Tuple[int, ...]] = None) text_part = text[start_pos:] - res = find_partial_matches( - self.fsm, + state_seq = walk_fsm( + self.fsm.fsm_info, text_part, - start_state=start_state, + start_state, full_match=self.match_whole, ) - if len(res) == 0: + if len(state_seq) == 0: return None - ((_, state_seq),) = res - if last_fsm_state_seq: - state_seq = last_fsm_state_seq[:-1] + state_seq + state_seq = last_fsm_state_seq + tuple(state_seq) + else: + state_seq = (start_state,) + tuple(state_seq) return state_seq @@ -910,96 +829,6 @@ def get_contextual_lexer(x: Union[PartialLexerThread, PartialParsingFrontend]): return x.lexer.lexer -def find_partial_matches( - fsm: FSM, input_string: str, start_state: Optional[int] = None, full_match=True -) -> Set[Tuple[int, Tuple[int, ...]]]: - """Find the states in the finite state machine `fsm` that accept `input_string`. - - This will consider all possible states in the finite state machine (FSM) - that accept the beginning of `input_string` as starting points, unless a - specific `start_state` is provided. - - Parameters - ---------- - fsm - The finite state machine. - input_string - The string for which we generate partial matches. - start_state - A single fixed starting state to consider. For example, if this value - is set to `fsm.initial`, it attempt to read `input_string` from the - beginning of the FSM/regular expression. - full_match - Matches must cover the entire string. - - Returns - ------- - A set of tuples corresponding to each valid starting state in the FSM. The - first element of each tuple contains an integer indicating the position in - `input_string` at which the FSM stopped. The second element is the tuple - of states visited during execution of the FSM plus the next, unvisited - transition state. - - """ - if len(input_string) == 0: - return set() - - trans_key = fsm.alphabet[input_string[0]] - - # TODO: We could probably reuse parts of the computed paths when computing - # results for multiple starting points. - def _partial_match( - trans: Dict[int, int] - ) -> Tuple[Optional[int], Optional[Tuple[int, ...]]]: - fsm_map = ChainMap({fsm.initial: trans}, fsm.map) - state = fsm.initial - accepted_states: Tuple[int, ...] = () - last_final_idx = -1 - - for i, symbol in enumerate(input_string): - trans_key = fsm.alphabet[symbol] - - trans_map = fsm_map.get(state) - - if trans_map is None or trans_key not in trans_map: - if full_match: - if state in fsm.finals: - i -= 1 - break - else: - if last_final_idx > -1: - i = last_final_idx - accepted_states = accepted_states[: last_final_idx + 1] - break - - return None, None - - state = trans_map[trans_key] - - if state in fsm.finals: - last_final_idx = i - - accepted_states += (state,) - - terminated = state in fsm.finals - if not terminated and state == fsm.initial: - return None, None - - return i, accepted_states - - res = set() - transition_maps = ( - fsm.map if start_state is None else {start_state: fsm.map[start_state]} - ) - for state, trans in transition_maps.items(): - if trans_key in trans: - last_match_idx, path = _partial_match(trans) - if last_match_idx is not None and path is not None: - res.add((last_match_idx, (state,) + path)) - - return res - - def terminals_to_fsms(lp: PartialLark) -> Dict[str, FSM]: """Construct a ``dict`` mapping terminal symbol names to their finite state machines.""" @@ -1015,221 +844,3 @@ def terminals_to_fsms(lp: PartialLark) -> Dict[str, FSM]: symbol_names_and_fsms[terminal.name] = fsm return symbol_names_and_fsms - - -def map_partial_states_to_vocab( - vocabulary: Iterable[str], - terminals_to_fsms_map: Dict[str, FSM], - partial_match_filter: Callable[ - [str, Optional[int], Tuple[int, ...]], bool - ] = lambda *args: True, - final_state_string: Optional[str] = None, -) -> Tuple[Dict[PartialParseState, Set[int]], Dict[str, Dict[int, Set[int]]]]: - """Construct a map from partial parse states to subsets of `vocabulary`. - - The subsets of `vocabulary` consist of elements that are accepted by--or - transition to--the corresponding partial parse states. - - Parameters - ---------- - vocabulary - The vocabulary composed of strings. - terminals_to_fsms_map - Terminal symbol names mapped to FSMs, as provided by `terminals_to_fsms`. - partial_match_filter - A callable that determines which partial matches to keep. The first - argument is the string being match, the rest are the unpacked partial - match return values of `find_partial_matches`. - final_state_string - A string from `vocabulary` that is to be added to all the final states - in the FSM. - """ - - final_state_string_idx = None - - # Partial parse states to the subsets of the vocabulary that accept them - pstate_to_vocab: Dict[Tuple[str, int], Set[int]] = {} - possible_paths = {} - for symbol_name, fsm in terminals_to_fsms_map.items(): - terminal_possible_paths: Dict[int, Set[int]] = {} - for i, vocab_string in enumerate(vocabulary): - if vocab_string == final_state_string: - final_state_string_idx = i - - for end_idx, state_seq in find_partial_matches(fsm, vocab_string): - if partial_match_filter(vocab_string, end_idx, state_seq): - terminal_possible_paths.setdefault(state_seq[0], set()).add( - state_seq[-1] - ) - pstate_to_vocab.setdefault((symbol_name, state_seq[0]), set()).add( - i - ) - - possible_paths[symbol_name] = terminal_possible_paths - - if final_state_string_idx is not None: - # Allow transitions to EOS from all terminals FSM states - for symbol_name, fsm in terminals_to_fsms_map.items(): - for state in fsm.finals: - pstate_to_vocab.setdefault((symbol_name, state), set()).add( - final_state_string_idx - ) - - return pstate_to_vocab, possible_paths - - -def fsm_union( - fsms: Sequence[FSM], -) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]: - """Construct an FSM representing the union of the FSMs in `fsms`. - - This is an updated version of `interegular.fsm.FSM.union` made to return an - extra map of component FSMs to the sets of state transitions that - correspond to them in the new FSM. - - """ - - alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms]) - - indexed_fsms = tuple(enumerate(fsms)) - - initial = {i: fsm.initial for (i, fsm) in indexed_fsms} - - # Dedicated function accepting a "superset" and returning the next - # "superset" obtained by following this transition in the new FSM - def follow(current_state, new_transition: int): - next = {} - for i, f in indexed_fsms: - old_transition = new_to_old[i][new_transition] - if ( - i in current_state - and current_state[i] in f.map - and old_transition in f.map[current_state[i]] - ): - next[i] = f.map[current_state[i]][old_transition] - if not next: - raise OblivionError - return next - - states = [initial] - finals: Set[int] = set() - map: Dict[int, Dict[int, int]] = {} - - # Map component FSMs to their new state-to-state transitions, finals, and a - # map translating component FSM states to aggregate FSM states - fsms_to_trans_finals: Dict[ - int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]] - ] = {} - - i = 0 - while i < len(states): - state = states[i] - - # Add to the finals of the aggregate FSM whenever we hit a final in a - # component FSM - if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms): - finals.add(i) - - # Compute the map for this state - map[i] = {} - for transition in alphabet.by_transition: - try: - next = follow(state, transition) - except OblivionError: - # Reached an oblivion state; don't list it - continue - else: - try: - # TODO: Seems like this could--and should--be avoided - j = states.index(next) - except ValueError: - j = len(states) - states.append(next) - - map[i][transition] = j - - for fsm_id, fsm_state in next.items(): - ( - fsm_transitions, - fsm_finals, - fsm_old_to_new, - ) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {})) - old_from = state[fsm_id] - old_to = fsm_state - fsm_old_to_new.setdefault(old_from, set()).add(i) - fsm_old_to_new.setdefault(old_to, set()).add(j) - fsm_transitions.add((i, j)) - if fsm_state in fsms[fsm_id].finals: - fsm_finals.add(j) - - i += 1 - - fsm = FSM( - alphabet=alphabet, - states=range(len(states)), - initial=0, - finals=finals, - map=map, - __no_validation__=True, - ) - - fsm, old_to_new_states = make_deterministic_fsm(fsm) - _fsms_to_trans_finals = { - fsm_id: ( - {(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions}, - {old_to_new_states[s] for s in finals}, - { - old_state: {old_to_new_states[new_state] for new_state in new_states} - for old_state, new_states in old_to_new.items() - }, - ) - for fsm_id, (transitions, finals, old_to_new) in sorted( - fsms_to_trans_finals.items(), key=lambda x: x[0] - ) - } - - return ( - fsm, - _fsms_to_trans_finals, - ) - - -def get_sub_fsms_from_seq( - state_seq: Sequence[int], - fsms_to_trans_finals: Dict[ - int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]] - ], -) -> Generator[Tuple[int, bool, bool], None, None]: - """Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`. - - Parameters - ---------- - state_seq - A state sequence. - fsms_to_trans_finals - A map from FSM indices to tuples containing sets of their state transitions - and sets of the final/accept states. - - Returns - ------- - A generator returning tuples containing each sub-FSM index (in the order - they were union-ed to construct `fsm`) and booleans indicating whether or - not there is another valid transition from the last state in the sequence - for the associated sub-FSM (i.e. if the FSM can continue - accepting/matching) and whether or not the sequence ends in a final state - of the sub-FSM. - """ - state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:])) - last_fsm_state = state_seq[-1] - yield from ( - ( - # The sub-FMS index - fsm_idx, - # Is there another possible transition in this sub-FSM? - any(last_fsm_state == from_s for (from_s, to_s) in transitions), - # Is this sub-FSM in a final state? - state_seq[-1] in finals, - ) - for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items() - if state_seq_transitions.issubset(transitions) - ) diff --git a/pyproject.toml b/pyproject.toml index 1426b524f..b5a8fb6e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,8 @@ dependencies = [ "scipy", "tenacity", "torch", + "numba", + "joblib", ] dynamic = ["version"] @@ -81,7 +83,7 @@ exclude=["examples"] module = [ "diffusers", "jinja2", - "joblib", + "joblib.*", "openai", "numpy.*", "perscache.*", @@ -96,6 +98,7 @@ module = [ "transformers.*", "lark.*", "interegular.*", + "numba.*", ] ignore_missing_imports = true diff --git a/tests/text/generate/test_regex.py b/tests/text/generate/test_regex.py index 5ef3afcd7..e3ba4f6d5 100644 --- a/tests/text/generate/test_regex.py +++ b/tests/text/generate/test_regex.py @@ -15,6 +15,12 @@ class Tokenizer: tokens = list(vocabulary.keys()) special_tokens = {""} + def encode(self, tokens): + if not isinstance(tokens, (tuple, list)): + tokens = [tokens] + + return [self.vocabulary[token] for token in tokens] + def decode(self, token_ids): decoded = [] for i in range(token_ids.shape[0]): @@ -26,11 +32,21 @@ def convert_token_to_string(self, token): return token +class TokenizerWithEmpty(Tokenizer): + vocabulary = {"": 0, "-": 1, "1": 2, "0.": 3, "431": 4, "a": 5, "A": 6, "": 7} + tokens = list(vocabulary.keys()) + + class Model: tokenizer = Tokenizer() device = "cpu" +class ModelWithEmpty: + tokenizer = TokenizerWithEmpty() + device = "cpu" + + @pytest.mark.parametrize( "regex_string, valid_first_token, proposal", [ @@ -153,3 +169,32 @@ def test_float_proposal(input_ids, proposal): result, torch.tensor(proposal), ) + + +@pytest.mark.parametrize( + "input_ids, proposal, with_empty", + [ + ([[]], [[-math.inf, 1.0, 1.0, 1.0, 1.0, -math.inf, -math.inf, 1]], True), + ( + [[]], + [[-math.inf, 1.0, 1.0, 1.0, 1.0, -math.inf, -math.inf, -math.inf]], + False, + ), + ([[3]], [[1.0, -math.inf, 1.0, -math.inf, 1.0, -math.inf, -math.inf, 1]], True), + ( + [[3]], + [[1.0, -math.inf, 1.0, -math.inf, 1.0, -math.inf, -math.inf, -math.inf]], + False, + ), + ], +) +def test_empty_strings(input_ids, proposal, with_empty): + model = ModelWithEmpty() + generator = generate.float(model, allow_empty_tokens=with_empty) + + logits = torch.ones(len(model.tokenizer.vocabulary)) + result = generator.create_proposal(torch.tensor(input_ids), logits) + assert torch.equal( + result, + torch.tensor(proposal), + ) diff --git a/tests/text/test_fsm.py b/tests/text/test_fsm.py new file mode 100644 index 000000000..bddf77e90 --- /dev/null +++ b/tests/text/test_fsm.py @@ -0,0 +1,353 @@ +import interegular +import numba +import pytest + +from outlines.models.transformers import TransformersTokenizer +from outlines.text.fsm import ( + create_fsm_index, + create_fsm_index_end_to_end, + create_fsm_index_tokenizer, + find_partial_matches, + fsm_union, + get_sub_fsms_from_seq, + make_deterministic_fsm, + walk_fsm, +) + + +def test_partial_match(): + name_pattern = interegular.parse_pattern(r"[^\W\d]\w*") + name_fsm, _ = make_deterministic_fsm(name_pattern.to_fsm().reduce()) + assert name_fsm.initial == 0 + + name_fsm = name_fsm.fsm_info + + def_pattern = interegular.parse_pattern("def") + def_fsm, _ = make_deterministic_fsm(def_pattern.to_fsm().reduce()) + assert def_fsm.initial == 0 + + def_fsm = def_fsm.fsm_info + + def to_python(res): + return {(x, tuple(y)) for x, y in res} + + res = to_python(find_partial_matches(def_fsm, "def")) + assert res == {(2, (0, 1, 2, 3))} + res = to_python(find_partial_matches(def_fsm, "de")) + assert res == {(1, (0, 1, 2))} + res = to_python(find_partial_matches(def_fsm, "d")) + assert res == {(0, (0, 1))} + res = to_python(find_partial_matches(def_fsm, "")) + assert res == set() + res = to_python(find_partial_matches(def_fsm, "df")) + assert res == set() + res = to_python(find_partial_matches(def_fsm, "ef")) + assert res == {(1, (1, 2, 3))} + res = to_python(find_partial_matches(def_fsm, "e")) + assert res == {(0, (1, 2))} + res = to_python(find_partial_matches(def_fsm, "f")) + assert res == {(0, (2, 3))} + res = to_python(find_partial_matches(def_fsm, "ef foo")) + assert res == {(1, (1, 2, 3))} + + # This string has a `DEF` token in it, but should ultimately not lex one + res = to_python(find_partial_matches(def_fsm, "defb")) + assert res == {(2, (0, 1, 2, 3))} + + # `NAME` can have multiple start states for this input + res = to_python(find_partial_matches(name_fsm, "d")) + assert res == {(0, (0, 1)), (0, (1, 1))} + # Not this case + res = to_python(find_partial_matches(name_fsm, "1d")) + assert res == {(1, (1, 1, 1))} + + res = to_python(find_partial_matches(name_fsm, "blah")) + assert res == { + (3, (0, 1, 1, 1, 1)), + (3, (1, 1, 1, 1, 1)), + } + + float_pattern = interegular.parse_pattern( + r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))" + ) + float_fsm, _ = make_deterministic_fsm(float_pattern.to_fsm().reduce()) + assert 5 in float_fsm.finals + assert 2 not in float_fsm.finals + + float_fsm = float_fsm.fsm_info + + res = to_python(find_partial_matches(float_fsm, ".")) + assert res == {(0, (3, 5)), (0, (4, 5)), (0, (0, 2))} + + joins_fsm, _ = make_deterministic_fsm( + interegular.parse_pattern(r"(JOIN LEFT|JOIN)").to_fsm().reduce() + ) + + joins_fsm = joins_fsm.fsm_info + + res = to_python(find_partial_matches(joins_fsm, "JOIN BLAH", full_match=False)) + assert res == {(3, (0, 1, 2, 3, 4))} + + res = to_python(find_partial_matches(joins_fsm, "JOIN L", full_match=False)) + assert res == {(5, (0, 1, 2, 3, 4, 5, 6))} + + res = to_python(find_partial_matches(joins_fsm, "JOI", full_match=False)) + assert res == {(2, (0, 1, 2, 3))} + + regex_pattern = interegular.parse_pattern("0|[1-9][2-9]*") + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + # State `1` has no transitions + assert not regex_fsm.map[1] + # This should fail, because state `1` reads nothing + res = to_python(walk_fsm(regex_fsm.fsm_info, "0", 1)) + assert res == set() + + res = to_python(find_partial_matches(regex_fsm.fsm_info, "0", 1)) + assert res == {(0, (0, 1))} + + +def test_create_fsm_index(): + regex_str = "0|[1-9][0-9]*" + + regex_pattern = interegular.parse_pattern(regex_str) + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + + vocabulary = {"blah": 0, "1a": 1, "2": 2, "0": 3, "": 4} + + res = create_fsm_index(regex_fsm.fsm_info, vocabulary) + + assert res == {0: {2, 3}, 2: {2, 3}} + + res = create_fsm_index(regex_fsm.fsm_info, vocabulary, "") + + assert res == {0: {2, 3}, 1: {4}, 2: {2, 3, 4}} + + +def test_get_sub_fsms_from_seq(): + name_pattern = interegular.parse_pattern(r"[^\W\d]\w*") + name_fsm, _ = make_deterministic_fsm(name_pattern.to_fsm().reduce()) + + def_pattern = interegular.parse_pattern("def") + def_fsm, _ = make_deterministic_fsm(def_pattern.to_fsm().reduce()) + + match_pattern = interegular.parse_pattern("match") + match_fsm, _ = make_deterministic_fsm(match_pattern.to_fsm().reduce()) + + peq_pattern = interegular.parse_pattern(r"\+=") + peq_fsm, _ = make_deterministic_fsm(peq_pattern.to_fsm().reduce()) + + plus_pattern = interegular.parse_pattern(r"\+") + plus_fsm, _ = make_deterministic_fsm(plus_pattern.to_fsm().reduce()) + + fsms = [def_fsm, match_fsm, name_fsm, peq_fsm, plus_fsm] + + fsm, fsms_to_trans_finals = fsm_union(fsms) + + assert fsms_to_trans_finals == { + 0: ({(0, 3), (3, 9), (9, 10)}, {10}, {0: {0}, 1: {3}, 2: {9}, 3: {10}}), + 1: ( + {(0, 4), (4, 5), (5, 6), (6, 7), (7, 8)}, + {8}, + {0: {0}, 1: {4}, 2: {5}, 3: {6}, 4: {7}, 5: {8}}, + ), + 2: ( + { + (0, 2), + (0, 3), + (0, 4), + (2, 2), + (3, 2), + (3, 9), + (4, 2), + (4, 5), + (5, 2), + (5, 6), + (6, 2), + (6, 7), + (7, 2), + (7, 8), + (8, 2), + (9, 2), + (9, 10), + (10, 2), + }, + {2, 3, 4, 5, 6, 7, 8, 9, 10}, + {0: {0}, 1: {2, 3, 4, 5, 6, 7, 8, 9, 10}}, + ), + 3: ({(0, 1), (1, 11)}, {11}, {0: {0}, 1: {1}, 2: {11}}), + 4: ({(0, 1)}, {1}, {0: {0}, 1: {1}}), + } + + assert not fsm.accepts("1a") + assert fsm.accepts("a1") + assert fsm.accepts("def") + assert fsm.accepts("match") + assert fsm.accepts("+=") + assert fsm.accepts("+") + + state_seq = walk_fsm(fsm.fsm_info, "def", fsm.initial) + state_seq.insert(0, fsm.initial) + + res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) + assert res == [(0, False, True), (2, True, True)] + + # Make sure the old-to-new state map is correct + def_state_seq = walk_fsm(def_fsm.fsm_info, "def", fsm.initial) + def_state_seq.insert(0, fsm.initial) + + def_old_to_new_states = fsms_to_trans_finals[0][2] + assert all( + new_state in def_old_to_new_states[old_state] + for old_state, new_state in zip(def_state_seq, state_seq) + ) + + state_seq = walk_fsm(fsm.fsm_info, "ef", fsm.initial) + state_seq.insert(0, fsm.initial) + + res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) + assert res == [(2, True, True)] + + name_state_seq = walk_fsm(name_fsm.fsm_info, "ef", fsm.initial) + name_state_seq.insert(0, fsm.initial) + + name_old_to_new_states = fsms_to_trans_finals[2][2] + assert all( + new_state in name_old_to_new_states[old_state] + for old_state, new_state in zip(name_state_seq, state_seq) + ) + + state_seq = walk_fsm(fsm.fsm_info, "match", fsm.initial) + state_seq.insert(0, fsm.initial) + + res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) + assert res == [(1, False, True), (2, True, True)] + + match_state_seq = walk_fsm(match_fsm.fsm_info, "match", fsm.initial) + match_state_seq.insert(0, fsm.initial) + + match_old_to_new_states = fsms_to_trans_finals[1][2] + assert all( + new_state in match_old_to_new_states[old_state] + for old_state, new_state in zip(match_state_seq, state_seq) + ) + + state_seq = walk_fsm(fsm.fsm_info, "defa", fsm.initial) + state_seq.insert(0, fsm.initial) + + res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) + assert res == [(2, True, True)] + + state_seq = walk_fsm(fsm.fsm_info, "de", fsm.initial) + state_seq.insert(0, fsm.initial) + + res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) + assert res == [(0, True, False), (2, True, True)] + + state_seq = walk_fsm(fsm.fsm_info, "+", fsm.initial, False) + state_seq.insert(0, fsm.initial) + + res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) + assert res == [(3, True, False), (4, False, True)] + + state_seq = walk_fsm(fsm.fsm_info, "+=", fsm.initial) + state_seq.insert(0, fsm.initial) + + res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) + assert res == [(3, False, True)] + + # Test some overlapping patterns + join_fsms = [ + interegular.parse_pattern(r"JOIN").to_fsm().reduce(), + interegular.parse_pattern(r"JOIN LEFT").to_fsm().reduce(), + ] + fsm, fsms_to_trans_finals = fsm_union(join_fsms) + + ((_, state_seq),) = find_partial_matches(fsm.fsm_info, "OI", full_match=False) + + res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) + assert res == [(0, True, False), (1, True, False)] + + ((_, state_seq),) = find_partial_matches(fsm.fsm_info, "N", full_match=False) + + res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) + assert res == [(0, False, True), (1, True, False)] + + ((_, state_seq),) = find_partial_matches(fsm.fsm_info, " ", full_match=False) + + res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) + assert res == [(1, True, False)] + + +def test_create_fsm_index_end_to_end(): + regex_str = "0|[1-9][0-9]*" + + regex_pattern = interegular.parse_pattern(regex_str) + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + + vocabulary = { + "blah": numba.typed.List([0]), + "1a": numba.typed.List([1]), + "2": numba.typed.List([2]), + "0": numba.typed.List([3]), + "": numba.typed.List([4]), + } + + vocabulary_nb = numba.typed.Dict.empty( + numba.types.string, numba.types.ListType(numba.int64) + ) + vocabulary_nb.update(vocabulary) + + res = create_fsm_index_end_to_end(regex_fsm.fsm_info, vocabulary_nb) + + assert res == {0: {(2, 2), (3, 1)}, 2: {(2, 2), (3, 2)}} + + +def test_create_fsm_index_tokenizer(): + # The combined regular expressions of a lexer state in a Python grammar + regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" + + regex_pattern = interegular.parse_pattern(regex_str) + # Not reduced, so that there are many states + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm()) + + num_fsm_states = len(regex_fsm.states) + assert num_fsm_states == 220 + + tokenizer = TransformersTokenizer("gpt2") + + states_to_token_subsets, empty_token_ids = create_fsm_index_tokenizer( + regex_fsm, tokenizer + ) + + assert not empty_token_ids + assert len(states_to_token_subsets) / num_fsm_states > 0.94 + + +@pytest.mark.skip(reason="Only for local profiling") +def test_regex_index_performance(): + from line_profiler import LineProfiler # type: ignore [import] + + regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" + + regex_pattern = interegular.parse_pattern(regex_str) + # Not reduced, so that there are many states + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm()) + + num_fsm_states = len(regex_fsm.states) + assert num_fsm_states == 220 + + tokenizer = TransformersTokenizer("gpt2") + + # Pre-compile Numba functions + res, _ = create_fsm_index_tokenizer(regex_fsm, tokenizer) + assert len(res) > 1 + + profiler = LineProfiler(create_fsm_index_end_to_end) + + profiler.runctx( + "create_fsm_index_tokenizer(regex_fsm, tokenizer)", + globals(), + locals(), + ) + profiler.dump_stats("line-profiler-create_fsm_index.pkl") + profiler.print_stats(output_unit=1e-3, summarize=True) diff --git a/tests/text/test_parsing.py b/tests/text/test_parsing.py index f4a08dd3b..20b96e7d6 100644 --- a/tests/text/test_parsing.py +++ b/tests/text/test_parsing.py @@ -1,22 +1,10 @@ -import random -import re from copy import copy -import interegular import pytest from lark.indenter import DedentError from lark.lexer import UnexpectedCharacters, UnexpectedToken -from outlines.text.parsing import ( - PartialLark, - PartialPythonIndenter, - find_partial_matches, - fsm_union, - get_sub_fsms_from_seq, - make_deterministic_fsm, - map_partial_states_to_vocab, - terminals_to_fsms, -) +from outlines.text.parsing import PartialLark, PartialPythonIndenter def test_partial_parsing(): @@ -204,356 +192,3 @@ def test_sequential_parse_example(): if i + 1 == len(input_tokens): assert all(tk in next_vocab for tk in ["\n", "\nde", " ", " + 1"]) - - -def test_find_partial_matches(): - name_pattern = interegular.parse_pattern(r"[^\W\d]\w*") - name_fsm, _ = make_deterministic_fsm(name_pattern.to_fsm().reduce()) - assert name_fsm.initial == 0 - - def_pattern = interegular.parse_pattern("def") - def_fsm, _ = make_deterministic_fsm(def_pattern.to_fsm().reduce()) - assert def_fsm.initial == 0 - - assert find_partial_matches(def_fsm, "def") == {(2, (0, 1, 2, 3))} - assert find_partial_matches(def_fsm, "de") == {(1, (0, 1, 2))} - assert find_partial_matches(def_fsm, "d") == {(0, (0, 1))} - assert find_partial_matches(def_fsm, "") == set() - assert find_partial_matches(def_fsm, "df") == set() - assert find_partial_matches(def_fsm, "ef") == {(1, (1, 2, 3))} - assert find_partial_matches(def_fsm, "e") == {(0, (1, 2))} - assert find_partial_matches(def_fsm, "f") == {(0, (2, 3))} - assert find_partial_matches(def_fsm, "ef foo") == {(1, (1, 2, 3))} - - # This string has a `DEF` token in it, but should ultimately not lex one - assert find_partial_matches(def_fsm, "defb") == {(2, (0, 1, 2, 3))} - - # `NAME` can have multiple start states for this input - assert find_partial_matches(name_fsm, "d") == { - (0, (0, 1)), - (0, (1, 1)), - } - # Not this case - assert find_partial_matches(name_fsm, "1d") == {(1, (1, 1, 1))} - - assert find_partial_matches(name_fsm, "blah") == { - (3, (0, 1, 1, 1, 1)), - (3, (1, 1, 1, 1, 1)), - } - - float_pattern = interegular.parse_pattern( - r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))" - ) - float_fsm, _ = make_deterministic_fsm(float_pattern.to_fsm().reduce()) - assert 5 in float_fsm.finals - assert 2 not in float_fsm.finals - - res = find_partial_matches(float_fsm, ".") - assert res == {(0, (3, 5)), (0, (4, 5)), (0, (0, 2))} - - joins_fsm, _ = make_deterministic_fsm( - interegular.parse_pattern(r"(JOIN LEFT|JOIN)").to_fsm().reduce() - ) - res = find_partial_matches( - joins_fsm, "JOIN BLAH", joins_fsm.initial, full_match=False - ) - assert res == {(3, (0, 1, 2, 3, 4))} - - res = find_partial_matches(joins_fsm, "JOIN L", joins_fsm.initial, full_match=False) - assert res == {(5, (0, 1, 2, 3, 4, 5, 6))} - - res = find_partial_matches(joins_fsm, "JOI", joins_fsm.initial, full_match=False) - assert res == {(2, (0, 1, 2, 3))} - - -def test_map_partial_states_to_vocab_python(): - pyparser = PartialLark.open_from_package( - "tests", - "partial_python.lark", - ["text"], - parser="lalr", - postlex=PartialPythonIndenter(), - start="file_input", - ) - - symbol_names_and_fsms = terminals_to_fsms(pyparser) - test_symbols = {"DEF", "NAME", "__IGNORE_0"} - symbol_names_and_fsms = { - k: v for k, v in symbol_names_and_fsms.items() if k in test_symbols - } - - assert len(symbol_names_and_fsms["DEF"].states) == 4 - assert len(symbol_names_and_fsms["NAME"].states) == 2 - assert len(symbol_names_and_fsms["__IGNORE_0"].states) == 2 - - vocabulary = ["d", "e", "ef foo", "f ", " ", "1d", ""] - - pstate_to_vocab, possible_paths = map_partial_states_to_vocab( - vocabulary, symbol_names_and_fsms - ) - - assert dict(pstate_to_vocab) == { - ("__IGNORE_0", 0): {4}, - ("__IGNORE_0", 1): {4}, - ("NAME", 0): {0, 1, 2, 3}, - ("NAME", 1): {0, 1, 2, 3, 5}, - ("DEF", 0): {0}, - ("DEF", 1): {1, 2}, - ("DEF", 2): {3}, - } - assert possible_paths["__IGNORE_0"] == {0: {1}, 1: {1}} - assert possible_paths["NAME"] == {0: {1}, 1: {1}} - assert possible_paths["DEF"] == {0: {1}, 1: {2, 3}, 2: {3}} - - pstate_to_vocab, possible_paths = map_partial_states_to_vocab( - vocabulary, symbol_names_and_fsms, final_state_string="" - ) - - assert dict(pstate_to_vocab) == { - ("__IGNORE_0", 0): { - 4, - }, - ("__IGNORE_0", 1): {4, 6}, - ("NAME", 0): {0, 1, 2, 3}, - ("NAME", 1): {0, 1, 2, 3, 5, 6}, - ("DEF", 0): { - 0, - }, - ("DEF", 1): {1, 2}, - ("DEF", 2): { - 3, - }, - ("DEF", 3): { - 6, - }, - } - assert possible_paths["__IGNORE_0"] == {0: {1}, 1: {1}} - assert possible_paths["NAME"] == {0: {1}, 1: {1}} - assert possible_paths["DEF"] == {0: {1}, 1: {2, 3}, 2: {3}} - - -def test_map_partial_states_to_vocab_regex(): - regex_string = r"([0-9]+([.][0-9]*)?|[.][0-9]+)" - regex_pattern = interegular.parse_pattern(regex_string) - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - - vocabulary = [ - "1.", - "2", - "3.", - ".", - ".80", - "42", - "1a", - " ", - "0", - "a", - "b", - "$", - "", - ] - - # We want the vocabulary strings to entirely match the regex--not just the - # prefixes of the vocabulary strings - def partial_match_filter(string, end_idx, state_seq): - if end_idx is not None and end_idx < len(string) - 1: - return False - return True - - pstate_to_vocab, possible_paths = map_partial_states_to_vocab( - vocabulary, {"FLOAT": regex_fsm}, partial_match_filter, "" - ) - - assert sorted(pstate_to_vocab.values(), key=lambda x: -len(x)) == [ - {0, 1, 2, 3, 4, 5, 8, 12}, - {0, 1, 2, 3, 4, 5, 8}, - {1, 5, 8, 12}, - {1, 5, 8}, - ] - assert possible_paths["FLOAT"] == {2: {2, 3}, 0: {1, 2, 3}, 3: {3}, 1: {3}} - - pstate_to_vocab = {k: tuple(v) for k, v in pstate_to_vocab.items()} - - random.seed(24080) - - for n in range(50): - # Start at the initial state - pstate = ("FLOAT", regex_fsm.initial) - - sample_seq = "" - - for i in range(5): - next_support = pstate_to_vocab[pstate] - - (next_sample_idx,) = random.sample(next_support, 1) - - next_sample = vocabulary[next_sample_idx] - - if next_sample == "": - break - - sample_seq += next_sample - - # Continue matching from where we left off - (pmatch,) = find_partial_matches( - regex_fsm, next_sample, start_state=pstate[-1] - ) - - # Create the next state - pstate = (pstate[0], pmatch[1][-1]) - - # TODO: We could check if the FSM is done (i.e. in an final/accept - # state) and end the sampling loop - - # Make sure the whole thing matches the regex - assert re.fullmatch(regex_string, sample_seq) is not None - - -def test_get_sub_fsms_from_seq(): - name_pattern = interegular.parse_pattern(r"[^\W\d]\w*") - name_fsm, _ = make_deterministic_fsm(name_pattern.to_fsm().reduce()) - - def_pattern = interegular.parse_pattern("def") - def_fsm, _ = make_deterministic_fsm(def_pattern.to_fsm().reduce()) - - match_pattern = interegular.parse_pattern("match") - match_fsm, _ = make_deterministic_fsm(match_pattern.to_fsm().reduce()) - - peq_pattern = interegular.parse_pattern(r"\+=") - peq_fsm, _ = make_deterministic_fsm(peq_pattern.to_fsm().reduce()) - - plus_pattern = interegular.parse_pattern(r"\+") - plus_fsm, _ = make_deterministic_fsm(plus_pattern.to_fsm().reduce()) - - fsms = [def_fsm, match_fsm, name_fsm, peq_fsm, plus_fsm] - - fsm, fsms_to_trans_finals = fsm_union(fsms) - - assert fsms_to_trans_finals == { - 0: ({(0, 3), (3, 9), (9, 10)}, {10}, {0: {0}, 1: {3}, 2: {9}, 3: {10}}), - 1: ( - {(0, 4), (4, 5), (5, 6), (6, 7), (7, 8)}, - {8}, - {0: {0}, 1: {4}, 2: {5}, 3: {6}, 4: {7}, 5: {8}}, - ), - 2: ( - { - (0, 2), - (0, 3), - (0, 4), - (2, 2), - (3, 2), - (3, 9), - (4, 2), - (4, 5), - (5, 2), - (5, 6), - (6, 2), - (6, 7), - (7, 2), - (7, 8), - (8, 2), - (9, 2), - (9, 10), - (10, 2), - }, - {2, 3, 4, 5, 6, 7, 8, 9, 10}, - {0: {0}, 1: {2, 3, 4, 5, 6, 7, 8, 9, 10}}, - ), - 3: ({(0, 1), (1, 11)}, {11}, {0: {0}, 1: {1}, 2: {11}}), - 4: ({(0, 1)}, {1}, {0: {0}, 1: {1}}), - } - - assert not fsm.accepts("1a") - assert fsm.accepts("a1") - assert fsm.accepts("def") - assert fsm.accepts("match") - assert fsm.accepts("+=") - assert fsm.accepts("+") - - ((_, state_seq),) = find_partial_matches(fsm, "def", start_state=fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(0, False, True), (2, True, True)] - - # Make sure the old-to-new state map is correct - ((_, def_state_seq),) = find_partial_matches( - def_fsm, "def", start_state=fsm.initial - ) - def_old_to_new_states = fsms_to_trans_finals[0][2] - assert all( - new_state in def_old_to_new_states[old_state] - for old_state, new_state in zip(def_state_seq, state_seq) - ) - - ((_, state_seq),) = find_partial_matches(fsm, "ef", start_state=fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(2, True, True)] - - ((_, name_state_seq),) = find_partial_matches( - name_fsm, "ef", start_state=fsm.initial - ) - name_old_to_new_states = fsms_to_trans_finals[2][2] - assert all( - new_state in name_old_to_new_states[old_state] - for old_state, new_state in zip(name_state_seq, state_seq) - ) - - ((_, state_seq),) = find_partial_matches(fsm, "match", start_state=fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(1, False, True), (2, True, True)] - - ((_, match_state_seq),) = find_partial_matches( - match_fsm, "match", start_state=fsm.initial - ) - match_old_to_new_states = fsms_to_trans_finals[1][2] - assert all( - new_state in match_old_to_new_states[old_state] - for old_state, new_state in zip(match_state_seq, state_seq) - ) - - ((_, state_seq),) = find_partial_matches(fsm, "defa", start_state=fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(2, True, True)] - - ((_, state_seq),) = find_partial_matches(fsm, "de", start_state=fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(0, True, False), (2, True, True)] - - ((_, state_seq),) = find_partial_matches(fsm, "+", start_state=fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(3, True, False), (4, False, True)] - - ((_, state_seq),) = find_partial_matches(fsm, "+=", start_state=fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(3, False, True)] - - # Test some overlapping patterns - join_fsms = [ - interegular.parse_pattern(r"JOIN").to_fsm().reduce(), - interegular.parse_pattern(r"JOIN LEFT").to_fsm().reduce(), - ] - fsm, fsms_to_trans_finals = fsm_union(join_fsms) - ((_, state_seq),) = find_partial_matches( - fsm, "OI", start_state=None, full_match=False - ) - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(0, True, False), (1, True, False)] - - ((_, state_seq),) = find_partial_matches( - fsm, "N", start_state=None, full_match=False - ) - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(0, False, True), (1, True, False)] - - ((_, state_seq),) = find_partial_matches( - fsm, " ", start_state=None, full_match=False - ) - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(1, True, False)]