From 4ab1a2b88ebbb80f237ac3bdbeb6f269033c894e Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sat, 17 Sep 2022 16:08:16 -0500 Subject: [PATCH] Add FunctionGraph callback checks to tests --- tests/graph/test_fg.py | 404 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 397 insertions(+), 7 deletions(-) diff --git a/tests/graph/test_fg.py b/tests/graph/test_fg.py index d6975db647..c145a99e24 100644 --- a/tests/graph/test_fg.py +++ b/tests/graph/test_fg.py @@ -1,10 +1,13 @@ import pickle +from typing import Any, Dict, List, Tuple import numpy as np import pytest +from typing_extensions import Literal from aesara.configdefaults import config from aesara.graph.basic import NominalVariable +from aesara.graph.features import Feature from aesara.graph.fg import FunctionGraph from aesara.graph.utils import MissingInputError from tests.graph.utils import ( @@ -19,6 +22,32 @@ ) +class CallbackTracker(Feature): + def __init__(self): + self.callback_history: List[ + Tuple[ + Literal["attach", "detach", "import", "change_input", "prune"], + Tuple[Any, ...], + Dict[Any, Any], + ] + ] = [] + + def on_attach(self, *args, **kwargs): + self.callback_history.append(("attach", args, kwargs)) + + def on_detach(self, *args, **kwargs): + self.callback_history.append(("detach", args, kwargs)) + + def on_import(self, *args, **kwargs): + self.callback_history.append(("import", args, kwargs)) + + def on_change_input(self, *args, **kwargs): + self.callback_history.append(("change_input", args, kwargs)) + + def on_prune(self, *args, **kwargs): + self.callback_history.append(("prune", args, kwargs)) + + class TestFunctionGraph: def test_pickle(self): var1 = op1() @@ -61,7 +90,11 @@ def test_init(self): var2 = MyVariable("var2") var3 = op1(var1) var4 = op2(var3, var2) - fg = FunctionGraph([var1, var2], [var3, var4], clone=False) + cb_tracker = CallbackTracker() + fg = FunctionGraph( + [var1, var2], [var3, var4], clone=False, features=[cb_tracker] + ) + assert fg.inputs == [var1, var2] assert fg.outputs == [var3, var4] assert fg.apply_nodes == {var3.owner, var4.owner} @@ -73,6 +106,19 @@ def test_init(self): assert fg.get_clients(var3) == [("output", 0), (var4.owner, 0)] assert fg.get_clients(var4) == [("output", 1)] + assert len(cb_tracker.callback_history) == 3 + assert cb_tracker.callback_history[0] == ("attach", (fg,), {}) + assert cb_tracker.callback_history[1] == ( + "import", + (fg, var3.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[2] == ( + "import", + (fg, var4.owner, "init"), + {}, + ) + varC = MyConstant("varC") var5 = op1(var1, varC) fg = FunctionGraph(outputs=[var3, var4, var5], clone=False) @@ -94,7 +140,10 @@ def test_remove_client(self): var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) - fg = FunctionGraph([var1, var2], [var3, var5], clone=False) + cb_tracker = CallbackTracker() + fg = FunctionGraph( + [var1, var2], [var3, var5], clone=False, features=[cb_tracker] + ) assert fg.variables == {var1, var2, var3, var4, var5} assert fg.get_clients(var2) == [ @@ -104,6 +153,25 @@ def test_remove_client(self): (var5.owner, 2), ] + assert len(cb_tracker.callback_history) == 4 + assert cb_tracker.callback_history[0] == ("attach", (fg,), {}) + assert cb_tracker.callback_history[1] == ( + "import", + (fg, var3.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[2] == ( + "import", + (fg, var4.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[3] == ( + "import", + (fg, var5.owner, "init"), + {}, + ) + cb_tracker.callback_history.clear() + fg.remove_client(var2, (var4.owner, 1)) assert fg.get_clients(var2) == [ @@ -112,12 +180,16 @@ def test_remove_client(self): (var5.owner, 2), ] + assert len(cb_tracker.callback_history) == 0 + fg.remove_client(var1, (var3.owner, 1)) assert fg.get_clients(var1) == [] assert var4.owner in fg.apply_nodes + assert len(cb_tracker.callback_history) == 0 + # This next `remove_client` should trigger a complete removal of `var4`'s # variables and `Apply` node from the `FunctionGraph`. # @@ -132,6 +204,13 @@ def test_remove_client(self): assert var4.owner.tag.removed_by == ["testing"] assert not any(o in fg.variables for o in var4.owner.outputs) + assert len(cb_tracker.callback_history) == 1 + assert cb_tracker.callback_history[0] == ( + "prune", + (fg, var4.owner, "testing"), + {}, + ) + def test_import_node(self): var1 = MyVariable("var1") @@ -139,7 +218,29 @@ def test_import_node(self): var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) - fg = FunctionGraph([var1, var2], [var3, var5], clone=False) + cb_tracker = CallbackTracker() + fg = FunctionGraph( + [var1, var2], [var3, var5], clone=False, features=[cb_tracker] + ) + + assert len(cb_tracker.callback_history) == 4 + assert cb_tracker.callback_history[0] == ("attach", (fg,), {}) + assert cb_tracker.callback_history[1] == ( + "import", + (fg, var3.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[2] == ( + "import", + (fg, var4.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[3] == ( + "import", + (fg, var5.owner, "init"), + {}, + ) + cb_tracker.callback_history.clear() var8 = MyVariable("var8") var6 = op2(var8) @@ -148,11 +249,16 @@ def test_import_node(self): fg.import_node(var6.owner) assert var8 not in fg.variables + assert len(cb_tracker.callback_history) == 0 fg.import_node(var6.owner, import_missing=True) assert var8 in fg.inputs assert var6.owner in fg.apply_nodes + assert len(cb_tracker.callback_history) == 1 + assert cb_tracker.callback_history[0] == ("import", (fg, var6.owner, None), {}) + cb_tracker.callback_history.clear() + var7 = op2(var2) assert not hasattr(var7.owner.tag, "imported_by") fg.import_node(var7.owner) @@ -162,6 +268,9 @@ def test_import_node(self): assert var7.owner in fg.apply_nodes assert (var7.owner, 0) in fg.get_clients(var2) + assert len(cb_tracker.callback_history) == 1 + assert cb_tracker.callback_history[0] == ("import", (fg, var7.owner, None), {}) + def test_import_var(self): var1 = MyVariable("var1") @@ -200,7 +309,29 @@ def test_change_input(self): var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) - fg = FunctionGraph([var1, var2], [var3, var5], clone=False) + cb_tracker = CallbackTracker() + fg = FunctionGraph( + [var1, var2], [var3, var5], clone=False, features=[cb_tracker] + ) + + assert len(cb_tracker.callback_history) == 4 + assert cb_tracker.callback_history[0] == ("attach", (fg,), {}) + assert cb_tracker.callback_history[1] == ( + "import", + (fg, var3.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[2] == ( + "import", + (fg, var4.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[3] == ( + "import", + (fg, var5.owner, "init"), + {}, + ) + cb_tracker.callback_history.clear() var6 = MyVariable2("var6") with pytest.raises(TypeError): @@ -209,6 +340,8 @@ def test_change_input(self): with pytest.raises(TypeError): fg.change_node_input(var5.owner, 1, var6) + assert len(cb_tracker.callback_history) == 0 + old_apply_nodes = set(fg.apply_nodes) old_variables = set(fg.variables) old_var5_clients = list(fg.get_clients(var5)) @@ -216,6 +349,8 @@ def test_change_input(self): # We're replacing with the same variable, so nothing should happen fg.change_node_input(var5.owner, 1, var2) + assert len(cb_tracker.callback_history) == 0 + assert old_apply_nodes == fg.apply_nodes assert old_variables == fg.variables assert old_var5_clients == fg.get_clients(var5) @@ -223,9 +358,35 @@ def test_change_input(self): # Perform a valid `Apply` node input change fg.change_node_input(var5.owner, 1, var1) - assert var5.owner.inputs[1] is var1 + assert var5.owner.inputs == [var4, var1, var2] + assert fg.outputs[1].owner == var5.owner assert (var5.owner, 1) not in fg.get_clients(var2) + assert len(cb_tracker.callback_history) == 1 + assert cb_tracker.callback_history[0] == ( + "change_input", + (fg, var5.owner, 1, var2, var1), + {"reason": None}, + ) + cb_tracker.callback_history.clear() + + # Perform a valid `Apply` node input change that results in a + # node removal (i.e. `var4.owner`) + fg.change_node_input(var5.owner, 0, var1) + + assert var5.owner.inputs[0] is var1 + assert not fg.get_clients(var4) + assert var4.owner not in fg.apply_nodes + assert var4 not in fg.variables + + assert len(cb_tracker.callback_history) == 2 + assert cb_tracker.callback_history[0] == ("prune", (fg, var4.owner, None), {}) + assert cb_tracker.callback_history[1] == ( + "change_input", + (fg, var5.owner, 0, var4, var1), + {"reason": None}, + ) + @config.change_flags(compute_test_value="raise") def test_replace_test_value(self): @@ -254,18 +415,212 @@ def test_replace(self): var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) - fg = FunctionGraph([var1, var2], [var3, var5], clone=False) + cb_tracker = CallbackTracker() + fg = FunctionGraph( + [var1, var2], [var3, var5], clone=False, features=[cb_tracker] + ) + + assert len(cb_tracker.callback_history) == 4 + assert cb_tracker.callback_history[0] == ("attach", (fg,), {}) + assert cb_tracker.callback_history[1] == ( + "import", + (fg, var3.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[2] == ( + "import", + (fg, var4.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[3] == ( + "import", + (fg, var5.owner, "init"), + {}, + ) + cb_tracker.callback_history.clear() with pytest.raises(TypeError): var0 = MyVariable2("var0") # The types don't match and one cannot be converted to the other fg.replace(var3, var0) + assert len(cb_tracker.callback_history) == 0 + # Test a basic replacement fg.replace_all([(var3, var1)]) assert var3 not in fg.variables assert fg.apply_nodes == {var4.owner, var5.owner} assert var4.owner.inputs == [var1, var2] + assert fg.outputs == [var1, var5] + + assert len(cb_tracker.callback_history) == 3 + assert cb_tracker.callback_history[0] == ( + "change_input", + (fg, "output", 0, var3, var1), + {"reason": None}, + ) + assert cb_tracker.callback_history[1] == ("prune", (fg, var3.owner, None), {}) + assert cb_tracker.callback_history[2] == ( + "change_input", + (fg, var4.owner, 0, var3, var1), + {"reason": None}, + ) + + var3 = op1(var1) + var4 = op2(var3) + var5 = op3(var4) + cb_tracker = CallbackTracker() + fg = FunctionGraph([var1], [var5], clone=False, features=[cb_tracker]) + + # Test a replacement that would remove the replacement variable + # (i.e. `var3`) from the graph when the variable to be replaced + # (i.e. `var4`) is removed + fg.replace_all([(var4, var3)]) + + assert fg.apply_nodes == {var3.owner, var5.owner} + assert fg.inputs == [var1] + assert fg.outputs == [var5] + assert fg.variables == {var1, var3, var5} + + assert cb_tracker.callback_history == [ + ("attach", (fg,), {}), + ("import", (fg, var3.owner, "init"), {}), + ("import", (fg, var4.owner, "init"), {}), + ("import", (fg, var5.owner, "init"), {}), + ("prune", (fg, var4.owner, None), {}), + ("change_input", (fg, var5.owner, 0, var4, var3), {"reason": None}), + ] + + var3 = op1(var1) + var4 = op2(var3) + var5 = op3(var4, var4) + cb_tracker = CallbackTracker() + fg = FunctionGraph([var1], [var5], clone=False, features=[cb_tracker]) + + # Test multiple `change_node_input` calls on the same node + fg.replace_all([(var4, var3)]) + + assert fg.apply_nodes == {var3.owner, var5.owner} + assert fg.inputs == [var1] + assert fg.outputs == [var5] + assert fg.variables == {var1, var3, var5} + + assert cb_tracker.callback_history == [ + ("attach", (fg,), {}), + ("import", (fg, var3.owner, "init"), {}), + ("import", (fg, var4.owner, "init"), {}), + ("import", (fg, var5.owner, "init"), {}), + ("change_input", (fg, var5.owner, 0, var4, var3), {"reason": None}), + ("prune", (fg, var4.owner, None), {}), + ("change_input", (fg, var5.owner, 1, var4, var3), {"reason": None}), + ] + + def test_replace_outputs(self): + var1 = MyVariable("var1") + var2 = MyVariable("var2") + var3 = op1(var1) + var4 = op2(var2) + cb_tracker = CallbackTracker() + fg = FunctionGraph( + [var1, var2], [var3, var4, var3], clone=False, features=[cb_tracker] + ) + + fg.replace_all([(var3, var1)]) + assert var3 not in fg.variables + + assert fg.apply_nodes == {var4.owner} + assert fg.outputs == [var1, var4, var1] + + assert cb_tracker.callback_history == [ + ("attach", (fg,), {}), + ("import", (fg, var3.owner, "init"), {}), + ("import", (fg, var4.owner, "init"), {}), + ("change_input", (fg, "output", 0, var3, var1), {"reason": None}), + ("prune", (fg, var3.owner, None), {}), + ("change_input", (fg, "output", 2, var3, var1), {"reason": None}), + ] + + def test_replace_contract(self): + x = MyVariable("x") + v1 = op1(x) + v2 = op1(v1) + v3 = op1(v2) + v4 = op1(v3) + + v1.name = "v1" + v2.name = "v2" + v3.name = "v3" + v4.name = "v4" + + cb_tracker = CallbackTracker() + fg = FunctionGraph([x], [v4], clone=False, features=[cb_tracker]) + + # This replacement should produce a new `Apply` node that's equivalent + # to `v2` and try to replace `v3`'s node with that one. In other + # words, the replacement creates a new node that's already in the + # `FunctionGraph`. + # The end result is `v3 = v2`. + fg.replace_all([(v2, v1)]) + + assert v2 not in fg.variables + assert fg.clients == { + x: [(v1.owner, 0)], + v1: [(v3.owner, 0)], + v2: [], + v3: [(v4.owner, 0)], + v4: [("output", 0)], + } + assert fg.apply_nodes == {v4.owner, v3.owner, v1.owner} + assert v2 not in set(sum((n.outputs for n in fg.apply_nodes), [])) + + assert cb_tracker.callback_history == [ + ("attach", (fg,), {}), + ("import", (fg, v1.owner, "init"), {}), + ("import", (fg, v2.owner, "init"), {}), + ("import", (fg, v3.owner, "init"), {}), + ("import", (fg, v4.owner, "init"), {}), + ("prune", (fg, v2.owner, None), {}), + ("change_input", (fg, v3.owner, 0, v2, v1), {"reason": None}), + ] + + # Let's try the same thing at a different point in the chain + x = MyVariable("x") + v1 = op1(x) + v2 = op1(v1) + v3 = op1(v2) + v4 = op1(v3) + + v1.name = "v1" + v2.name = "v2" + v3.name = "v3" + v4.name = "v4" + + cb_tracker = CallbackTracker() + fg = FunctionGraph([x], [v4], clone=False, features=[cb_tracker]) + + fg.replace_all([(v3, v2)]) + + assert v3 not in fg.variables + assert fg.clients == { + x: [(v1.owner, 0)], + v1: [(v2.owner, 0)], + v2: [(v4.owner, 0)], + v3: [], + v4: [("output", 0)], + } + assert fg.apply_nodes == {v4.owner, v2.owner, v1.owner} + assert v3 not in set(sum((n.outputs for n in fg.apply_nodes), [])) + + exp_res = [ + ("attach", (fg,), {}), + ("import", (fg, v1.owner, "init"), {}), + ("import", (fg, v2.owner, "init"), {}), + ("import", (fg, v3.owner, "init"), {}), + ("import", (fg, v4.owner, "init"), {}), + ("prune", (fg, v3.owner, None), {}), + ("change_input", (fg, v4.owner, 0, v3, v2), {"reason": None}), + ] + assert cb_tracker.callback_history == exp_res def test_replace_verbose(self, capsys): @@ -288,7 +643,29 @@ def test_replace_circular(self): var3 = op1(var2, var1) var4 = op2(var3, var2) var5 = op3(var4, var2, var2) - fg = FunctionGraph([var1, var2], [var3, var5], clone=False) + cb_tracker = CallbackTracker() + fg = FunctionGraph( + [var1, var2], [var3, var5], clone=False, features=[cb_tracker] + ) + + assert len(cb_tracker.callback_history) == 4 + assert cb_tracker.callback_history[0] == ("attach", (fg,), {}) + assert cb_tracker.callback_history[1] == ( + "import", + (fg, var3.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[2] == ( + "import", + (fg, var4.owner, "init"), + {}, + ) + assert cb_tracker.callback_history[3] == ( + "import", + (fg, var5.owner, "init"), + {}, + ) + cb_tracker.callback_history.clear() fg.replace_all([(var3, var4)]) @@ -297,6 +674,19 @@ def test_replace_circular(self): assert fg.apply_nodes == {var4.owner, var5.owner} assert var4.owner.inputs == [var4, var2] + assert len(cb_tracker.callback_history) == 3 + assert cb_tracker.callback_history[0] == ( + "change_input", + (fg, "output", 0, var3, var4), + {"reason": None}, + ) + assert cb_tracker.callback_history[1] == ("prune", (fg, var3.owner, None), {}) + assert cb_tracker.callback_history[2] == ( + "change_input", + (fg, var4.owner, 0, var3, var4), + {"reason": None}, + ) + def test_replace_bad_state(self): var1 = MyVariable("var1")