Skip to content

Commit

Permalink
Make parse tree/value computations optional
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Sep 11, 2023
1 parent 77e1593 commit e6ff583
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 16 deletions.
120 changes: 104 additions & 16 deletions outlines/text/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,21 +147,35 @@ def make_deterministic_fsm(fsm: FSM) -> Tuple[FSM, Dict[int, int]]:


class PartialParserConf(ParserConf):
__serialize_fields__ = "rules", "start", "parser_type", "deterministic"
__serialize_fields__ = (
"rules",
"start",
"parser_type",
"deterministic",
"use_value_stack",
)

def __init__(self, rules, callbacks, start, deterministic):
def __init__(self, rules, callbacks, start, deterministic, use_value_stack):
super().__init__(rules, callbacks, start)
self.deterministic = deterministic
self.use_value_stack = use_value_stack


class PartialLark(Lark):
__serialize_fields__ = "parser", "rules", "options", "deterministic"
__serialize_fields__ = (
"parser",
"rules",
"options",
"deterministic",
"use_value_stack",
)

def __init__(self, grammar, **options):
# TODO: Could've extended `LarkOptions`, but all these extensions are
# already way too much (and brittle). This library really needs a
# complete refactoring.
self.deterministic = options.pop("deterministic", False)
self.use_value_stack = options.pop("use_value_stack", False)
options["regex"] = True
super().__init__(grammar, **options)
assert self.options.parser == "lalr"
Expand All @@ -180,7 +194,11 @@ def _build_parser(self) -> "PartialParsingFrontend":
self._prepare_callbacks()
_validate_frontend_args(self.options.parser, self.options.lexer)
parser_conf = PartialParserConf(
self.rules, self._callbacks, self.options.start, self.deterministic
self.rules,
self._callbacks,
self.options.start,
self.deterministic,
self.use_value_stack,
)

# This is `_construct_parsing_frontend` expanded/inlined
Expand Down Expand Up @@ -393,7 +411,12 @@ def to_tuple(v):
zip(self._parse_table.states.keys(), new_states.keys())
)

self.parser = PartialParser(self._parse_table, callbacks, debug)
self.parser = PartialParser(
self._parse_table,
callbacks,
debug,
use_value_stack=parser_conf.use_value_stack,
)

@classmethod
def deserialize(cls, data, memo, callbacks, debug=False):
Expand All @@ -404,16 +427,20 @@ def deserialize(cls, data, memo, callbacks, debug=False):


class PartialParserState(ParserState):
def __copy__(self):
return type(self)(
self.parse_conf,
copy(self.lexer),
copy(self.state_stack),
deepcopy(self.value_stack),
)
__slots__ = "use_value_stack"

def __repr__(self):
return f"{type(self).__name__}(lexer={self.lexer!r}, state_stack={self.state_stack!r})"
def __init__(
self,
parse_conf,
lexer,
state_stack=None,
value_stack=None,
use_value_stack=False,
):
super().__init__(
parse_conf, lexer, state_stack=state_stack, value_stack=value_stack
)
self.use_value_stack = use_value_stack

def feed_token(self, token, is_end=False):
if token.type == "partial":
Expand All @@ -438,16 +465,77 @@ def feed_token(self, token, is_end=False):
)
return

super().feed_token(token, is_end=is_end)
if self.use_value_stack:
super().feed_token(token, is_end=is_end)
else:
self.feed_token_no_stack(token, is_end=is_end)

def feed_token_no_stack(self, token, is_end=False):
"""
This is a copy of `ParserState.feed_token` with all the value stack
steps removed. Since we're not exactly parsing in order to obtain a
CST or anything similar, we can avoid the growing expense of tracking
the parse tree.
"""
state_stack = self.state_stack
states = self.parse_conf.states
end_state = self.parse_conf.end_state

while True:
state = state_stack[-1]
try:
action, arg = states[state][token.type]
except KeyError:
expected = {s for s in states[state].keys() if s.isupper()}
raise UnexpectedToken(
token, expected, state=self, interactive_parser=None
)

assert arg != end_state

if action is Shift:
# shift once and return
assert not is_end
state_stack.append(arg)
return
else:
# reduce+shift as many times as necessary
rule = arg
size = len(rule.expansion)
if size:
del state_stack[-size:]

_action, new_state = states[state_stack[-1]][rule.origin.name]
assert _action is Shift
state_stack.append(new_state)

