Skip to content

Commit

Permalink
Fix mypy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Smit-create authored and brandonwillard committed Mar 8, 2023
1 parent 29367ca commit 170dc60
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions aesara/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,34 +356,36 @@ def numba_const_convert(data, dtype=None, **kwargs):

def numba_funcify(obj, node=None, storage_map=None, **kwargs) -> Callable:
"""Convert `obj` to a Numba-JITable object."""
return _numba_funcify(obj, node=node, storage_map=storage_map, **kwargs)
return cast(
Callable, _numba_funcify(obj, node=node, storage_map=storage_map, **kwargs)
)


numba_cache_index = pathlib.PurePath(config.compiledir, "numba_cache_index")
numba_db = shelve.open(numba_cache_index.as_posix())


def make_node_key(node: "Apply") -> Optional[str]:
def make_node_key(node):
"""Create a cache key for `node`.
TODO: Currently this works only with Apply Node
"""
if not isinstance(node, Apply):
return None
key = (node.op,)
key += tuple(inp.type for inp in node.inputs)
key = tuple(inp.type for inp in node.inputs)
key += tuple(inp.type for inp in node.outputs)

key = hashlib.sha256(pickle.dumps(key)).hexdigest()
hash_key = hashlib.sha256(pickle.dumps(key)).hexdigest()

return key
return hash_key


def check_cache(node_key: str):
def check_cache(node_key):
"""Check disk-backed cache."""
return numba_db.get(node_key)


def add_to_cache(node_key: str, numba_py_fn) -> Callable:
def add_to_cache(node_key, numba_py_fn):
"""Add the numba generated function to the cache."""
module_file_base = (
pathlib.PurePath(config.compiledir, node_key).with_suffix(".py").as_posix()
Expand All @@ -403,7 +405,7 @@ def add_to_cache(node_key: str, numba_py_fn) -> Callable:
return numba_py_fn


def persist_py_code(func) -> Callable:
def persist_py_code(func):
"""Persist a Numba JIT-able Python function.
Parameters
==========
Expand All @@ -418,7 +420,6 @@ def _func(obj, node, **kwargs):
numba_py_fn = None
if node_key:
numba_py_fn = check_cache(node_key)

if node_key is None or numba_py_fn is None:
# We could only ever return the function source in our dispatch
# implementations. That way, we can compile directly to the on-disk
Expand Down

0 comments on commit 170dc60

Please sign in to comment.