Skip to content

Commit

Permalink
Separate recursive importing from single node importing in FunctionGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Dec 14, 2022
1 parent 1cce2b0 commit 3c665a5
Showing 1 changed file with 57 additions and 37 deletions.
94 changes: 57 additions & 37 deletions aesara/graph/fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def import_var(

if isinstance(var.type, NullType):
raise TypeError(
f"Computation graph contains a NaN. {var.type.why_null}"
f"Computation graph contains a null type: {var} {var.type.why_null}"
)
if import_missing:
self.add_input(var)
Expand All @@ -327,7 +327,7 @@ def import_node(
reason: Optional[str] = None,
import_missing: bool = False,
) -> None:
"""Recursively import everything between an ``Apply`` node and the ``FunctionGraph``'s outputs.
"""Recursively import everything between an `Apply` node and the `FunctionGraph`'s outputs.
Parameters
----------
Expand All @@ -347,42 +347,62 @@ def import_node(
# to know where to stop going down.)
new_nodes = io_toposort(self.variables, apply_node.outputs)

if check:
for node in new_nodes:
for var in node.inputs:
if (
var.owner is None
and not isinstance(var, AtomicVariable)
and var not in self.inputs
):
if import_missing:
self.add_input(var)
else:
error_msg = (
f"Input {node.inputs.index(var)} ({var})"
" of the graph (indices start "
f"from 0), used to compute {node}, was not "
"provided and not given a value. Use the "
"Aesara flag exception_verbosity='high', "
"for more information on this error."
)
raise MissingInputError(error_msg, variable=var)

for node in new_nodes:
assert node not in self.apply_nodes
self.apply_nodes.add(node)
if not hasattr(node.tag, "imported_by"):
node.tag.imported_by = []
node.tag.imported_by.append(str(reason))
for output in node.outputs:
self.setup_var(output)
self.variables.add(output)
for i, input in enumerate(node.inputs):
if input not in self.variables:
self.setup_var(input)
self.variables.add(input)
self.add_client(input, (node, i))
self.execute_callbacks("on_import", node, reason)
self._import_node(
node, check=check, reason=reason, import_missing=import_missing
)

def _import_node(
self,
apply_node: Apply,
check: bool = True,
reason: Optional[str] = None,
import_missing: bool = False,
) -> None:
"""Import a single node.
See `FunctionGraph.import_node`.
"""
assert apply_node not in self.apply_nodes

for i, inp in enumerate(apply_node.inputs):
if (
check
and inp.owner is None
and not isinstance(inp, AtomicVariable)
and inp not in self.inputs
):
if import_missing:
self.add_input(inp)
else:
error_msg = (
f"Input {apply_node.inputs.index(inp)} ({inp})"
" of the graph (indices start "
f"from 0), used to compute {apply_node}, was not "
"provided and not given a value. Use the "
"Aesara flag exception_verbosity='high', "
"for more information on this error."
)
raise MissingInputError(error_msg, variable=inp)

if inp not in self.variables:
self.setup_var(inp)
self.variables.add(inp)

self.add_client(inp, (apply_node, i))

for output in apply_node.outputs:
self.setup_var(output)
self.variables.add(output)

self.apply_nodes.add(apply_node)

if not hasattr(apply_node.tag, "imported_by"):
apply_node.tag.imported_by = []

apply_node.tag.imported_by.append(str(reason))

self.execute_callbacks("on_import", apply_node, reason)

def change_node_input(
self,
Expand Down

0 comments on commit 3c665a5

Please sign in to comment.