Skip to content

Commit

Permalink
Require cloning and patching before calling parse_to_end
Browse files Browse the repository at this point in the history
This just makes it easier to cut down on unnecessary copying.
  • Loading branch information
brandonwillard committed Jul 6, 2023
1 parent 9368a8d commit 5224979
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
9 changes: 7 additions & 2 deletions outlines/text/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,14 @@ def copy_ip(ip: "InteractiveParser") -> "InteractiveParser":


def parse_to_end(parser_state: ParserState) -> Tuple[ParserState, Set[str]]:
"""Continue parsing from the current parse state and return partial next tokens."""
"""Continue parsing from the current parse state and return partial next tokens.
parser_state = copy_parser_state(parser_state)
.. warning::
The parse state `parser_state` is updated in-place and must be patched
to work with this function. Either patch it manually or use
`copy_parser_state` before calling this.
"""

expected_next_tokens: Set[str] = set()
try:
Expand Down
22 changes: 16 additions & 6 deletions tests/text/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,32 @@ def test_parse_to_end():
)

ip = pyparser.parse_interactive("x")
parser_state, expected_next_tokens = parse_to_end(ip.parser_state)
parser_state = copy_parser_state(ip.parser_state)
parser_state, expected_next_tokens = parse_to_end(parser_state)
assert not parser_state.value_stack
assert expected_next_tokens == {"NAME"}

ip = pyparser.parse_interactive("x = '")
parser_state, expected_next_tokens = parse_to_end(ip.parser_state)
parser_state = copy_parser_state(ip.parser_state)
parser_state, expected_next_tokens = parse_to_end(parser_state)
assert parser_state.value_stack[-1].type == "EQUAL"
assert expected_next_tokens == {"LONG_STRING", "STRING"}

ip = pyparser.parse_interactive("x = 'hi")
parser_state, expected_next_tokens = parse_to_end(ip.parser_state)
parser_state = copy_parser_state(ip.parser_state)
parser_state, expected_next_tokens = parse_to_end(parser_state)
assert parser_state.value_stack[-1].type == "EQUAL"
assert expected_next_tokens == {"STRING"}

ip = pyparser.parse_interactive("x = ('hi")
parser_state, expected_next_tokens = parse_to_end(ip.parser_state)
parser_state = copy_parser_state(ip.parser_state)
parser_state, expected_next_tokens = parse_to_end(parser_state)
assert parser_state.value_stack[-1].type == "LPAR"
assert expected_next_tokens == {"STRING"}

ip = pyparser.parse_interactive("def")
parser_state, expected_next_tokens = parse_to_end(ip.parser_state)
parser_state = copy_parser_state(ip.parser_state)
parser_state, expected_next_tokens = parse_to_end(parser_state)
assert not parser_state.value_stack
assert expected_next_tokens == {"NAME", "DEF"}

Expand Down Expand Up @@ -97,7 +102,7 @@ def test_sequential_parse_example():
start="file_input",
)
ip = pyparser.parse_interactive("")
parser_state = ip.parser_state
parser_state = copy_parser_state(ip.parser_state)

token_seq = ""
for i, token in enumerate(input_tokens):
Expand Down Expand Up @@ -243,6 +248,9 @@ def test_parse_from_partial_match():
(parser_state,) = create_pmatch_parser_states(
lp, terminals_to_states, term_type, ptoken, first_pmatch
)
# These copies also patch the lexers in the parse state, which is now
# needed for use with `parse_to_end`
parser_state = copy_parser_state(parser_state)
new_parser_state, expected_next_tokens = parse_to_end(parser_state)
assert expected_next_tokens == {"NAME"}

Expand All @@ -252,6 +260,7 @@ def test_parse_from_partial_match():
(parser_state,) = create_pmatch_parser_states(
lp, terminals_to_states, term_type, ptoken, first_pmatch
)
parser_state = copy_parser_state(parser_state)
new_parser_state, expected_next_tokens = parse_to_end(parser_state)
assert not expected_next_tokens

Expand All @@ -261,6 +270,7 @@ def test_parse_from_partial_match():
(parser_state,) = create_pmatch_parser_states(
lp, terminals_to_states, term_type, ptoken, first_pmatch
)
parser_state = copy_parser_state(parser_state)
with pytest.raises(UnexpectedToken):
parse_to_end(parser_state)

Expand Down

0 comments on commit 5224979

Please sign in to comment.