Skip to content

Commit

Permalink
Do not always remap storage in fgraph_to_python
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Nov 10, 2022
1 parent d1af711 commit f24fac6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
12 changes: 3 additions & 9 deletions aesara/link/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,8 +678,6 @@ def fgraph_to_python(
*,
type_conversion_fn: Callable = lambda x, **kwargs: x,
order: Optional[List[Apply]] = None,
input_storage: Optional["InputStorageType"] = None,
output_storage: Optional["OutputStorageType"] = None,
storage_map: Optional["StorageMapType"] = None,
fgraph_name: str = "fgraph_to_python",
global_env: Optional[Dict[Any, Any]] = None,
Expand All @@ -704,10 +702,6 @@ def fgraph_to_python(
``(value: Optional[Any], variable: Variable=None, storage: List[Optional[Any]]=None, **kwargs)``.
order
The `order` argument to `map_storage`.
input_storage
The `input_storage` argument to `map_storage`.
output_storage
The `output_storage` argument to `map_storage`.
storage_map
The `storage_map` argument to `map_storage`.
fgraph_name
Expand All @@ -730,9 +724,9 @@ def fgraph_to_python(

if order is None:
order = fgraph.toposort()
input_storage, output_storage, storage_map = map_storage(
fgraph, order, input_storage, output_storage, storage_map
)

if storage_map is None:
_, _, storage_map = map_storage(fgraph, order)

unique_name = unique_name_generator([fgraph_name])

Expand Down
19 changes: 19 additions & 0 deletions tests/link/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,25 @@ def test_fgraph_to_python_constant_outputs():
assert out_py()[0] is y.data


def test_fgraph_to_python_constant_inputs():
x = constant([1.0])
y = vector("y")

out = x + y
out_fg = FunctionGraph(outputs=[out], clone=False)

out_py = fgraph_to_python(out_fg, to_python, storage_map=None)

res = out_py(2.0)
assert res == (3.0,)

storage_map = {out: [None], x: [np.r_[2.0]], y: [None]}
out_py = fgraph_to_python(out_fg, to_python, storage_map=storage_map)

res = out_py(2.0)
assert res == (4.0,)


def test_unique_name_generator():

unique_names = unique_name_generator(["blah"], suffix_sep="_")
Expand Down

0 comments on commit f24fac6

Please sign in to comment.