diff --git a/aesara/link/numba/dispatch/basic.py b/aesara/link/numba/dispatch/basic.py index 1cd304f5e5..fe4caf08d0 100644 --- a/aesara/link/numba/dispatch/basic.py +++ b/aesara/link/numba/dispatch/basic.py @@ -377,6 +377,9 @@ def perform(*inputs): @numba_funcify.register(OpFromGraph) def numba_funcify_OpFromGraph(op, node=None, **kwargs): + + _ = kwargs.pop("storage_map", None) + fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs)) if len(op.fgraph.outputs) == 1: diff --git a/aesara/link/numba/dispatch/scalar.py b/aesara/link/numba/dispatch/scalar.py index 28031ea988..08dd5f1a10 100644 --- a/aesara/link/numba/dispatch/scalar.py +++ b/aesara/link/numba/dispatch/scalar.py @@ -221,6 +221,9 @@ def clip(_x, _min, _max): @numba_funcify.register(Composite) def numba_funcify_Composite(op, node, **kwargs): signature = create_numba_signature(node, force_scalar=True) + + _ = kwargs.pop("storage_map", None) + composite_fn = numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)( numba_funcify(op.fgraph, squeeze_output=True, **kwargs) ) diff --git a/aesara/link/utils.py b/aesara/link/utils.py index 2a593e9af1..c7d93fa7f1 100644 --- a/aesara/link/utils.py +++ b/aesara/link/utils.py @@ -678,8 +678,6 @@ def fgraph_to_python( *, type_conversion_fn: Callable = lambda x, **kwargs: x, order: Optional[List[Apply]] = None, - input_storage: Optional["InputStorageType"] = None, - output_storage: Optional["OutputStorageType"] = None, storage_map: Optional["StorageMapType"] = None, fgraph_name: str = "fgraph_to_python", global_env: Optional[Dict[Any, Any]] = None, @@ -704,10 +702,6 @@ def fgraph_to_python( ``(value: Optional[Any], variable: Variable=None, storage: List[Optional[Any]]=None, **kwargs)``. order The `order` argument to `map_storage`. - input_storage - The `input_storage` argument to `map_storage`. - output_storage - The `output_storage` argument to `map_storage`. storage_map The `storage_map` argument to `map_storage`. fgraph_name @@ -730,9 +724,9 @@ def fgraph_to_python( if order is None: order = fgraph.toposort() - input_storage, output_storage, storage_map = map_storage( - fgraph, order, input_storage, output_storage, storage_map - ) + + if storage_map is None: + storage_map = {} unique_name = unique_name_generator([fgraph_name]) @@ -752,10 +746,13 @@ def fgraph_to_python( node_input_names = [] for i in node.inputs: local_input_name = unique_name(i) - if storage_map[i][0] is not None or isinstance(i, Constant): + input_storage = storage_map.setdefault( + i, [None if not isinstance(i, Constant) else i.data] + ) + if input_storage[0] is not None or isinstance(i, Constant): # Constants need to be assigned locally and referenced global_env[local_input_name] = type_conversion_fn( - storage_map[i][0], variable=i, storage=storage_map[i], **kwargs + input_storage[0], variable=i, storage=input_storage, **kwargs ) # TODO: We could attempt to use the storage arrays directly # E.g. `local_input_name = f"{local_input_name}[0]"` @@ -763,20 +760,24 @@ def fgraph_to_python( node_output_names = [unique_name(v) for v in node.outputs] - assign_comment_str = f"{indent(str(node), '# ')}" assign_str = f"{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})" - body_assigns.append(f"{assign_comment_str}\n{assign_str}") + assign_comment_str = f"{indent(str(node), '# ')}" + assign_block_str = f"{assign_comment_str}\n{assign_str}" + body_assigns.append(assign_block_str) # Handle `Constant`-only outputs (these don't have associated `Apply` # nodes, so the above isn't applicable) for out in fgraph.outputs: if isinstance(out, Constant): - local_input_name = unique_name(out) - if local_input_name not in global_env: - global_env[local_input_name] = type_conversion_fn( - storage_map[out][0], + local_output_name = unique_name(out) + if local_output_name not in global_env: + output_storage = storage_map.setdefault( + out, [None if not isinstance(out, Constant) else out.data] + ) + global_env[local_output_name] = type_conversion_fn( + output_storage[0], variable=out, - storage=storage_map[out], + storage=output_storage, **kwargs, ) @@ -794,7 +795,7 @@ def fgraph_to_python( fgraph_def_src = dedent( f""" def {fgraph_name}({", ".join(fgraph_input_names)}): - {indent(joined_body_assigns, " " * 4)} +{indent(joined_body_assigns, " " * 4)} return {fgraph_return_src} """ ).strip() diff --git a/tests/link/test_utils.py b/tests/link/test_utils.py index 407d399552..96bc3c3673 100644 --- a/tests/link/test_utils.py +++ b/tests/link/test_utils.py @@ -176,6 +176,25 @@ def test_fgraph_to_python_constant_outputs(): assert out_py()[0] is y.data +def test_fgraph_to_python_constant_inputs(): + x = constant([1.0]) + y = vector("y") + + out = x + y + out_fg = FunctionGraph(outputs=[out], clone=False) + + out_py = fgraph_to_python(out_fg, to_python, storage_map=None) + + res = out_py(2.0) + assert res == (3.0,) + + storage_map = {out: [None], x: [np.r_[2.0]], y: [None]} + out_py = fgraph_to_python(out_fg, to_python, storage_map=storage_map) + + res = out_py(2.0) + assert res == (4.0,) + + def test_unique_name_generator(): unique_names = unique_name_generator(["blah"], suffix_sep="_")