Skip to content

Commit

Permalink
Fix bug in Composite when multiple outputs are identical
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo Vieira committed Oct 27, 2022
1 parent e2d80eb commit cc409b8
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
15 changes: 15 additions & 0 deletions aesara/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4165,6 +4165,21 @@ def init_fgraph(self):
"The fgraph to Composite must be exclusively"
" composed of ScalarOp instances."
)

# Clone identical outputs that have been merged
if len(set(fgraph.outputs)) != len(self.outputs):
old_outputs = fgraph.outputs
new_outputs = []
for output in old_outputs:
if output not in new_outputs:
new_outputs.append(output)
else:
node = output.owner
output_idx = node.outputs.index(output)
new_output = node.clone().outputs[output_idx]
new_outputs.append(new_output)
fgraph = FunctionGraph(fgraph.inputs, new_outputs, clone=False)

self.fgraph = fgraph

def __init__(self, inputs, outputs):
Expand Down
11 changes: 11 additions & 0 deletions tests/scalar/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,17 @@ def test_many_outputs(self):
fn = make_function(DualLinker().accept(g))
assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]

def test_identical_outputs(self):
x, y, z = floats("xyz")
e0 = x + y + z
e1 = x + y + z
e2 = x / y
C = Composite([x, y, z], [e0, e1, e2])
c = C.make_node(x, y, z)
g = FunctionGraph([x, y, z], c.outputs)
fn = make_function(DualLinker().accept(g))
assert fn(1.0, 2.0, 3.0) == [6.0, 6.0, 0.5]

def test_composite_printing(self):
x, y, z = floats("xyz")
e0 = x + y + z
Expand Down

0 comments on commit cc409b8

Please sign in to comment.