Skip to content

Commit

Permalink
Use objmode in scipy.special without numba-scipy
Browse files Browse the repository at this point in the history
  • Loading branch information
Adrian Seyboldt committed Jul 24, 2022
1 parent 8763981 commit 24a84eb
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 17 deletions.
6 changes: 6 additions & 0 deletions aesara/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,12 @@ def add_numba_configvars():
BoolParam(True),
in_c_key=False,
)
config.add(
"numba_scipy",
("Enable usage of the numba_scipy package for special functions",),
BoolParam(True),
in_c_key=False,
)


def _default_compiledirname():
Expand Down
10 changes: 7 additions & 3 deletions aesara/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,8 @@ def numba_typify(data, dtype=None, **kwargs):
return data


@singledispatch
def numba_funcify(op, node=None, storage_map=None, **kwargs):
def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from an Aesara `Op`."""

warnings.warn(
f"Numba will use object mode to run {op}'s perform method",
UserWarning,
Expand Down Expand Up @@ -371,6 +369,12 @@ def perform(*inputs):
return perform


@singledispatch
def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Generate a numba function for a given op and apply node."""
return generate_fallback_impl(op, node, storage_map, **kwargs)


@numba_funcify.register(FunctionGraph)
def numba_funcify_FunctionGraph(
fgraph,
Expand Down
15 changes: 13 additions & 2 deletions aesara/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
OR,
XOR,
Add,
Composite,
IntDiv,
Mean,
Mul,
Expand All @@ -40,6 +41,7 @@
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import MaxAndArgmax, MulWithoutZeros
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.type import scalar


@singledispatch
Expand Down Expand Up @@ -424,8 +426,17 @@ def axis_apply_fn(x):

@numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs):

scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs)
# Creating a new scalar node is more involved and unnecessary
# if the scalar_op is composite, as the fgraph already contains
# all the necessary information.
scalar_node = None
if not isinstance(op.scalar_op, Composite):
scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs]
scalar_node = op.scalar_op.make_node(*scalar_inputs)

scalar_op_fn = numba_funcify(
op.scalar_op, node=scalar_node, parent_node=node, inline="always", **kwargs
)
elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False)
elemwise_fn_name = elemwise_fn.__name__

Expand Down
28 changes: 25 additions & 3 deletions aesara/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import warnings
from functools import reduce
from typing import List

Expand All @@ -10,7 +11,11 @@
from aesara.compile.ops import ViewOp
from aesara.graph.basic import Variable
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import create_numba_signature, numba_funcify
from aesara.link.numba.dispatch.basic import (
create_numba_signature,
generate_fallback_impl,
numba_funcify,
)
from aesara.link.utils import (
compile_function_src,
get_name_for_object,
Expand All @@ -37,14 +42,31 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
# compiling the same Numba function over and over again?

scalar_func_name = op.nfunc_spec[0]
scalar_func = None

if scalar_func_name.startswith("scipy."):
func_package = scipy
scalar_func_name = scalar_func_name.split(".", 1)[-1]

use_numba_scipy = config.numba_scipy
if use_numba_scipy:
try:
import numba_scipy # noqa: F401
except ImportError:
use_numba_scipy = False
if not use_numba_scipy:
warnings.warn(
"Native numba versions of scipy functions might be "
"avalable if numba-scipy is installed.",
UserWarning,
)
scalar_func = generate_fallback_impl(op, node, **kwargs)
else:
func_package = np

if "." in scalar_func_name:
if scalar_func is not None:
pass
elif "." in scalar_func_name:
scalar_func = reduce(getattr, [scipy] + scalar_func_name.split("."))
else:
scalar_func = getattr(func_package, scalar_func_name)
Expand Down Expand Up @@ -220,7 +242,7 @@ def clip(_x, _min, _max):

@numba_funcify.register(Composite)
def numba_funcify_Composite(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
signature = create_numba_signature(op.fgraph, force_scalar=True)
composite_fn = numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
)
Expand Down
28 changes: 19 additions & 9 deletions tests/link/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,10 @@ def test_box_unbox(input, wrapper_fn, check_fn):
assert check_fn(res, input)


@pytest.mark.parametrize(
"numba_scipy",
[True, False],
)
@pytest.mark.parametrize(
"inputs, input_vals, output_fn, exc",
[
Expand Down Expand Up @@ -352,6 +356,12 @@ def test_box_unbox(input, wrapper_fn, check_fn):
lambda x: at.erfc(x),
None,
),
(
[at.vector()],
[rng.standard_normal(100).astype(config.floatX)],
lambda x: at.erfcx(x),
None,
),
(
[at.vector() for i in range(4)],
[rng.standard_normal(100).astype(config.floatX) for i in range(4)],
Expand Down Expand Up @@ -393,17 +403,17 @@ def test_box_unbox(input, wrapper_fn, check_fn):
),
],
)
def test_Elemwise(inputs, input_vals, output_fn, exc):
def test_Elemwise(numba_scipy, inputs, input_vals, output_fn, exc):
with config.change_flags(numba_scipy=numba_scipy):
outputs = output_fn(*inputs)

outputs = output_fn(*inputs)

out_fg = FunctionGraph(
outputs=[outputs] if not isinstance(outputs, list) else outputs
)
out_fg = FunctionGraph(
outputs=[outputs] if not isinstance(outputs, list) else outputs
)

cm = contextlib.suppress() if exc is None else pytest.raises(exc)
with cm:
compare_numba_and_py(out_fg, input_vals)
cm = contextlib.suppress() if exc is None else pytest.raises(exc)
with cm:
compare_numba_and_py(out_fg, input_vals)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 24a84eb

Please sign in to comment.