Skip to content

Commit

Permalink
Modification to the hash_argument function of caching
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinPicard authored and rlouf committed Jan 11, 2024
1 parent 32d84dd commit 865d456
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 17 deletions.
23 changes: 9 additions & 14 deletions outlines/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Callable, Optional

import cloudpickle
import torch
from diskcache import Cache

home_dir = os.path.expanduser("~")
Expand All @@ -13,15 +12,11 @@
_caching_enabled = True


def hash_data(*data) -> str:
"""Pickles and hashes all the data passed to it as args.
Pickling and then hashing significantly reduces the size of the key especially when dealing with large tensors.
"""
result = hashlib.md5() # nosec B303
for datum in data:
if isinstance(datum, torch.Tensor):
datum = datum.cpu().numpy()
result.update(cloudpickle.dumps(datum))
def hash_arguments(*args, **kwargs) -> str:
"""Create a hash out of the args and kwargs provided"""
result = hashlib.md5()
for item in list(args) + sorted(kwargs.items()):
result.update(cloudpickle.dumps(item))
return result.hexdigest()


Expand All @@ -45,9 +40,9 @@ def wrapper(*args, **kwargs):
return cached_function(*args, **kwargs)
if key_function:
key_args = key_function(*args, **kwargs)
cache_key = hash_data(*key_args)
cache_key = hash_arguments(*key_args)
else:
cache_key = hash_data(*args, **kwargs)
cache_key = hash_arguments(*args, **kwargs)
if cache_key in memory:
return memory[cache_key]
result = cached_function(*args, **kwargs)
Expand All @@ -59,9 +54,9 @@ async def async_wrapper(*args, **kwargs):
return await cached_function(*args, **kwargs)
if key_function:
key_args = key_function(*args, **kwargs)
cache_key = hash_data(*key_args)
cache_key = hash_arguments(*key_args)
else:
cache_key = hash_data(*args, **kwargs)
cache_key = hash_arguments(*args, **kwargs)
if cache_key in memory:
return memory[cache_key]
result = await cached_function(*args, **kwargs)
Expand Down
6 changes: 3 additions & 3 deletions outlines/fsm/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,16 @@ def __init__(
):
def func_cache_key_args(
regex_string: str, tokenizer: "Tokenizer"
) -> Tuple[str, list]:
) -> Tuple[str, tuple]:
"""Return the values that will be used to create the cache key of create_states_mapping"""
cacheable_vocabulary = sorted(tokenizer.vocabulary.values())
cacheable_vocabulary = tuple(sorted(tokenizer.vocabulary.items()))
return (regex_string, cacheable_vocabulary)

@cache(func_cache_key_args)
def create_states_mapping(
regex_string: str, tokenizer: "Tokenizer"
) -> Tuple[dict, set, set]:
"""Create the variables related the mapping between stzates and tokens
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
"""
regex_pattern = interegular.parse_pattern(regex_string)
Expand Down

0 comments on commit 865d456

Please sign in to comment.