From 3b5ef7525b91b3cbac8e24ce8f0b915884e42c04 Mon Sep 17 00:00:00 2001 From: Smit-create Date: Tue, 28 Feb 2023 09:19:27 +0530 Subject: [PATCH] Prototype of cache numba-jit functions --- aesara/link/numba/dispatch/basic.py | 89 ++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 1 deletion(-) diff --git a/aesara/link/numba/dispatch/basic.py b/aesara/link/numba/dispatch/basic.py index ddf6697639..b173ab2765 100644 --- a/aesara/link/numba/dispatch/basic.py +++ b/aesara/link/numba/dispatch/basic.py @@ -1,10 +1,16 @@ +import hashlib import operator +import pathlib +import pickle +import shelve import warnings from contextlib import contextmanager -from functools import singledispatch +from functools import singledispatch, wraps from textwrap import dedent +from types import ModuleType from typing import TYPE_CHECKING, Callable, Optional, Union, cast +import dill import numba import numba.np.unsafe.ndarray as numba_ndarray import numpy as np @@ -353,6 +359,87 @@ def numba_funcify(obj, node=None, storage_map=None, **kwargs) -> Callable: return _numba_funcify(obj, node=node, storage_map=storage_map, **kwargs) +numba_cache_index = pathlib.PurePath(config.compiledir, "numba_cache_index") +numba_db = shelve.open(numba_cache_index.as_posix()) + + +def make_node_key(node: "Apply") -> Optional[str]: + """Create a cache key for `node`. + TODO: Currently this works only with Apply Node + """ + if not isinstance(node, Apply): + return None + key = (node.op,) + key += tuple(inp.type for inp in node.inputs) + key += tuple(inp.type for inp in node.outputs) + + key = hashlib.sha256(pickle.dumps(key)).hexdigest() + + return key + + +def check_cache(node_key: str): + """Check disk-backed cache.""" + return numba_db.get(node_key) + + +def add_to_cache(node_key: str, numba_py_fn) -> Callable: + """Add the numba generated function to the cache.""" + module_file_base = ( + pathlib.PurePath(config.compiledir, node_key).with_suffix(".py").as_posix() + ) + cache_module = ModuleType(node_key) + + # Create a temporary module for the generated source + cache_module.source = numba_py_fn + dill.dump_module(module_file_base, module=cache_module) + + # Load the function from the persisted module + numba_py_fn = dill.load_module(module_file_base).source + + # Add the function to numba_cache database + numba_db[node_key] = numba_py_fn + + return numba_py_fn + + +def persist_py_code(func) -> Callable: + """Persist a Numba JIT-able Python function. + Parameters + ========== + func + An `Op` dispatch function that returns the source for + a Python function that will be `numba.njit`ed. + """ + + @wraps(func) + def _func(obj, node, **kwargs): + node_key = make_node_key(node) + numba_py_fn = None + if node_key: + numba_py_fn = check_cache(node_key) + + if node_key is None or numba_py_fn is None: + # We could only ever return the function source in our dispatch + # implementations. That way, we can compile directly to the on-disk + # modules only once. + numba_py_fn = func(obj, node, **kwargs) + + # This will determine on-disk module name to be generated for + # `numba_py_src` and return the corresponding Python function + # object using steps similar to + # `aesara.link.utils.compile_function_src`. + if node_key: + numba_py_fn = add_to_cache(node_key, numba_py_fn) + + # TODO: Presently numba_py_fn is already jitted. + # numba_fn = numba_njit(numba_py_fn) + return numba_py_fn + + return _func + + +@persist_py_code @singledispatch def _numba_funcify( obj,