From 83c4d3a37f6d340e3ad21e5d93209f51c4bfe557 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sun, 26 May 2024 03:03:24 -0500 Subject: [PATCH] ensure byte fsm unicode_type compatibility by prefixing hex-bytes with \x00 --- outlines/fsm/regex.py | 54 +++++++++++++++++++++++++++++++---------- tests/fsm/test_regex.py | 53 ++++++++++++++++++++++++++++------------ 2 files changed, 79 insertions(+), 28 deletions(-) diff --git a/outlines/fsm/regex.py b/outlines/fsm/regex.py index a396b8e9d..5ee526362 100644 --- a/outlines/fsm/regex.py +++ b/outlines/fsm/regex.py @@ -196,7 +196,7 @@ def transition_trie_setdefault( def byte_symbol(byte: int) -> str: - return f"{byte:02X}" if byte >= 0x80 else chr(byte) + return f"\x00{byte:02X}" if byte >= 0x80 else chr(byte) def make_byte_level_fsm(fsm: FSM, keep_utf8=False) -> FSM: @@ -416,7 +416,7 @@ def _walk_fsm( alphabet_anything_value: int, fsm_initial: int, fsm_finals: Set[int], - input_string: Sequence[str], + input_string: str, start_state: int, full_match: bool = True, ) -> List[int]: @@ -424,7 +424,21 @@ def _walk_fsm( accepted_states: List[int] = numba.typed.List.empty_list(numba.int64) last_final_idx: int = numba.uint64(0) - for i, symbol in enumerate(input_string): + # Iterate over symbols (characters and null-prefixed two-hex-character bytes) + # By default, each symbol is a unicode character + # Except, if the character, input_string[i] == '\x00', then the next two + # in input_string characters are a hex representation of the byte + i = 0 + while i < len(input_string): + # if null-byte prefixed its a hex representation + # unless its the last character, then its a trailing null byte symbol + if input_string[i] == "\x00" and i != len(input_string) - 1: + symbol = input_string[i : i + 3] + i += 3 + else: + symbol = input_string[i] + i += 1 + trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value) new_state = fsm_transitions.get((state, trans_key)) @@ -438,11 +452,11 @@ def _walk_fsm( state = new_state if state in fsm_finals: - last_final_idx = numba.uint64(i + 1) + last_final_idx = numba.uint64(i) accepted_states.append(_nonoptional(state)) - if full_match and last_final_idx - 1 != i: + if full_match and last_final_idx != i: return numba.typed.List.empty_list(numba.int64) return accepted_states @@ -450,7 +464,7 @@ def _walk_fsm( def walk_fsm( fsm: BetterFSM, - input_string: Sequence[str], + input_string: str, start_state: int, full_match: bool = True, ) -> List[int]: @@ -464,7 +478,17 @@ def walk_fsm( alphabet_anything_value = fsm.alphabet.anything_value fsm_transitions = fsm.flat_transition_map - for i, symbol in enumerate(input_string): + # See _walk_fsm() explanation of symbol iteration + i = 0 + while i < len(input_string): + # if null-byte prefixed its a hex representation + # unless the input string itself is a null byte, then symbol is a lone null-byte + if input_string[i] == "\x00" and input_string != "\x00": + symbol = input_string[i : i + 3] + i += 3 + else: + symbol = input_string[i] + i += 1 trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value) new_state = fsm_transitions.get((state, trans_key)) @@ -478,11 +502,11 @@ def walk_fsm( state = new_state if state in fsm_finals: - last_final_idx = i + 1 + last_final_idx = i accepted_states.append(state) - if full_match and last_final_idx - 1 != i: + if full_match and last_final_idx != i: return [] return accepted_states @@ -652,7 +676,7 @@ def state_scan_tokens( alphabet_anything_value: int, fsm_initial: int, fsm_finals: Set[int], - vocabulary: List[Tuple[Sequence[str], Sequence[int]]], + vocabulary: List[Tuple[str, Sequence[int]]], start_state: int, ) -> Set[Tuple[int, int]]: res = set() @@ -669,7 +693,11 @@ def state_scan_tokens( False, ) - if state_seq is not None and len(state_seq) < len(token): + if token == "\x00": + token_length = 1 + else: + token_length = len(token) - 2 * token.count("\x00") + if state_seq is not None and len(state_seq) < token_length: continue for token_id in token_ids: @@ -680,7 +708,7 @@ def state_scan_tokens( def create_fsm_index_end_to_end( fsm_info: FSMInfo, - vocabulary: List[Tuple[Sequence[str], Sequence[int]]], + vocabulary: List[Tuple[str, Sequence[int]]], ) -> Dict[int, Set[Tuple[int, int]]]: """Create an FSM state-to-vocabulary map/index through end-to-end token parsing.""" @@ -768,7 +796,7 @@ def gpt2_unicode_to_bytes(): @lru_cache def reduced_vocabulary( tokenizer: "Tokenizer", -) -> Tuple[List[Tuple[Sequence[str], Sequence[int]]], Set[int]]: +) -> Tuple[List[Tuple[str, Sequence[int]]], Set[int]]: """Create a map from decoded vocabulary tokens to lists of equivalent token ids.""" empty_token_ids = set() vocabulary: Dict[Union[str, Tuple[str, ...]], List[int]] = {} diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py index 2fc8a5384..ce8ed6f8f 100644 --- a/tests/fsm/test_regex.py +++ b/tests/fsm/test_regex.py @@ -25,7 +25,7 @@ def identity(s): def to_bytes(s): - return [chr(b) if b < 0x80 else f"{b:02X}" for b in s.encode("utf-8")] + return [chr(b) if b < 0x80 else f"\x00{b:02X}" for b in s.encode("utf-8")] def walk_fsm_numba( @@ -115,19 +115,27 @@ def test_walk_fsm_multi_bytes(function, transform): str_regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) regex_fsm = make_byte_level_better_fsm(str_regex_fsm, keep_utf8=True) - res = tuple(function(regex_fsm, transform("😂"), regex_fsm.initial, full_match=True)) + res = tuple( + function(regex_fsm, "".join(transform("😂")), regex_fsm.initial, full_match=True) + ) assert res[-1:] == (1,) res = tuple( - function(regex_fsm, transform("😂😂"), regex_fsm.initial, full_match=False) + function( + regex_fsm, "".join(transform("😂😂")), regex_fsm.initial, full_match=False + ) ) assert res[-1:] == (1,) - res = tuple(function(regex_fsm, transform("!"), regex_fsm.initial, full_match=True)) + res = tuple( + function(regex_fsm, "".join(transform("!")), regex_fsm.initial, full_match=True) + ) assert res == tuple() res = tuple( - function(regex_fsm, transform("😂😂"), regex_fsm.initial, full_match=True) + function( + regex_fsm, "".join(transform("😂😂")), regex_fsm.initial, full_match=True + ) ) assert res == tuple() @@ -304,15 +312,15 @@ def test_create_fsm_index_end_to_end(): vocabulary_nb = numba.typed.List.empty_list( numba.types.Tuple( ( - numba.types.UnicodeCharSeq(2)[:], + numba.types.unicode_type, numba.int64[:], ) ) ) for token_tuple, token_ids in vocabulary.items(): - token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2")) + token = "".join(token_tuple) token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) - vocabulary_nb.append((token_tuple_np, token_ids_np)) + vocabulary_nb.append((token, token_ids_np)) res = create_fsm_index_end_to_end(regex_fsm.fsm_info, vocabulary_nb) @@ -326,28 +334,34 @@ def test_create_fsm_index_end_to_end_multi_byte(): regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) byte_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True) + merge_symbols = lambda byte_hexs: "".join( + ["" + b if len(b) == 2 else b for b in byte_hexs] + ) + vocabulary = { "blah": numba.typed.List([0]), "😈a": numba.typed.List([1]), "😇": numba.typed.List([2]), "😍": numba.typed.List([3]), - ("F0", "9F", "98", "8D"): numba.typed.List([4]), # '😍' + merge_symbols(("F0", "9F", "98", "8D")): numba.typed.List([4]), # '😍' " 😍": numba.typed.List([5]), - (" ", "F0", "9F", "98", "8D"): numba.typed.List([6]), # ' 😍' - (" ", "F0", "9F", "98"): numba.typed.List([7]), # ' 😍' incomplete + merge_symbols((" ", "F0", "9F", "98", "8D")): numba.typed.List([6]), # ' 😍' + merge_symbols((" ", "F0", "9F", "98")): numba.typed.List( + [7] + ), # ' 😍' incomplete "": numba.typed.List([8]), } vocabulary_nb = numba.typed.List.empty_list( numba.types.Tuple( ( - numba.types.UnicodeCharSeq(2)[:], + numba.types.unicode_type, numba.int64[:], ) ) ) for token_tuple, token_ids in vocabulary.items(): - token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2")) + token_tuple_np = merge_symbols(token_tuple) token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) vocabulary_nb.append((token_tuple_np, token_ids_np)) @@ -356,7 +370,16 @@ def test_create_fsm_index_end_to_end_multi_byte(): assert res == {0: {(5, 3), (6, 3), (7, 7), (2, 2)}, 3: {(2, 3), (3, 3), (4, 3)}} -def test_create_fsm_index_tokenizer(): +@pytest.mark.parametrize( + "hf_tokenizer_uri", + [ + "gpt2", + "microsoft/phi-2", + "Qwen/Qwen1.5-0.5B-Chat", + "NousResearch/Hermes-2-Pro-Llama-3-8B", + ], +) +def test_create_fsm_index_tokenizer(hf_tokenizer_uri): # The combined regular expressions of a lexer state in a Python grammar regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" @@ -371,7 +394,7 @@ def test_create_fsm_index_tokenizer(): num_bytes_fsm_states = len(bytes_fsm.states) assert num_bytes_fsm_states == 235 - tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_uri) tokenizer = TransformerTokenizer(tokenizer) states_to_token_subsets, empty_token_ids = create_fsm_index_tokenizer(