From 53bb10c6eb12abff3eb083b5abe129684ad514e2 Mon Sep 17 00:00:00 2001 From: Smit-create Date: Thu, 2 Mar 2023 08:48:50 +0530 Subject: [PATCH] Fix mypy issues --- aesara/link/numba/dispatch/basic.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/aesara/link/numba/dispatch/basic.py b/aesara/link/numba/dispatch/basic.py index b173ab2765..7ff10e3807 100644 --- a/aesara/link/numba/dispatch/basic.py +++ b/aesara/link/numba/dispatch/basic.py @@ -356,34 +356,36 @@ def numba_const_convert(data, dtype=None, **kwargs): def numba_funcify(obj, node=None, storage_map=None, **kwargs) -> Callable: """Convert `obj` to a Numba-JITable object.""" - return _numba_funcify(obj, node=node, storage_map=storage_map, **kwargs) + return cast( + Callable, _numba_funcify(obj, node=node, storage_map=storage_map, **kwargs) + ) numba_cache_index = pathlib.PurePath(config.compiledir, "numba_cache_index") numba_db = shelve.open(numba_cache_index.as_posix()) -def make_node_key(node: "Apply") -> Optional[str]: +def make_node_key(node): """Create a cache key for `node`. TODO: Currently this works only with Apply Node """ if not isinstance(node, Apply): return None key = (node.op,) - key += tuple(inp.type for inp in node.inputs) + key = tuple(inp.type for inp in node.inputs) key += tuple(inp.type for inp in node.outputs) - key = hashlib.sha256(pickle.dumps(key)).hexdigest() + hash_key = hashlib.sha256(pickle.dumps(key)).hexdigest() - return key + return hash_key -def check_cache(node_key: str): +def check_cache(node_key): """Check disk-backed cache.""" return numba_db.get(node_key) -def add_to_cache(node_key: str, numba_py_fn) -> Callable: +def add_to_cache(node_key, numba_py_fn): """Add the numba generated function to the cache.""" module_file_base = ( pathlib.PurePath(config.compiledir, node_key).with_suffix(".py").as_posix() @@ -403,7 +405,7 @@ def add_to_cache(node_key: str, numba_py_fn) -> Callable: return numba_py_fn -def persist_py_code(func) -> Callable: +def persist_py_code(func): """Persist a Numba JIT-able Python function. Parameters ========== @@ -418,7 +420,6 @@ def _func(obj, node, **kwargs): numba_py_fn = None if node_key: numba_py_fn = check_cache(node_key) - if node_key is None or numba_py_fn is None: # We could only ever return the function source in our dispatch # implementations. That way, we can compile directly to the on-disk