From 618c11c7a8c639c687f8574c79ecfab3b62b9d7f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 11 Oct 2022 14:40:41 +0200 Subject: [PATCH] Temporarily disable inplace for multiple-output Composites --- aesara/tensor/rewriting/elemwise.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/aesara/tensor/rewriting/elemwise.py b/aesara/tensor/rewriting/elemwise.py index 37b4a534da..743507dd82 100644 --- a/aesara/tensor/rewriting/elemwise.py +++ b/aesara/tensor/rewriting/elemwise.py @@ -59,6 +59,14 @@ def print_profile(cls, stream, prof, level=0): for n in sorted(ndim.keys()): print(blanc, n, ndim[n], file=stream) + def candidate_input_idxs(self, node): + if isinstance(node.op.scalar_op, aes.Composite) and len(node.outputs) > 1: + # TODO: Implement specialized InplaceCompositeOptimizer with logic + # needed to correctly assign inplace for multi-output Composites + return [] + else: + return range(len(node.outputs)) + def apply(self, fgraph): r""" @@ -149,7 +157,7 @@ def apply(self, fgraph): baseline = op.inplace_pattern candidate_outputs = [ - i for i in range(len(node.outputs)) if i not in baseline + i for i in self.candidate_input_idxs(node) if i not in baseline ] # node inputs that are Constant, already destroyed, # or fgraph protected inputs and fgraph outputs can't be used as @@ -167,7 +175,7 @@ def apply(self, fgraph): ] else: baseline = [] - candidate_outputs = list(range(len(node.outputs))) + candidate_outputs = self.candidate_input_idxs(node) # node inputs that are Constant, already destroyed, # fgraph protected inputs and fgraph outputs can't be used as inplace # target.