From da2608ded41adee5a3e3a8e0e341c9b6bbd6c386 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 31 May 2024 11:51:44 -0500 Subject: [PATCH] pass token_transition_sequence to walk_fsm in parsing.py --- outlines/fsm/parsing.py | 9 +++++++- outlines/fsm/regex.py | 47 ++++++++++++++++++++++++--------------- tests/fsm/test_parsing.py | 14 +++++++----- 3 files changed, 46 insertions(+), 24 deletions(-) diff --git a/outlines/fsm/parsing.py b/outlines/fsm/parsing.py index 9ebc2af55..19deb975e 100644 --- a/outlines/fsm/parsing.py +++ b/outlines/fsm/parsing.py @@ -38,6 +38,7 @@ from outlines.fsm.regex import ( fsm_union, get_sub_fsms_from_seq, + get_token_transitions, make_deterministic_fsm, walk_fsm, ) @@ -569,9 +570,15 @@ def match(self, text, pos, last_fsm_state_seq: Optional[Tuple[int, ...]] = None) text_part = text[start_pos:] + text_transitions = get_token_transitions( + self.fsm.fsm_info.alphabet_symbol_mapping, + self.fsm.fsm_info.alphabet_anything_value, + text_part, + ) + state_seq = walk_fsm( self.fsm, - text_part, + text_transitions, start_state, full_match=self.match_whole, ) diff --git a/outlines/fsm/regex.py b/outlines/fsm/regex.py index 1d27a0872..6e2b81412 100644 --- a/outlines/fsm/regex.py +++ b/outlines/fsm/regex.py @@ -675,6 +675,32 @@ def state_scan_tokens( return res +@numba.njit(cache=True, nogil=True) +def get_token_transitions( + alphabet_symbol_mapping: Dict[str, int], + alphabet_anything_value: int, + token_str: str, +) -> Sequence[int]: + trans_key_seq = [] + i = 0 + while i < len(token_str): + if token_str[i] == "\x00" and i != len(token_str) - 1: + symbol = token_str[i : i + 3] + i += 3 + else: + symbol = token_str[i] + i += 1 + + trans_key_seq.append( + alphabet_symbol_mapping.get(symbol, alphabet_anything_value) + ) + + trans_key_seq_array = np.empty(len(trans_key_seq), dtype=np.int64) + for j in range(len(trans_key_seq)): + trans_key_seq_array[j] = trans_key_seq[j] + return trans_key_seq_array + + @numba.njit(cache=True, nogil=True) def get_tokens_trans_keys( alphabet_symbol_mapping: Dict[str, int], @@ -683,24 +709,9 @@ def get_tokens_trans_keys( ) -> List[Sequence[int]]: tokens_trans_keys = numba.typed.List.empty_list(numba.int64[:]) for token_str, _ in vocabulary: - trans_key_seq = [] - i = 0 - while i < len(token_str): - if token_str[i] == "\x00" and i != len(token_str) - 1: - symbol = token_str[i : i + 3] - i += 3 - else: - symbol = token_str[i] - i += 1 - - trans_key_seq.append( - alphabet_symbol_mapping.get(symbol, alphabet_anything_value) - ) - - trans_key_seq_array = np.empty(len(trans_key_seq), dtype=np.int64) - for j in range(len(trans_key_seq)): - trans_key_seq_array[j] = trans_key_seq[j] - + trans_key_seq_array = get_token_transitions( + alphabet_symbol_mapping, alphabet_anything_value, token_str + ) tokens_trans_keys.append(trans_key_seq_array) return tokens_trans_keys diff --git a/tests/fsm/test_parsing.py b/tests/fsm/test_parsing.py index 4e093a994..b624fddee 100644 --- a/tests/fsm/test_parsing.py +++ b/tests/fsm/test_parsing.py @@ -9,7 +9,14 @@ from outlines.fsm.parsing import PartialLark, PartialPythonIndenter -def test_partial_parsing(): +@pytest.fixture +def cleanup_lark_import(): + yield + # Clean up lark.lark.LarkOptions._defaults + importlib.reload(lark.lark) + + +def test_partial_parsing(cleanup_lark_import): lp = PartialLark.open_from_package( "tests", "partial_python.lark", @@ -136,11 +143,8 @@ def test_partial_parsing(): assert len(parser_state.state_stack) == 4 assert parser_state.value_stack[-1].type == "LPAR" - # Clean up lark.lark.LarkOptions._defaults - importlib.reload(lark.lark) - -def test_sequential_parse_example(): +def test_sequential_parse_example(cleanup_lark_import): input_tokens = [ "x ", "= ",