diff --git a/aesara/tensor/rewriting/elemwise.py b/aesara/tensor/rewriting/elemwise.py index 08449bfc55..1973277da2 100644 --- a/aesara/tensor/rewriting/elemwise.py +++ b/aesara/tensor/rewriting/elemwise.py @@ -52,6 +52,14 @@ def print_profile(cls, stream, prof, level=0): for n in sorted(ndim.keys()): print(blanc, n, ndim[n], file=stream) + def candidate_input_idxs(self, node): + if isinstance(node.op.scalar_op, aes.Composite) and len(node.outputs) > 1: + # TODO: Implement specialized InplaceCompositeOptimizer with logic + # needed to correctly assign inplace for multi-output Composites + return [] + else: + return range(len(node.outputs)) + def apply(self, fgraph): r""" @@ -142,7 +150,7 @@ def apply(self, fgraph): baseline = op.inplace_pattern candidate_outputs = [ - i for i in range(len(node.outputs)) if i not in baseline + i for i in self.candidate_input_idxs(node) if i not in baseline ] # node inputs that are Constant, already destroyed, # or fgraph protected inputs and fgraph outputs can't be used as @@ -160,7 +168,7 @@ def apply(self, fgraph): ] else: baseline = [] - candidate_outputs = list(range(len(node.outputs))) + candidate_outputs = self.candidate_input_idxs(node) # node inputs that are Constant, already destroyed, # fgraph protected inputs and fgraph outputs can't be used as inplace # target.