Skip to content

Commit

Permalink
Refactor find_partial_matches so that it returns full sequences
Browse files Browse the repository at this point in the history
This refactoring also removed the need for the antecedent mapping option in
`map_partial_states_to_vocab`.
  • Loading branch information
brandonwillard committed Jul 13, 2023
1 parent d4f5b62 commit 39e7302
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 114 deletions.
66 changes: 16 additions & 50 deletions outlines/text/parsing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections import ChainMap, defaultdict
from copy import copy
from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -282,23 +281,23 @@ def find_partial_matches(
Returns
-------
A set of tuples corresponding to each valid starting state in the FSM.
The first element of each tuple contains either ``None`` or an integer
A set of tuples corresponding to each valid starting state in the FSM. The
first element of each tuple contains either ``None`` or an integer
indicating the position in `input_string` at which the FSM terminated. The
second element is a tuple of the states visited during execution of the
FSM.
second element is the tuple of states visited during execution of the FSM
plus the next, unvisited transition state.
"""
if len(input_string) == 0 or input_string[0] not in fsm.alphabet:
return set()

trans_key = fsm.alphabet[input_string[0]]

# TODO: We could probably memoize this easily (i.e. no need to recompute
# paths shared by different starting states)
# TODO: We could probably reuse parts of the computed paths when computing
# results for multiple starting points.
def _partial_match(
trans: Dict[int, int]
) -> Optional[Tuple[Optional[int], Tuple[int, ...]]]:
) -> Tuple[Optional[int], Optional[Tuple[int, ...]]]:
fsm_map = ChainMap({fsm.initial: trans}, fsm.map)
state = fsm.initial
accepted_states: Tuple[int, ...] = ()
Expand All @@ -313,27 +312,27 @@ def _partial_match(
if state in fsm.finals:
i -= 1
break
return None
return None, None

state = fsm_map[state][trans_key]

accepted_states += (state,)

terminated = state in fsm.finals
if not terminated and state == fsm.initial:
return None
return None, None

return None if not terminated else i, accepted_states

res = set()
transition_maps = (
fsm.map.values() if start_state is None else [fsm.map[start_state]]
fsm.map if start_state is None else {start_state: fsm.map[start_state]}
)
for trans in transition_maps:
for state, trans in transition_maps.items():
if trans_key in trans:
path = _partial_match(trans)
n_matched, path = _partial_match(trans)
if path is not None:
res.add(path)
res.add((n_matched, (state,) + path))

return res

Expand All @@ -346,7 +345,7 @@ def terminals_to_fsms(lp: Lark) -> Dict[str, FSM]:
pattern = interegular.parse_pattern(terminal.pattern.to_regexp())
# TODO: Use `pyparser.terminals[0].pattern.flags`?
try:
fsm = pattern.to_fsm()
fsm = pattern.to_fsm().reduce()
except Unsupported:
fsm = None

Expand All @@ -358,7 +357,6 @@ def terminals_to_fsms(lp: Lark) -> Dict[str, FSM]:
def map_partial_states_to_vocab(
vocabulary: Iterable[str],
terminals_to_fsms_map: Dict[str, FSM],
map_to_antecedents: bool = False,
partial_match_filter: Callable[
[str, Optional[int], Tuple[int, ...]], bool
] = lambda *args: True,
Expand All @@ -375,10 +373,6 @@ def map_partial_states_to_vocab(
The vocabulary composed of strings.
terminals_to_fsms_map
Terminal symbol names mapped to FSMs, as provided by `terminals_to_fsms`.
map_to_antecedents
When ``True``, return a map with keys that are the antecedent partial
parse states. In other words, this is a map that can be used to
determine valid next tokens given a parse state.
partial_match_filter
A callable that determines which partial matches to keep. The first
argument is the string being match, the rest are the unpacked partial
Expand All @@ -400,41 +394,13 @@ def map_partial_states_to_vocab(
if partial_match_filter(vocab_string, end_idx, state_seq):
pstate_to_vocab[(symbol_name, state_seq[0])].add(i)

if not map_to_antecedents:
return pstate_to_vocab

# Partial parse states to their valid next/transition states
ts_pstate_to_substates = dict(
chain.from_iterable(
[
((symbol_name, s), {(symbol_name, v) for v in ts.values()})
for s, ts in fsm.map.items()
]
for symbol_name, fsm in terminals_to_fsms_map.items()
)
)

# Reverse the state transitions map
# TODO: We could construct this more directly.
rev_ts_pstate_to_substates = defaultdict(set)
for pstate, to_pstates in ts_pstate_to_substates.items():
for to_pstate in to_pstates:
rev_ts_pstate_to_substates[to_pstate].add(pstate)

# A version of `pstate_to_vocab` that is keyed on states that *transition to*
# the original keys of `pstate_to_vocab`.
_pstate_to_vocab: DefaultDict[PartialParseState, Set[int]] = defaultdict(set)
for pstate, vocab in pstate_to_vocab.items():
for next_pstate in rev_ts_pstate_to_substates[pstate]:
_pstate_to_vocab[next_pstate] |= vocab

if final_state_string_idx is not None:
# Allow transitions to EOS from all terminals FSM states
for symbol_name, fsm in terminals_to_fsms_map.items():
for state in fsm.finals:
_pstate_to_vocab[(symbol_name, state)].add(final_state_string_idx)
pstate_to_vocab[(symbol_name, state)].add(final_state_string_idx)

return _pstate_to_vocab
return pstate_to_vocab


def terminals_to_lalr_states(lp: Lark) -> DefaultDict[str, Set[int]]:
Expand Down
128 changes: 64 additions & 64 deletions tests/text/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,34 +129,54 @@ def test_sequential_parse_example():

def test_partial_match():
name_pattern = interegular.parse_pattern(r"[^\W\d]\w*")
name_fsm = name_pattern.to_fsm()
name_fsm = name_pattern.to_fsm().reduce()
assert name_fsm.initial == 0

def_pattern = interegular.parse_pattern("def")
def_fsm = def_pattern.to_fsm()
def_fsm = def_pattern.to_fsm().reduce()
assert def_fsm.initial == 0

assert find_partial_matches(def_fsm, "def") == {(2, (1, 2, 3))}
assert find_partial_matches(def_fsm, "de") == {(None, (1, 2))}
assert find_partial_matches(def_fsm, "d") == {(None, (1,))}
assert find_partial_matches(def_fsm, "def") == {(2, (0, 1, 2, 3))}
assert find_partial_matches(def_fsm, "de") == {(None, (0, 1, 2))}
assert find_partial_matches(def_fsm, "d") == {(None, (0, 1))}
assert find_partial_matches(def_fsm, "") == set()
assert find_partial_matches(def_fsm, "df") == set()
assert find_partial_matches(def_fsm, "ef") == {(1, (2, 3))}
assert find_partial_matches(def_fsm, "e") == {(None, (2,))}
assert find_partial_matches(def_fsm, "f") == {(0, (3,))}
assert find_partial_matches(def_fsm, "ef foo") == {(1, (2, 3))}
assert find_partial_matches(def_fsm, "ef") == {(1, (1, 2, 3))}
assert find_partial_matches(def_fsm, "e") == {(None, (1, 2))}
assert find_partial_matches(def_fsm, "f") == {(0, (2, 3))}
assert find_partial_matches(def_fsm, "ef foo") == {(1, (1, 2, 3))}

# This string has a `DEF` token in it, but should ultimately not lex one
assert find_partial_matches(def_fsm, "defb") == {(2, (1, 2, 3))}
assert find_partial_matches(def_fsm, "defb") == {(2, (0, 1, 2, 3))}

# `NAME` can have multiple start states for this input
assert find_partial_matches(name_fsm, "d") == {(0, (1,)), (0, (2,))}
assert find_partial_matches(name_fsm, "d") == {
(0, (0, 1)),
(0, (1, 1)),
}
# Not this case
assert find_partial_matches(name_fsm, "1d") == {(1, (2, 2))}
assert find_partial_matches(name_fsm, "1d") == {(1, (1, 1, 1))}

assert find_partial_matches(name_fsm, "blah") == {
(3, (1, 2, 2, 2)),
(3, (2, 2, 2, 2)),
(3, (0, 1, 1, 1, 1)),
(3, (1, 1, 1, 1, 1)),
}

float_pattern = interegular.parse_pattern(
r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))"
)
float_fsm = float_pattern.to_fsm().reduce()

# XXX: It look like there's a lot of set/frozenset usage that prevents us
# from adequately reproducing the exact state sequences in this case.
# It seems to stem from `_CharGroup`s and the FSM map construction process.
res = find_partial_matches(float_fsm, ".")
assert {v[0] for v in res} == {0, 0, None}
# Make sure that the terminated sequences actually end in final states
assert all(v[1][-1] in float_fsm.finals for v in res if v[0] == 0)
# Make sure that the non-terminated sequences don't end in final states
assert all(v[1][-1] not in float_fsm.finals for v in res if v[0] != 0)


def test_map_partial_states_to_vocab_python():
pyparser = Lark.open_from_package(
Expand All @@ -174,54 +194,45 @@ def test_map_partial_states_to_vocab_python():
k: v for k, v in symbol_names_and_fsms.items() if k in test_symbols
}

vocabulary = ["d", "e", "ef foo", "f ", " "]
assert len(symbol_names_and_fsms["DEF"].states) == 4
assert len(symbol_names_and_fsms["NAME"].states) == 2
assert len(symbol_names_and_fsms["__IGNORE_0"].states) == 2

pstate_to_vocab = map_partial_states_to_vocab(
vocabulary, symbol_names_and_fsms, False
)
vocabulary = ["d", "e", "ef foo", "f ", " ", "1d", "<EOS>"]

assert dict(pstate_to_vocab) == {
("__IGNORE_0", 2): {4},
("__IGNORE_0", 1): {4},
("NAME", 2): {0, 1, 2, 3},
("NAME", 1): {0, 1, 2, 3},
("DEF", 1): {0},
("DEF", 2): {1, 2},
("DEF", 3): {3},
}

pstate_to_vocab = map_partial_states_to_vocab(
vocabulary, symbol_names_and_fsms, True
)
pstate_to_vocab = map_partial_states_to_vocab(vocabulary, symbol_names_and_fsms)

assert dict(pstate_to_vocab) == {
("__IGNORE_0", 1): {4},
("__IGNORE_0", 2): {4},
("__IGNORE_0", 0): {4},
("NAME", 1): {0, 1, 2, 3},
("NAME", 2): {0, 1, 2, 3},
("__IGNORE_0", 1): {4},
("NAME", 0): {0, 1, 2, 3},
("NAME", 1): {0, 1, 2, 3, 5},
("DEF", 0): {0},
("DEF", 1): {1, 2},
("DEF", 2): {3},
}

vocabulary = list(vocabulary) + ["<EOS>"]
pstate_to_vocab = map_partial_states_to_vocab(
vocabulary, symbol_names_and_fsms, True, final_state_string="<EOS>"
vocabulary, symbol_names_and_fsms, final_state_string="<EOS>"
)

assert dict(pstate_to_vocab) == {
("__IGNORE_0", 1): {4, 5},
("__IGNORE_0", 2): {4, 5},
("__IGNORE_0", 0): {4},
("NAME", 1): {0, 1, 2, 3, 5},
("NAME", 2): {0, 1, 2, 3, 5},
("__IGNORE_0", 0): {
4,
},
("__IGNORE_0", 1): {4, 6},
("NAME", 0): {0, 1, 2, 3},
("DEF", 0): {0},
("NAME", 1): {0, 1, 2, 3, 5, 6},
("DEF", 0): {
0,
},
("DEF", 1): {1, 2},
("DEF", 2): {3},
("DEF", 3): {5},
("DEF", 2): {
3,
},
("DEF", 3): {
6,
},
}


Expand Down Expand Up @@ -286,7 +297,7 @@ def test_parse_from_partial_match():
def test_map_partial_states_to_vocab_regex():
regex_string = r"([0-9]+([.][0-9]*)?|[.][0-9]+)"
regex_pattern = interegular.parse_pattern(regex_string)
regex_fsm = regex_pattern.simplify().to_fsm()
regex_fsm = regex_pattern.to_fsm().reduce()

vocabulary = [
"1.",
Expand All @@ -312,19 +323,15 @@ def partial_match_filter(string, end_idx, state_seq):
return True

pstate_to_vocab = map_partial_states_to_vocab(
vocabulary, {"FLOAT": regex_fsm}, True, partial_match_filter, "<EOS>"
vocabulary, {"FLOAT": regex_fsm}, partial_match_filter, "<EOS>"
)

assert tuple(pstate_to_vocab.values()) == (
{0, 1, 2, 3, 4, 5, 8},
{0, 1, 2, 3, 4, 5, 8, 12},
assert sorted(pstate_to_vocab.values(), key=lambda x: -len(x)) == [
{0, 1, 2, 3, 4, 5, 8, 12},
{1, 5, 8, 12},
{1, 5, 8, 12},
{1, 5, 8, 12},
{0, 1, 2, 3, 4, 5, 8},
{1, 5, 8, 12},
{1, 5, 8},
)
]

pstate_to_vocab = {k: tuple(v) for k, v in pstate_to_vocab.items()}

Expand All @@ -348,16 +355,9 @@ def partial_match_filter(string, end_idx, state_seq):

sample_seq += next_sample

# Parse the entire sampled sequence/string
# TODO: We could continue from the previous parse state, but this is
# easier for now and only for demonstration purposes.
partial_matches = find_partial_matches(
regex_fsm, sample_seq, start_state=regex_fsm.initial
)

# Use the/a longest match
pmatch = max(
partial_matches, key=lambda x: x[0] if x[0] is not None else -1
# Continue matching from where we left off
(pmatch,) = find_partial_matches(
regex_fsm, next_sample, start_state=pstate[-1]
)

# Create the next state
Expand Down

0 comments on commit 39e7302

Please sign in to comment.