From e6ff5834e263928835e40607ba522f163ab77b39 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 2 Aug 2023 14:06:53 -0500 Subject: [PATCH] Make parse tree/value computations optional --- outlines/text/parsing.py | 120 ++++++++++++++++++++++++++++++++----- tests/text/test_parsing.py | 23 +++++++ 2 files changed, 127 insertions(+), 16 deletions(-) diff --git a/outlines/text/parsing.py b/outlines/text/parsing.py index 161e76cfc..00a7a8a23 100644 --- a/outlines/text/parsing.py +++ b/outlines/text/parsing.py @@ -147,21 +147,35 @@ def make_deterministic_fsm(fsm: FSM) -> Tuple[FSM, Dict[int, int]]: class PartialParserConf(ParserConf): - __serialize_fields__ = "rules", "start", "parser_type", "deterministic" + __serialize_fields__ = ( + "rules", + "start", + "parser_type", + "deterministic", + "use_value_stack", + ) - def __init__(self, rules, callbacks, start, deterministic): + def __init__(self, rules, callbacks, start, deterministic, use_value_stack): super().__init__(rules, callbacks, start) self.deterministic = deterministic + self.use_value_stack = use_value_stack class PartialLark(Lark): - __serialize_fields__ = "parser", "rules", "options", "deterministic" + __serialize_fields__ = ( + "parser", + "rules", + "options", + "deterministic", + "use_value_stack", + ) def __init__(self, grammar, **options): # TODO: Could've extended `LarkOptions`, but all these extensions are # already way too much (and brittle). This library really needs a # complete refactoring. self.deterministic = options.pop("deterministic", False) + self.use_value_stack = options.pop("use_value_stack", False) options["regex"] = True super().__init__(grammar, **options) assert self.options.parser == "lalr" @@ -180,7 +194,11 @@ def _build_parser(self) -> "PartialParsingFrontend": self._prepare_callbacks() _validate_frontend_args(self.options.parser, self.options.lexer) parser_conf = PartialParserConf( - self.rules, self._callbacks, self.options.start, self.deterministic + self.rules, + self._callbacks, + self.options.start, + self.deterministic, + self.use_value_stack, ) # This is `_construct_parsing_frontend` expanded/inlined @@ -393,7 +411,12 @@ def to_tuple(v): zip(self._parse_table.states.keys(), new_states.keys()) ) - self.parser = PartialParser(self._parse_table, callbacks, debug) + self.parser = PartialParser( + self._parse_table, + callbacks, + debug, + use_value_stack=parser_conf.use_value_stack, + ) @classmethod def deserialize(cls, data, memo, callbacks, debug=False): @@ -404,16 +427,20 @@ def deserialize(cls, data, memo, callbacks, debug=False): class PartialParserState(ParserState): - def __copy__(self): - return type(self)( - self.parse_conf, - copy(self.lexer), - copy(self.state_stack), - deepcopy(self.value_stack), - ) + __slots__ = "use_value_stack" - def __repr__(self): - return f"{type(self).__name__}(lexer={self.lexer!r}, state_stack={self.state_stack!r})" + def __init__( + self, + parse_conf, + lexer, + state_stack=None, + value_stack=None, + use_value_stack=False, + ): + super().__init__( + parse_conf, lexer, state_stack=state_stack, value_stack=value_stack + ) + self.use_value_stack = use_value_stack def feed_token(self, token, is_end=False): if token.type == "partial": @@ -438,16 +465,77 @@ def feed_token(self, token, is_end=False): ) return - super().feed_token(token, is_end=is_end) + if self.use_value_stack: + super().feed_token(token, is_end=is_end) + else: + self.feed_token_no_stack(token, is_end=is_end) + + def feed_token_no_stack(self, token, is_end=False): + """ + This is a copy of `ParserState.feed_token` with all the value stack + steps removed. Since we're not exactly parsing in order to obtain a + CST or anything similar, we can avoid the growing expense of tracking + the parse tree. + """ + state_stack = self.state_stack + states = self.parse_conf.states + end_state = self.parse_conf.end_state + + while True: + state = state_stack[-1] + try: + action, arg = states[state][token.type] + except KeyError: + expected = {s for s in states[state].keys() if s.isupper()} + raise UnexpectedToken( + token, expected, state=self, interactive_parser=None + ) + + assert arg != end_state + + if action is Shift: + # shift once and return + assert not is_end + state_stack.append(arg) + return + else: + # reduce+shift as many times as necessary + rule = arg + size = len(rule.expansion) + if size: + del state_stack[-size:] + + _action, new_state = states[state_stack[-1]][rule.origin.name] + assert _action is Shift + state_stack.append(new_state) + + if is_end and state_stack[-1] == end_state: + return + + def __copy__(self): + return type(self)( + self.parse_conf, + copy(self.lexer), + copy(self.state_stack), + deepcopy(self.value_stack), + use_value_stack=self.use_value_stack, + ) + + def __repr__(self): + return f"{type(self).__name__}(lexer={self.lexer!r}, state_stack={self.state_stack!r})" class PartialParser(_Parser): + def __init__(self, parse_table, callbacks, debug=False, use_value_stack=False): + super().__init__(parse_table, callbacks, debug=debug) + self.use_value_stack = use_value_stack + def parse( self, lexer, start, value_stack=None, state_stack=None, start_interactive=False ): parse_conf = ParseConf(self.parse_table, self.callbacks, start) parser_state = PartialParserState( - parse_conf, copy(lexer), state_stack, value_stack + parse_conf, copy(lexer), state_stack, value_stack, self.use_value_stack ) if start_interactive: return InteractiveParser(self, parser_state, parser_state.lexer) diff --git a/tests/text/test_parsing.py b/tests/text/test_parsing.py index fb13452f6..f4a08dd3b 100644 --- a/tests/text/test_parsing.py +++ b/tests/text/test_parsing.py @@ -37,6 +37,7 @@ def test_partial_parsing(): assert last_token.type == "partial" assert last_token.value.fsm_state_seq == (0, 15) assert last_token.value.is_not_finished is True + assert not parser_state.value_stack # End with an ignored token parser_state = lp.parse("x ") @@ -45,6 +46,7 @@ def test_partial_parsing(): assert last_token.type == "partial" assert last_token.value.fsm_state_seq == (0, 1) assert last_token.value.is_not_finished is True + assert not parser_state.value_stack # Could be a complete `=` or the start of a `==` parser_state = lp.parse("x =") @@ -55,6 +57,7 @@ def test_partial_parsing(): term_info.terminal_name == "EQUAL" for term_info in last_token.value.terminals_and_info ) + assert not parser_state.value_stack parser_state = lp.parse("x = '") assert parser_state.state_stack == [0, 58, 59] @@ -62,6 +65,7 @@ def test_partial_parsing(): assert last_token.type == "partial" assert last_token.value.fsm_state_seq == (0, 6) assert last_token.value.is_not_finished is True + assert not parser_state.value_stack parser_state = lp.parse("x = 'hi") assert parser_state.state_stack == [0, 58, 59] @@ -69,6 +73,7 @@ def test_partial_parsing(): assert last_token.type == "partial" assert last_token.value.fsm_state_seq == (0, 6, 6, 6) assert last_token.value.is_not_finished is True + assert not parser_state.value_stack parser_state = lp.parse("x = ('hi") assert parser_state.state_stack == [0, 58, 59, 254] @@ -76,6 +81,7 @@ def test_partial_parsing(): assert last_token.type == "partial" assert last_token.value.fsm_state_seq == (0, 6, 6, 6) assert last_token.value.is_not_finished is True + assert not parser_state.value_stack parser_state = lp.parse("def") assert parser_state.state_stack == [0] @@ -83,16 +89,19 @@ def test_partial_parsing(): assert last_token.type == "partial" assert last_token.value.fsm_state_seq == (0, 26, 99, 100) assert last_token.value.is_not_finished is True + assert not parser_state.value_stack # Now, try something incremental last_lexer_state = parser_state.lexer.state last_lexer_state.text += " blah()" lp.parse_from_state(parser_state, is_end=False) last_token = parser_state.lexer.state.last_token + assert not parser_state.value_stack last_lexer_state = parser_state.lexer.state last_valid_token = last_lexer_state.last_token assert last_valid_token.type == "RPAR" + assert not parser_state.value_stack # Something incremental and a little more complicated parser_state = lp.parse("x = 1\ndef foo(x):\n ") @@ -120,6 +129,20 @@ def test_partial_parsing(): with pytest.raises(UnexpectedToken): lp.parse("def \n") + lp = PartialLark.open_from_package( + "tests", + "partial_python.lark", + ["text"], + parser="lalr", + postlex=PartialPythonIndenter(), + start="file_input", + use_value_stack=True, + ) + parser_state = lp.parse("x = ('hi") + lp.parse_from_state(parser_state, is_end=False) + assert len(parser_state.state_stack) == 4 + assert parser_state.value_stack[-1].type == "LPAR" + def test_sequential_parse_example(): input_tokens = [