From 076bd98232a0e24d79e13a67348d0b5671d00281 Mon Sep 17 00:00:00 2001 From: Robin Picard Date: Mon, 8 Jan 2024 13:56:41 +0100 Subject: [PATCH] Use the function's arguments to cache create_states_mapping --- outlines/fsm/fsm.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/outlines/fsm/fsm.py b/outlines/fsm/fsm.py index 7fc7a41e9..a8dc89c64 100644 --- a/outlines/fsm/fsm.py +++ b/outlines/fsm/fsm.py @@ -103,16 +103,9 @@ def __init__( regex_string: str, tokenizer: "Tokenizer", ): - def func_cache_key_args( - regex_string: str, tokenizer: "Tokenizer" - ) -> Tuple[str, tuple]: - """Return the values that will be used to create the cache key of create_states_mapping""" - cacheable_vocabulary = tuple(sorted(tokenizer.vocabulary.items())) - return (regex_string, cacheable_vocabulary) - - @cache(func_cache_key_args) + @cache() def create_states_mapping( - regex_string: str, tokenizer: "Tokenizer" + regex_string: str, cacheable_vocabulary: Tuple[Tuple[str, int]] ) -> Tuple[dict, set, set]: """Create the variables related to the mapping between states and tokens The parameters of the function are used for caching purpose @@ -144,7 +137,9 @@ def create_states_mapping( self.states_to_token_maps, self.empty_token_ids, self.final_states, - ) = create_states_mapping(regex_string, tokenizer) + ) = create_states_mapping( + regex_string, tuple(sorted(tokenizer.vocabulary.items())) + ) self.num_tokens_generated = 0 self.vocabulary = tokenizer.vocabulary.values() self.end_token_id = tokenizer.eos_token_id