diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2ac70ed563..875057b147 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -121,7 +121,7 @@ jobs: run: | mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.55" numba-scipy; fi - mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib + mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib dill pip install -e ./ mamba list && pip freeze python -c 'import aesara; print(aesara.config.__str__(print_doc=False))' diff --git a/aesara/configdefaults.py b/aesara/configdefaults.py index b1e914b2a9..b46bdf57a2 100644 --- a/aesara/configdefaults.py +++ b/aesara/configdefaults.py @@ -378,6 +378,13 @@ def add_basic_configvars(): in_c_key=False, ) + config.add( + "DISABLE_NUMBA_PYTHON_IR_CACHING", + ("Disable caching of the Aesara-generated Python IR used by the Numba backend"), + BoolParam(False), + in_c_key=False, + ) + def _is_gt_0(x): return x > 0 diff --git a/aesara/link/numba/dispatch/basic.py b/aesara/link/numba/dispatch/basic.py index ddf6697639..f2e3c69471 100644 --- a/aesara/link/numba/dispatch/basic.py +++ b/aesara/link/numba/dispatch/basic.py @@ -1,9 +1,11 @@ +import hashlib import operator +import pickle import warnings from contextlib import contextmanager from functools import singledispatch from textwrap import dedent -from typing import TYPE_CHECKING, Callable, Optional, Union, cast +from typing import TYPE_CHECKING, Callable, Dict, Optional, Union, cast import numba import numba.np.unsafe.ndarray as numba_ndarray @@ -350,7 +352,63 @@ def numba_const_convert(data, dtype=None, **kwargs): def numba_funcify(obj, node=None, storage_map=None, **kwargs) -> Callable: """Convert `obj` to a Numba-JITable object.""" - return _numba_funcify(obj, node=node, storage_map=storage_map, **kwargs) + numba_py_fn = None + if config.DISABLE_NUMBA_PYTHON_IR_CACHING: + numba_py_fn = _numba_funcify(obj, node=node, storage_map=storage_map, **kwargs) + else: + node_key = make_node_key(node) + + if node_key: + numba_py_fn = check_cache(node_key) + if node_key is None or numba_py_fn is None: + # We could only ever return the function source in our dispatch + # implementations. That way, we can compile directly to the on-disk + # modules only once. + numba_py_fn = _numba_funcify( + obj, node=node, storage_map=storage_map, **kwargs + ) + + # This will determine on-disk module name to be generated for + # `numba_py_src` and return the corresponding Python function + # object using steps similar to + # `aesara.link.utils.compile_function_src`. + if node_key: + numba_py_fn = add_to_cache(node_key, numba_py_fn) + + # TODO: Presently numba_py_fn is already jitted. + # numba_fn = numba_njit(numba_py_fn) + return cast(Callable, numba_py_fn) + + +numba_db: Dict[str, Callable] = {} + + +def make_node_key(node): + """Create a cache key for `node`. + TODO: Currently this works only with Apply Node + """ + if not isinstance(node, Apply): + return None + # TODO: Add a stronger hashing mechanism + key = str(node) + # key = (node.op,) + # key = tuple(inp.type for inp in node.inputs) + # key += tuple(inp.type for inp in node.outputs) + + hash_key = hashlib.sha256(pickle.dumps(key)).hexdigest() + + return hash_key + + +def check_cache(node_key): + """Check disk-backed cache.""" + return numba_db.get(node_key, None) + + +def add_to_cache(node_key, numba_py_fn): + """Add the numba generated function to the cache.""" + numba_db[node_key] = numba_py_fn + return numba_py_fn @singledispatch diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 15799a3134..3a84b0bb1e 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -164,17 +164,18 @@ def inner_vec(*args): mock.patch("numba.np.unsafe.ndarray.to_fixed_tuple", lambda x, n: tuple(x)), ] - with contextlib.ExitStack() as stack: - for ctx in mocks: - stack.enter_context(ctx) - - aesara_numba_fn = function( - fn_inputs, - fn_outputs, - mode=mode, - accept_inplace=True, - ) - _ = aesara_numba_fn(*inputs) + with config.change_flags(DISABLE_NUMBA_PYTHON_IR_CACHING=True): + with contextlib.ExitStack() as stack: + for ctx in mocks: + stack.enter_context(ctx) + + aesara_numba_fn = function( + fn_inputs, + fn_outputs, + mode=mode, + accept_inplace=True, + ) + _ = aesara_numba_fn(*inputs) def compare_numba_and_py( @@ -999,16 +1000,20 @@ def test_config_options_cached(): x = at.dvector() with config.change_flags(numba__cache=True): - aesara_numba_fn = function([x], x * 2, mode=numba_mode) - numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"] - assert not isinstance( - numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache - ) + with config.change_flags(DISABLE_NUMBA_PYTHON_IR_CACHING=True): + aesara_numba_fn = function([x], x * 2, mode=numba_mode) + numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"] + assert not isinstance( + numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache + ) with config.change_flags(numba__cache=False): - aesara_numba_fn = function([x], x * 2, mode=numba_mode) - numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"] - assert isinstance(numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache) + with config.change_flags(DISABLE_NUMBA_PYTHON_IR_CACHING=True): + aesara_numba_fn = function([x], x * 2, mode=numba_mode) + numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"] + assert isinstance( + numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache + ) def test_scalar_return_value_conversion():