if is_end and state_stack[-1] == end_state:
return

def __copy__(self):
return type(self)(
self.parse_conf,
copy(self.lexer),
copy(self.state_stack),
deepcopy(self.value_stack),
use_value_stack=self.use_value_stack,
)

def __repr__(self):
return f"{type(self).__name__}(lexer={self.lexer!r}, state_stack={self.state_stack!r})"


class PartialParser(_Parser):
def __init__(self, parse_table, callbacks, debug=False, use_value_stack=False):
super().__init__(parse_table, callbacks, debug=debug)
self.use_value_stack = use_value_stack

def parse(
self, lexer, start, value_stack=None, state_stack=None, start_interactive=False
):
parse_conf = ParseConf(self.parse_table, self.callbacks, start)
parser_state = PartialParserState(
parse_conf, copy(lexer), state_stack, value_stack
parse_conf, copy(lexer), state_stack, value_stack, self.use_value_stack
)
if start_interactive:
return InteractiveParser(self, parser_state, parser_state.lexer)
Expand Down
23 changes: 23 additions & 0 deletions tests/text/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_partial_parsing():
assert last_token.type == "partial"
assert last_token.value.fsm_state_seq == (0, 15)
assert last_token.value.is_not_finished is True
assert not parser_state.value_stack

# End with an ignored token
parser_state = lp.parse("x ")
Expand All @@ -45,6 +46,7 @@ def test_partial_parsing():
assert last_token.type == "partial"
assert last_token.value.fsm_state_seq == (0, 1)
assert last_token.value.is_not_finished is True
assert not parser_state.value_stack

# Could be a complete `=` or the start of a `==`
parser_state = lp.parse("x =")
Expand All @@ -55,44 +57,51 @@ def test_partial_parsing():
term_info.terminal_name == "EQUAL"
for term_info in last_token.value.terminals_and_info
)
assert not parser_state.value_stack

parser_state = lp.parse("x = '")
assert parser_state.state_stack == [0, 58, 59]
last_token = parser_state.lexer.state.last_token
assert last_token.type == "partial"
assert last_token.value.fsm_state_seq == (0, 6)
assert last_token.value.is_not_finished is True
assert not parser_state.value_stack

parser_state = lp.parse("x = 'hi")
assert parser_state.state_stack == [0, 58, 59]
last_token = parser_state.lexer.state.last_token
assert last_token.type == "partial"
assert last_token.value.fsm_state_seq == (0, 6, 6, 6)
assert last_token.value.is_not_finished is True
assert not parser_state.value_stack

parser_state = lp.parse("x = ('hi")
assert parser_state.state_stack == [0, 58, 59, 254]
last_token = parser_state.lexer.state.last_token
assert last_token.type == "partial"
assert last_token.value.fsm_state_seq == (0, 6, 6, 6)
assert last_token.value.is_not_finished is True
assert not parser_state.value_stack

parser_state = lp.parse("def")
assert parser_state.state_stack == [0]
last_token = parser_state.lexer.state.last_token
assert last_token.type == "partial"
assert last_token.value.fsm_state_seq == (0, 26, 99, 100)
assert last_token.value.is_not_finished is True
assert not parser_state.value_stack

# Now, try something incremental
last_lexer_state = parser_state.lexer.state
last_lexer_state.text += " blah()"
lp.parse_from_state(parser_state, is_end=False)
last_token = parser_state.lexer.state.last_token
assert not parser_state.value_stack

last_lexer_state = parser_state.lexer.state
last_valid_token = last_lexer_state.last_token
assert last_valid_token.type == "RPAR"
assert not parser_state.value_stack

# Something incremental and a little more complicated
parser_state = lp.parse("x = 1\ndef foo(x):\n ")
Expand Down Expand Up @@ -120,6 +129,20 @@ def test_partial_parsing():
with pytest.raises(UnexpectedToken):
lp.parse("def \n")

lp = PartialLark.open_from_package(
"tests",
"partial_python.lark",
["text"],
parser="lalr",
postlex=PartialPythonIndenter(),
start="file_input",
use_value_stack=True,
)
parser_state = lp.parse("x = ('hi")
lp.parse_from_state(parser_state, is_end=False)
assert len(parser_state.state_stack) == 4
assert parser_state.value_stack[-1].type == "LPAR"


def test_sequential_parse_example():
input_tokens = [
Expand Down

0 comments on commit e6ff583

Please sign in to comment.