From 5224979404605125c7031fccd6912b66141bc70f Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Thu, 6 Jul 2023 17:37:38 -0500 Subject: [PATCH] Require cloning and patching before calling parse_to_end This just makes it easier to cut down on unnecessary copying. --- outlines/text/parsing.py | 9 +++++++-- tests/text/test_parsing.py | 22 ++++++++++++++++------ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/outlines/text/parsing.py b/outlines/text/parsing.py index 253b46f75..f3ab28a94 100644 --- a/outlines/text/parsing.py +++ b/outlines/text/parsing.py @@ -251,9 +251,14 @@ def copy_ip(ip: "InteractiveParser") -> "InteractiveParser": def parse_to_end(parser_state: ParserState) -> Tuple[ParserState, Set[str]]: - """Continue parsing from the current parse state and return partial next tokens.""" + """Continue parsing from the current parse state and return partial next tokens. - parser_state = copy_parser_state(parser_state) + .. warning:: + The parse state `parser_state` is updated in-place and must be patched + to work with this function. Either patch it manually or use + `copy_parser_state` before calling this. + + """ expected_next_tokens: Set[str] = set() try: diff --git a/tests/text/test_parsing.py b/tests/text/test_parsing.py index 097d8c177..918c349d0 100644 --- a/tests/text/test_parsing.py +++ b/tests/text/test_parsing.py @@ -30,27 +30,32 @@ def test_parse_to_end(): ) ip = pyparser.parse_interactive("x") - parser_state, expected_next_tokens = parse_to_end(ip.parser_state) + parser_state = copy_parser_state(ip.parser_state) + parser_state, expected_next_tokens = parse_to_end(parser_state) assert not parser_state.value_stack assert expected_next_tokens == {"NAME"} ip = pyparser.parse_interactive("x = '") - parser_state, expected_next_tokens = parse_to_end(ip.parser_state) + parser_state = copy_parser_state(ip.parser_state) + parser_state, expected_next_tokens = parse_to_end(parser_state) assert parser_state.value_stack[-1].type == "EQUAL" assert expected_next_tokens == {"LONG_STRING", "STRING"} ip = pyparser.parse_interactive("x = 'hi") - parser_state, expected_next_tokens = parse_to_end(ip.parser_state) + parser_state = copy_parser_state(ip.parser_state) + parser_state, expected_next_tokens = parse_to_end(parser_state) assert parser_state.value_stack[-1].type == "EQUAL" assert expected_next_tokens == {"STRING"} ip = pyparser.parse_interactive("x = ('hi") - parser_state, expected_next_tokens = parse_to_end(ip.parser_state) + parser_state = copy_parser_state(ip.parser_state) + parser_state, expected_next_tokens = parse_to_end(parser_state) assert parser_state.value_stack[-1].type == "LPAR" assert expected_next_tokens == {"STRING"} ip = pyparser.parse_interactive("def") - parser_state, expected_next_tokens = parse_to_end(ip.parser_state) + parser_state = copy_parser_state(ip.parser_state) + parser_state, expected_next_tokens = parse_to_end(parser_state) assert not parser_state.value_stack assert expected_next_tokens == {"NAME", "DEF"} @@ -97,7 +102,7 @@ def test_sequential_parse_example(): start="file_input", ) ip = pyparser.parse_interactive("") - parser_state = ip.parser_state + parser_state = copy_parser_state(ip.parser_state) token_seq = "" for i, token in enumerate(input_tokens): @@ -243,6 +248,9 @@ def test_parse_from_partial_match(): (parser_state,) = create_pmatch_parser_states( lp, terminals_to_states, term_type, ptoken, first_pmatch ) + # These copies also patch the lexers in the parse state, which is now + # needed for use with `parse_to_end` + parser_state = copy_parser_state(parser_state) new_parser_state, expected_next_tokens = parse_to_end(parser_state) assert expected_next_tokens == {"NAME"} @@ -252,6 +260,7 @@ def test_parse_from_partial_match(): (parser_state,) = create_pmatch_parser_states( lp, terminals_to_states, term_type, ptoken, first_pmatch ) + parser_state = copy_parser_state(parser_state) new_parser_state, expected_next_tokens = parse_to_end(parser_state) assert not expected_next_tokens @@ -261,6 +270,7 @@ def test_parse_from_partial_match(): (parser_state,) = create_pmatch_parser_states( lp, terminals_to_states, term_type, ptoken, first_pmatch ) + parser_state = copy_parser_state(parser_state) with pytest.raises(UnexpectedToken): parse_to_end(parser_state)