Skip to content

Commit

Permalink
Prototype of cache numba-jit functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Smit-create committed Feb 28, 2023
1 parent d7b2c98 commit 3b5ef75
Showing 1 changed file with 88 additions and 1 deletion.
89 changes: 88 additions & 1 deletion aesara/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import hashlib
import operator
import pathlib
import pickle
import shelve
import warnings
from contextlib import contextmanager
from functools import singledispatch
from functools import singledispatch, wraps
from textwrap import dedent
from types import ModuleType
from typing import TYPE_CHECKING, Callable, Optional, Union, cast

import dill
import numba
import numba.np.unsafe.ndarray as numba_ndarray
import numpy as np
Expand Down Expand Up @@ -353,6 +359,87 @@ def numba_funcify(obj, node=None, storage_map=None, **kwargs) -> Callable:
return _numba_funcify(obj, node=node, storage_map=storage_map, **kwargs)


numba_cache_index = pathlib.PurePath(config.compiledir, "numba_cache_index")
numba_db = shelve.open(numba_cache_index.as_posix())


def make_node_key(node: "Apply") -> Optional[str]:
"""Create a cache key for `node`.
TODO: Currently this works only with Apply Node
"""
if not isinstance(node, Apply):
return None
key = (node.op,)
key += tuple(inp.type for inp in node.inputs)
key += tuple(inp.type for inp in node.outputs)

key = hashlib.sha256(pickle.dumps(key)).hexdigest()

return key


def check_cache(node_key: str):
"""Check disk-backed cache."""
return numba_db.get(node_key)


def add_to_cache(node_key: str, numba_py_fn) -> Callable:
"""Add the numba generated function to the cache."""
module_file_base = (
pathlib.PurePath(config.compiledir, node_key).with_suffix(".py").as_posix()
)
cache_module = ModuleType(node_key)

# Create a temporary module for the generated source
cache_module.source = numba_py_fn
dill.dump_module(module_file_base, module=cache_module)

# Load the function from the persisted module
numba_py_fn = dill.load_module(module_file_base).source

# Add the function to numba_cache database
numba_db[node_key] = numba_py_fn

return numba_py_fn


def persist_py_code(func) -> Callable:
"""Persist a Numba JIT-able Python function.
Parameters
==========
func
An `Op` dispatch function that returns the source for
a Python function that will be `numba.njit`ed.
"""

@wraps(func)
def _func(obj, node, **kwargs):
node_key = make_node_key(node)
numba_py_fn = None
if node_key:
numba_py_fn = check_cache(node_key)

if node_key is None or numba_py_fn is None:
# We could only ever return the function source in our dispatch
# implementations. That way, we can compile directly to the on-disk
# modules only once.
numba_py_fn = func(obj, node, **kwargs)

# This will determine on-disk module name to be generated for
# `numba_py_src` and return the corresponding Python function
# object using steps similar to
# `aesara.link.utils.compile_function_src`.
if node_key:
numba_py_fn = add_to_cache(node_key, numba_py_fn)

# TODO: Presently numba_py_fn is already jitted.
# numba_fn = numba_njit(numba_py_fn)
return numba_py_fn

return _func


@persist_py_code
@singledispatch
def _numba_funcify(
obj,
Expand Down

0 comments on commit 3b5ef75

Please sign in to comment.