Skip to content

Commit

Permalink
Use diskcache instead of perscache
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinPicard committed Dec 30, 2023
1 parent 3c29617 commit d032396
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 41 deletions.
81 changes: 70 additions & 11 deletions outlines/caching.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,79 @@
import asyncio
import hashlib
import os
from typing import Callable, Optional

from perscache import Cache, NoCache
from perscache.serializers import JSONSerializer
from perscache.storage import LocalFileStorage
import cloudpickle
import torch
from diskcache import Cache

home_dir = os.path.expanduser("~")
cache_dir = os.environ.get("OUTLINES_CACHE_DIR", f"{home_dir}/.cache/outlines")
memory = Cache(serializer=JSONSerializer(), storage=LocalFileStorage(cache_dir))
memory = Cache(cache_dir, eviction_policy="none", cull_limit=0)
_caching_enabled = True


def cache(ignore: Optional[str] = None):
def cache_fn(fn: Callable):
return memory.cache(ignore=ignore)(fn)
def hash_data(*data) -> str:
"""Pickles and hashes all the data passed to it as args.
Pickling and then hashing significantly reduces the size of the key especially when dealing with large tensors.
"""
result = hashlib.md5() # nosec B303
for datum in data:
if isinstance(datum, torch.Tensor):
datum = datum.cpu().numpy()
result.update(cloudpickle.dumps(datum))
return result.hexdigest()


def cache(key_function: Optional[Callable] = None):
"""Caching decorator for memoizing function calls.
The cache key is created based on the values returned by the key_function callable
if provided or based on the arguments of the decorated function directly otherwise
Parameters
----------
key_function
A callable function used to generate a unique key for each function call. It's
called with the arguments of the decorated function as arguments
Returns
-------
A decorator function that can be applied to other functions.
"""

return cache_fn
def decorator(cached_function: Callable):
def wrapper(*args, **kwargs):
if not _caching_enabled:
return cached_function(*args, **kwargs)
if key_function:
key_args = key_function(*args, **kwargs)
cache_key = hash_data(*key_args)
else:
cache_key = hash_data(*args, **kwargs)
if cache_key in memory:
return memory[cache_key]
result = cached_function(*args, **kwargs)
memory[cache_key] = result
return result

async def async_wrapper(*args, **kwargs):
if not _caching_enabled:
return await cached_function(*args, **kwargs)
if key_function:
key_args = key_function(*args, **kwargs)
cache_key = hash_data(*key_args)
else:
cache_key = hash_data(*args, **kwargs)
if cache_key in memory:
return memory[cache_key]
result = await cached_function(*args, **kwargs)
memory[cache_key] = result
return result

if asyncio.iscoroutinefunction(cached_function):
return async_wrapper
else:
return wrapper

return decorator


def get_cache():
Expand Down Expand Up @@ -51,11 +110,11 @@ def disable_cache():
>>> cache.disable()
"""
global memory
memory = NoCache()
global _caching_enabled
_caching_enabled = False


def clear_cache():
"""Erase the cache completely."""
global memory
memory.storage.clear()
memory.clear()
39 changes: 17 additions & 22 deletions outlines/fsm/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,18 @@ def __init__(
regex_string: str,
tokenizer: "Tokenizer",
):
@cache()
def func_cache_key_args(
regex_string: str, tokenizer: "Tokenizer"
) -> Tuple[str, list]:
"""Return the values that will be used to create the cache key of create_states_mapping"""
cacheable_vocabulary = sorted(tokenizer.vocabulary.values())
return (regex_string, cacheable_vocabulary)

@cache(func_cache_key_args)
def create_states_mapping(
regex_string: str, cachable_vocabulary: List[Tuple[str, int]]
) -> Tuple[dict, list, list]:
"""
Create the variables related the mapping between stzates and tokens
regex_string: str, tokenizer: "Tokenizer"
) -> Tuple[dict, set, set]:
"""Create the variables related the mapping between stzates and tokens
The parameters of the function are used for caching purpose
"""
regex_pattern = interegular.parse_pattern(regex_string)
Expand All @@ -144,27 +150,16 @@ def create_states_mapping(
final_states = regex_fsm.finals | {
-1
} # Include the EOS token in final states
return states_to_token_maps, list(empty_token_ids), list(final_states)

def convert_dict_items_to_int(item: dict) -> dict:
"""Recursively converts the keys/values of a dict to integers"""
if not isinstance(item, dict):
return int(item)
return {
int(key): convert_dict_items_to_int(value)
for key, value in item.items()
}
return states_to_token_maps, empty_token_ids, final_states

(
self.states_to_token_maps,
self.empty_token_ids,
self.final_states,
) = create_states_mapping(regex_string, tokenizer)
self.num_tokens_generated = 0
self.vocabulary = tokenizer.vocabulary.values()
self.end_token_id = tokenizer.eos_token_id
cachable_vocabulary = sorted(self.vocabulary)
states_to_token_maps, empty_token_ids, final_states = create_states_mapping(
regex_string, cachable_vocabulary
)
self.states_to_token_maps = convert_dict_items_to_int(states_to_token_maps)
self.empty_token_ids = set(empty_token_ids)
self.final_states = set(final_states)

def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
"""Generate a list of allowed tokens for the next step.
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ dependencies = [
"lark",
"nest_asyncio",
"numpy",
"perscache",
"cloudpickle",
"diskcache",
"pydantic>=2.0",
"scipy",
"torch>=2.1",
Expand Down Expand Up @@ -103,7 +104,8 @@ module = [
"mamba_ssm.*",
"nest_asyncio",
"numpy.*",
"perscache.*",
"cloudpickle.*",
"diskcache.*",
"pydantic.*",
"pytest",
"referencing.*",
Expand Down
10 changes: 4 additions & 6 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os
import tempfile
from pathlib import Path

import perscache
import diskcache
import pytest


Expand Down Expand Up @@ -35,19 +34,18 @@ def test_cache(refresh_environment):
import outlines

memory = outlines.get_cache()
assert memory.storage.location == Path(tempdir)
assert memory.directory == tempdir

yield outlines.caching.cache()

memory.storage.clear()
memory.clear()


def test_get_cache(test_cache):
import outlines

memory = outlines.get_cache()
assert isinstance(memory, perscache.Cache)
assert isinstance(memory.storage, perscache.storage.LocalFileStorage)
assert isinstance(memory, diskcache.Cache)

# If the cache is enabled then the size
# of `store` should not increase the
Expand Down

0 comments on commit d032396

Please sign in to comment.