-
-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move nested static numba-jit functions #1438
base: main
Are you sure you want to change the base?
Changes from 7 commits
2e10a82
d99dda8
1c29a6a
febc082
d8b19d1
2038a5f
2b6bfc7
34739a9
58ccd1b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -378,6 +378,13 @@ def add_basic_configvars(): | |||||
in_c_key=False, | ||||||
) | ||||||
|
||||||
config.add( | ||||||
"DISABLE_NUMBA_CACHE", | ||||||
("Disable numba caching in the backend"), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
BoolParam(False), | ||||||
in_c_key=False, | ||||||
) | ||||||
|
||||||
|
||||||
def _is_gt_0(x): | ||||||
return x > 0 | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
import hashlib | ||
import operator | ||
import pickle | ||
import warnings | ||
from contextlib import contextmanager | ||
from functools import singledispatch | ||
from textwrap import dedent | ||
from typing import TYPE_CHECKING, Callable, Optional, Union, cast | ||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union, cast | ||
|
||
import numba | ||
import numba.np.unsafe.ndarray as numba_ndarray | ||
|
@@ -350,7 +352,63 @@ def numba_const_convert(data, dtype=None, **kwargs): | |
|
||
def numba_funcify(obj, node=None, storage_map=None, **kwargs) -> Callable: | ||
"""Convert `obj` to a Numba-JITable object.""" | ||
return _numba_funcify(obj, node=node, storage_map=storage_map, **kwargs) | ||
numba_py_fn = None | ||
if config.DISABLE_NUMBA_CACHE: | ||
numba_py_fn = _numba_funcify(obj, node=node, storage_map=storage_map, **kwargs) | ||
else: | ||
node_key = make_node_key(node) | ||
|
||
if node_key: | ||
numba_py_fn = check_cache(node_key) | ||
if node_key is None or numba_py_fn is None: | ||
# We could only ever return the function source in our dispatch | ||
# implementations. That way, we can compile directly to the on-disk | ||
# modules only once. | ||
numba_py_fn = _numba_funcify( | ||
obj, node=node, storage_map=storage_map, **kwargs | ||
) | ||
|
||
# This will determine on-disk module name to be generated for | ||
# `numba_py_src` and return the corresponding Python function | ||
# object using steps similar to | ||
# `aesara.link.utils.compile_function_src`. | ||
if node_key: | ||
numba_py_fn = add_to_cache(node_key, numba_py_fn) | ||
|
||
# TODO: Presently numba_py_fn is already jitted. | ||
# numba_fn = numba_njit(numba_py_fn) | ||
return cast(Callable, numba_py_fn) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI: I think the tests are failing because we haven't finished refactoring the rest of the code so that it's aware of |
||
|
||
|
||
numba_db: Dict[str, Callable] = {} | ||
|
||
|
||
def make_node_key(node): | ||
"""Create a cache key for `node`. | ||
TODO: Currently this works only with Apply Node | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
""" | ||
if not isinstance(node, Apply): | ||
return None | ||
# TODO: Add a stronger hashing mechanism | ||
key = str(node) | ||
# key = (node.op,) | ||
# key = tuple(inp.type for inp in node.inputs) | ||
# key += tuple(inp.type for inp in node.outputs) | ||
|
||
hash_key = hashlib.sha256(pickle.dumps(key)).hexdigest() | ||
|
||
return hash_key | ||
|
||
|
||
def check_cache(node_key): | ||
"""Check disk-backed cache.""" | ||
return numba_db.get(node_key, None) | ||
|
||
|
||
def add_to_cache(node_key, numba_py_fn): | ||
"""Add the numba generated function to the cache.""" | ||
numba_db[node_key] = numba_py_fn | ||
return numba_py_fn | ||
|
||
|
||
@singledispatch | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to change this to something like
DISABLE_NUMBA_PYTHON_IR_CACHING
.