diff --git a/aesara/tensor/rewriting/elemwise.py b/aesara/tensor/rewriting/elemwise.py index 1973277da2..b1b2e4bcac 100644 --- a/aesara/tensor/rewriting/elemwise.py +++ b/aesara/tensor/rewriting/elemwise.py @@ -1,19 +1,27 @@ import sys -import time -from collections import defaultdict -from typing import Optional +from collections import defaultdict, deque +from functools import lru_cache +from typing import Any, Dict, Generator, List, Tuple +from typing import cast as typing_cast from warnings import warn import aesara import aesara.scalar.basic as aes from aesara import compile from aesara.configdefaults import config -from aesara.graph.basic import Apply, Constant, io_toposort +from aesara.graph import FunctionGraph +from aesara.graph.basic import ( + Apply, + Constant, + Variable, + ancestors, + clone_replace, + io_toposort, +) from aesara.graph.features import ReplaceValidate -from aesara.graph.op import compute_test_value, get_test_value from aesara.graph.rewriting.basic import GraphRewriter, copy_stack_trace, node_rewriter from aesara.graph.rewriting.db import SequenceDB -from aesara.graph.utils import InconsistencyError, MethodNotDefined, TestValueError +from aesara.graph.utils import InconsistencyError from aesara.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.exceptions import NotScalarConstantError @@ -531,337 +539,435 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): return rval -def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None): - r"""Create a recursive function that fuses `Elemwise` `Op`\s. +def elemwise_max_input_fct(node): + # `Elemwise.perform` uses NumPy ufuncs and they are limited to 31 inputs. + if not config.cxx: + return 31 + return 1024 - The basic idea is that we loop through an `Elemwise` node's inputs, find - other `Elemwise` nodes, determine the scalars input types for all of the - `Elemwise` `Op`\s, construct a new scalar `Op` using the scalar input types - and each `Elemwise`'s scalar `Op`, and use the composite scalar `Op` in a - new "fused" `Elemwise`. - It's parameterized in order to work for `Elemwise` `Op`\s. +class FusionOptimizer(GraphRewriter): + """Graph optimizer that fuses consecutive Elemwise operations.""" - Parameters - ---------- - op_class : type - `Elemwise` class (the one that we want to fuse) - max_input_fct : callable - A function that returns the maximum number of inputs that this `Elemwise` - can take. - On the CPU we limit to 32 input variables since that is the maximum - NumPy support. + def __init__(self, local_optimizer=None): + # TODO: Figure out what to do with this + super().__init__() + self.optimizer = local_optimizer - maker: callable - A function with the signature ``(node, *args)`` that constructs an - `op_class` instance (e.g. ``op_class(*args)``). + def add_requirements(self, fgraph): + fgraph.attach_feature(ReplaceValidate()) - """ - if maker is None: + @staticmethod + def elemwise_to_scalar(inputs, outputs): + replace_inputs = [(inp, inp.clone()) for inp in inputs] + outputs = clone_replace(outputs, replace=replace_inputs) + # print("elemwise_to_scalar replaced outputs:") + # aesara.dprint(outputs, print_type=True) - def maker(node, scalar_op): - return op_class(scalar_op) + inputs = [inp for _, inp in replace_inputs] + fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=False) + middle_inputs = [] - def local_fuse(fgraph, node): - r"""Fuse `Elemwise` `Op`\s in a node. + scalar_inputs = [ + aes.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs + ] + middle_scalar_inputs = [] + + for node in fg.toposort(): + node_scalar_inputs = [] + for inp in node.inputs: + if inp in inputs: + node_scalar_inputs.append(scalar_inputs[inputs.index(inp)]) + elif inp in middle_inputs: + node_scalar_inputs.append( + middle_scalar_inputs[middle_inputs.index(inp)] + ) + else: + new_scalar_input = aes.get_scalar_type( + inp.type.dtype + ).make_variable() + node_scalar_inputs.append(new_scalar_input) + middle_scalar_inputs.append(new_scalar_input) + middle_inputs.append(inp) + + new_scalar_node = node.op.scalar_op.make_node(*node_scalar_inputs) + middle_scalar_inputs.append(new_scalar_node.outputs[0]) + middle_inputs.append(node.outputs[0]) + + scalar_outputs = [ + middle_scalar_inputs[middle_inputs.index(out)] for out in fg.outputs + ] + return scalar_inputs, scalar_outputs - As part of specialization, we fuse two consecutive `Elemwise` `Op`\s of the - same shape. + def apply(self, fgraph): + nb_replacement = 0 - For mixed dtype, we let the `Composite` `Op` do the cast. It lets the C - compiler do the cast. + if fgraph.profile: + validate_before = fgraph.profile.validate_time + callbacks_before = fgraph.execute_callbacks_times.copy() + callback_before = fgraph.execute_callbacks_time - The number of dimensions is validated at call time by Aesara itself. + max_inputs = elemwise_max_input_fct(None) - """ - # TODO: use broadcast flag? + def find_next_fuseable_subgraph( + fg: FunctionGraph, + ) -> Generator[Tuple[List[Variable], List[Variable]], None, None]: + """Find all subgraphs in a FunctionGraph that can be fused together - # TODO: don't do this rewrite as a `NodeRewriter`. - # Analyze the graph in terms of elemwise subgraphs, and then - # replace each subgraph with a Composite version. + Yields + ------- + List of inputs and outputs that determine subgraphs which can be fused. This + method assumes that such replacement is done across iterations of the + generator. + """ - # TODO: use malloc and copy to transfer arguments that don't - # fit within the parameter space of 256 bytes - # - # TODO: Merge with multiple output to merge when an inputs - # have multiple clients. This can't be done with a `NodeRewriter` - - # TODO: Related: Support composites with multiple outputs - - # TODO: Use Composite to combine Elemwise and Reduce - # operations. We have to loop over the data anyway... might - # as well sum it up while we're at it (this can be trickier - # than i'm making it seound here. The data-traversal should be - # done contiguously, and the summing-up might not be easy or - # worthwhile if the summation axis doesn't line up with a - # contiguous dimension) - - if type(node.op) is not op_class: - return False - - if len(node.outputs) > 1: - # We don't support fusion for nodes with multiple outputs. - return - - inputs = [] # inputs of the new Elemwise op. - s_inputs = [] # inputs of the new scalar op used by the Composite. - # Inputs of the new scalar op that represents the current node. - s_g = [] - - # There is a hard limit of 256 bytes for the formal argument list to a - # GPU kernel function. - max_nb_input = max_input_fct(node) - # The number of inputs to the new fused op if we do not fuse more - # inputs. - new_nb_input = len(node.inputs) - # Did we fuse something? - # Needed as we can fuse unary op that don't change the number of - # inputs. - # And there is a case where the inputs are the same as the current - # node. That won't change the number of inputs of the new op. - fused = False - - for i in node.inputs: - scalar_node: Optional[Apply] = None - # Will store inputs of the fused node that are not currently inputs - # of the node we want to create (to avoid duplicating inputs). - tmp_input = [] - # Same as tmp_input, but for scalars. - tmp_scalar = [] - - # We should not check the number of inputs here - # As fusing op don't always change the number of input. - # If a variable is used as multiple into to the same node, - # we still want to fusion. So we take the set. - if ( - i.owner - and isinstance(i.owner.op, op_class) - and len({n for n, idx in fgraph.clients[i]}) == 1 - and - # Do not merge elemwise that don't have the same - # broadcastable pattern to don't redo duplicate - # computation due to broadcast. - i.owner.outputs[0].broadcastable == node.outputs[0].broadcastable - ): - try: - tmp_s_input = [] - # we should not put duplicate input into s_inputs and inputs - for ii in i.owner.inputs: - if ii in inputs: - tmp_s_input.append(s_inputs[inputs.index(ii)]) - elif ii in tmp_input: - tmp_s_input.append(tmp_scalar[tmp_input.index(ii)]) - else: - tmp = aes.get_scalar_type(ii.type.dtype).make_variable() - - try: - tv = get_test_value(ii) - # Sometimes the original inputs have - # zero-valued shapes in some dimensions, which - # implies that this whole scalar thing doesn't - # make sense (i.e. we're asking for the scalar - # value of an entry in a zero-dimensional - # array). - # This will eventually lead to an error in the - # `compute_test_value` call below when/if - # `config.compute_test_value_opt` is enabled - # (for debugging, more or less) - tmp.tag.test_value = tv.item() - except (TestValueError, ValueError): - pass - - tmp_s_input.append(tmp) - tmp_input.append(ii) - tmp_scalar.append(tmp_s_input[-1]) - - # Use the `Op.make_node` interface in case `Op.__call__` - # has been customized - scalar_node = i.owner.op.scalar_op.make_node(*tmp_s_input) - - if config.compute_test_value_opt != "off": - # This is required because `Op.make_node` won't do it - compute_test_value(scalar_node) - - # If the scalar_op doesn't have a C implementation, we skip - # its fusion to allow fusion of the other ops - i.owner.op.scalar_op.c_code( - scalar_node, - "test_presence_of_c_code", - ["x" for x in i.owner.inputs], - ["z" for z in i.owner.outputs], - {"fail": "%(fail)s"}, - ) - - except (NotImplementedError, MethodNotDefined): + @lru_cache(maxsize=None) + def elemwise_scalar_op_has_c_code(node: Apply) -> bool: + if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): + return True + else: warn( ( - "Rewrite warning: " - f"The Op {i.owner.op.scalar_op} does not provide a C implementation." + "Optimization Warning: " + f"The Op {node.op.scalar_op} does not provide a C implementation." " As well as being potentially slow, this also disables " "loop fusion." ) ) - scalar_node = None - - # Compute the number of inputs in case we fuse this input. - # We subtract 1 because we replace the existing input with the new - # inputs from `tmp_input`. - new_nb_input_ = new_nb_input + len(tmp_input) - 1 - - # If the new input is already an input of the current node, it was - # already counted when `new_nb_input` was initialized to - # len(node.inputs). - # This can happen when a variable is used both by the Elemwise to - # fuse and the current node. - for x in tmp_input: - if x in node.inputs: - new_nb_input_ -= 1 - - if scalar_node and (new_nb_input_ <= max_nb_input): - fused = True - new_nb_input = new_nb_input_ - inputs.extend(tmp_input) - s_inputs.extend(tmp_scalar) - s_g.extend(scalar_node.outputs) - else: - # We must support the case where the same variable appears many - # times within the inputs - if inputs.count(i) == node.inputs.count(i): - s = s_inputs[inputs.index(i)] - else: - s = aes.get_scalar_type(i.type.dtype).make_variable() - if config.compute_test_value_opt != "off": - try: - v = get_test_value(i) - # See the zero-dimensional test value situation - # described above. - s.tag.test_value = v.item() - except (TestValueError, ValueError): - pass - - inputs.append(i) - s_inputs.append(s) - s_g.append(s) - - if not fused: - return False - - if new_nb_input != len(inputs) or len(s_inputs) != len(inputs): - # TODO FIXME: This shouldn't be a generic `Exception` - raise Exception( - "Something has gone wrong with the elemwise fusion rewrite; skipping." - ) - - s_new_out = node.op.scalar_op(*s_g, return_list=True) - try: - s_new_out[0].owner.op.c_code( - s_new_out[0].owner, - "test_presence_of_c_code", - ["x" for x in s_g], - ["z" for x in s_new_out], - {"fail": "%(fail)s"}, - ) - except (NotImplementedError, MethodNotDefined): - name = str(s_new_out[0].owner.op) - warn( - ( - "Rewrite warning: " - f"The Op {name} does not provide a C implementation." - " As well as being potentially slow, this also disables " - "loop fusion." + return False + + # We start by creating two maps, 1) from each node to each potentially + # fuseable client (both nodes must be single output Elemwise with same + # broadcast type) and 2) from each node to each certainly unfuseable + # client (those that don't fit into 1)) + fuseable_clients: Dict[Any, List[Any]] = defaultdict(list) + unfuseable_clients: Dict[Any, List[Any]] = defaultdict(list) + for out, clients in fg.clients.items(): + out_maybe_fuseable = ( + out.owner + and isinstance(out.owner.op, Elemwise) + # and not isinstance(out.owner.op.scalar_op, aes.Composite) + and len(out.owner.outputs) == 1 + and elemwise_scalar_op_has_c_code(out.owner) ) - ) - return False - - # create the composite op. - composite_op = aes.Composite(s_inputs, s_new_out) - - # create the new node. - # Do not call make_node to have test_value - new_node = maker(node, composite_op)(*inputs).owner - - assert len(new_node.outputs) == 1 - assert node.outputs[0].type.dtype == new_node.outputs[0].type.dtype - - if len(new_node.inputs) > max_nb_input: - warn( - "Loop fusion failed because the resulting node " - "would exceed the kernel argument limit." - ) - return False - - # we fuse as many that we can at the same time to make debug mode faster - # debug mode will be faster as it won't test all intermediate step. - while True: - ret = local_fuse(fgraph, new_node) - if ret is not False and ret is not None: - assert len(ret) == len(new_node.outputs) - assert len(ret) == 1 - new_node = ret[0].owner - else: - break - - return new_node.outputs - - return local_fuse - + for client, _ in clients: + if ( + out_maybe_fuseable + and not isinstance(client, str) # "output" + and isinstance(client.op, Elemwise) + # and not isinstance(client.op.scalar_op, aes.Composite) + and len(client.outputs) == 1 + and out.type.broadcastable + == client.outputs[0].type.broadcastable + and elemwise_scalar_op_has_c_code(client) + ): + if client not in fuseable_clients[out]: + fuseable_clients[out].append(client) + else: + if client not in unfuseable_clients[out]: + unfuseable_clients[out].append(client) + + visited_nodes = set() + while True: + # print( + # "fuseable_clients:", + # { + # k: [out for v_ in v for out in v_.outputs] + # for k, v in fuseable_clients.items() + # }, + # ) + # print( + # "unfuseable_clients:", + # { + # k: [out for v_ in v if v_ != "output" for out in v_.outputs] + # for k, v in unfuseable_clients.items() + # }, + # ) + + # We walk through the apply nodes looking for one that has at least one + # candidate fuseable client + toposort = fg.toposort() + starting_nodes = set(toposort) + for starting_node in toposort: + if starting_node in visited_nodes: + continue -def elemwise_max_input_fct(node): - # `Elemwise.perform` uses NumPy ufuncs and they are limited to 31 inputs. - if not config.cxx: - return 31 - return 1024 + starting_out = starting_node.outputs[0] + if not fuseable_clients.get(starting_out): + # print(f"\n> Skipping {out} as it has no fuseable clients") + visited_nodes.add(starting_node) + continue + subgraph_inputs: List[Variable] = [] + subgraph_outputs: List[Variable] = [] + unfuseable_clients_subgraph = set() + # Manually "deepcopy" clients mapping as those will be altered in place. + # Cannot use `copy.deepcopy` because that would also copy the Aesara variables. + fuseable_clients_temp: Dict[Any, List[Any]] = defaultdict(list) + unfuseable_clients_temp: Dict[Any, List[Any]] = defaultdict(list) + fuseable_clients_temp.update( + { + out: [client for client in clients] + for out, clients in fuseable_clients.items() + } + ) + unfuseable_clients_temp.update( + { + out: [client for client in clients] + for out, clients in unfuseable_clients.items() + } + ) + fuseable_nodes_to_visit = deque([starting_node]) + # We now try to expand as much as possible towards the potentially + # fuseable clients and ancestors to detect the largest possible + # subgraph that can be Composed together into a single `Op`. The + # largest issue to watch out is for cyclical dependencies, where + # some inputs or clients may depend on other nodes of the same + # subgraph via a path that cannot be included in the Composite + # (unfuseable) + while fuseable_nodes_to_visit: + next_node = fuseable_nodes_to_visit.popleft() + visited_nodes.add(next_node) + next_out = next_node.outputs[0] + # print(f"\t{next_out=}, {subgraph_inputs=}, {subgraph_outputs=}, {fuseable_nodes_to_visit=}") + + # Node must become an output if it is to be fused. + must_become_output = ( + next_out not in fuseable_clients_temp + or next_out in unfuseable_clients_temp + ) -local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fct) + # We have backtracked to this node, and it may no longer be a + # viable output + if must_become_output and next_out in subgraph_outputs: + subgraph_outputs.remove(next_out) + # unfuseable_clients_subgraph = ( + # unfuseable_clients_subgraph + # - get_unfuseable_clients(unfuseable_clients_temp, out) + # ) + + required_unfuseable_inputs = [ + inp + for inp in next_node.inputs + if next_node in unfuseable_clients_temp.get(inp, []) + ] + + new_required_unfuseable_inputs = [ + inp + for inp in required_unfuseable_inputs + if inp not in subgraph_inputs + ] + + # print(f"\t\t{new_required_unfuseable_inputs=}, {required_unfuseable_inputs=}, {unfuseable_clients_subgraph=}") + must_backtrack = False + if new_required_unfuseable_inputs and subgraph_outputs: + # We need to check that any new ancestors required by this node + # do not depend on other outputs of the same subgraph, via + # an unfuseable path. + if any( + a in unfuseable_clients_subgraph + for a in ancestors( + [next_out], blockers=subgraph_outputs + ) + ): + # print("\t > Cannot fuse due to non-fuseable ancestor dependency in same subgraph") + must_backtrack = True + + if not must_backtrack: + implied_unfuseable_clients = { + c + for client in unfuseable_clients_temp.get(next_out, []) + if client != "output" + for c in client.outputs + } + + new_implied_unfuseable_clients = [ + client + for client in implied_unfuseable_clients + if client not in unfuseable_clients_subgraph + ] + + if new_implied_unfuseable_clients and subgraph_inputs: + # We need to check that any ancestors of the subgraph do not depend + # on other clients of this node, via an unfuseable path. + if any( + a in new_implied_unfuseable_clients + for a in ancestors(subgraph_inputs) + ): + # print("\t > Cannot fuse due to non-fuseable client dependency in same subgraph") + must_backtrack = True + + if must_backtrack: + for inp in next_node.inputs: + if ( + inp.owner in visited_nodes + # next_node could have the same input repeated + and next_node in fuseable_clients_temp[inp] + ): + fuseable_clients_temp[inp].remove(next_node) + unfuseable_clients_temp[inp].append(next_node) + # print(f"\t\t: Will have to revisit {inp} as it must now become an output of subgraph") + fuseable_nodes_to_visit.appendleft(inp.owner) + + for client in fuseable_clients_temp[next_out]: + if client in visited_nodes: + # MyPy does not know that fuseable clients can never be `output` clients + client = typing_cast(Apply, client) + fuseable_clients_temp[next_out].remove(client) + unfuseable_clients_temp[next_out].append(client) + # print(f"\t\t: Will have to revisit {client} as current node must now become an input of subgraph") + fuseable_nodes_to_visit.appendleft(client) + + # Revisit node at a later time + visited_nodes.remove(next_node) + continue + + for inp in new_required_unfuseable_inputs: + # Node could require the same new input multiple times + if inp not in subgraph_inputs: + subgraph_inputs.append(inp) + + if must_become_output: + # print("\t\tMust become output!") + subgraph_outputs.append(next_out) + # This node is now a "definite" part of the fused graph + unfuseable_clients_subgraph.update( + new_implied_unfuseable_clients + ) + for inp in sorted( + ( + inp + for inp in next_node.inputs + if ( + inp not in required_unfuseable_inputs + # No need to check if inp.owner is not None, as that + # would by definition be a required_unfuseable_input + and inp.owner not in visited_nodes + ) + ), + key=lambda inp: toposort.index(inp.owner), + reverse=True, + ): + # Expand through unvisited fuseable ancestors + fuseable_nodes_to_visit.appendleft(inp.owner) -class FusionOptimizer(GraphRewriter): - """Graph rewriter that simply runs node fusion operations. + for next_node in sorted( + fuseable_clients_temp.get(next_out, []), + key=lambda node: toposort.index(node), + ): + # Expand through unvisited fuseable clients + if next_node not in visited_nodes: + fuseable_nodes_to_visit.append(next_node) - TODO: This is basically an `EquilibriumGraphRewriter`; we should just use that. + # print(f"\t~ final fused subgraph: {subgraph_inputs=}, {subgraph_outputs=}") - """ + # Don't yield if final subgraph is just the original Elemwise + if ( + len(subgraph_outputs) == 1 + and ( + len(subgraph_outputs[0].owner.inputs) + == len(subgraph_inputs) + ) + and ( + set(subgraph_outputs[0].owner.inputs) + == set(subgraph_inputs) + ) + ): + # print(f"\t! final fused subgraph is just the original elemwise") + # Update fuseable mappings + # No input was actually fuseable + for inp in starting_node.inputs: + if ( + inp in fuseable_clients + and starting_node in fuseable_clients[inp] + ): + fuseable_clients[inp].remove(starting_node) + unfuseable_clients[inp].append(starting_node) + # No client was actually fuseable + for client in fuseable_clients.pop(starting_out, []): + unfuseable_clients[starting_out].append(client) + + else: + yield subgraph_inputs, subgraph_outputs + + # This is where we avoid repeated work by using a stateful + # generator. For large models (as in `TestFusion.test_big_fusion`) + # this can provide huge speedups + + # Update fuseable mappings + next_nodes = fg.apply_nodes + (new_composite_node,) = next_nodes - starting_nodes + dropped_nodes = starting_nodes - next_nodes + + # Remove intermediate Composite nodes from mappings + for dropped_node in dropped_nodes: + (dropped_out,) = dropped_node.outputs + fuseable_clients.pop(dropped_out, None) + unfuseable_clients.pop(dropped_out, None) + visited_nodes.remove(dropped_node) + + # Any input is now definitely unfuseable + for inp in subgraph_inputs: + if inp in fuseable_clients: + new_fuseable_clients = [ + client + for client in fuseable_clients[inp] + if client not in dropped_nodes + ] + if new_fuseable_clients: + fuseable_clients[inp] = new_fuseable_clients + else: + fuseable_clients.pop(inp) + unfuseable_clients[inp] = [ + client + for client in unfuseable_clients[inp] + if client not in dropped_nodes + ] + [new_composite_node] + + # Any client is now definitely unfuseable + for out in new_composite_node.outputs: + unfuseable_clients[out] = [ + client for client, _ in fg.clients[out] + ] + visited_nodes.add(new_composite_node) + break + else: # nobreak + return - def __init__(self, node_rewriter): - super().__init__() - self.node_rewriter = node_rewriter + # aesara.dprint(fgraph, print_type=True) + for res in find_next_fuseable_subgraph(fgraph): + # print(f">> >> Start of iteration {nb_replacement}: {len(fgraph.apply_nodes)=}") + if res is None: + # print("<< No further fuseable subgraph found") + break + inputs, outputs = res - def add_requirements(self, fgraph): - fgraph.attach_feature(ReplaceValidate()) + if len(inputs) > max_inputs: + warn( + "Loop fusion failed because the resulting node would exceed " + "the kernel argument limit." + ) + break - def apply(self, fgraph): - did_something = True - nb_iter = 0 - nb_replacement = 0 - nb_inconsistency_replace = 0 - time_toposort = 0 - if fgraph.profile: - validate_before = fgraph.profile.validate_time - callbacks_before = fgraph.execute_callbacks_times.copy() - callback_before = fgraph.execute_callbacks_time - while did_something: - t0 = time.time() - nodelist = list(fgraph.toposort()) - time_toposort += time.time() - t0 - nodelist.reverse() - did_something = False - for node in nodelist: - # Don't try to fuse node that have already been fused. - if node in fgraph.apply_nodes: - new_outputs = self.node_rewriter(fgraph, node) - if new_outputs: - assert len(new_outputs) == len(node.outputs) - try: - fgraph.replace_all_validate( - list(zip(node.outputs, new_outputs)), - reason=self.__class__.__name__, - ) - did_something = True - nb_replacement += 1 - except InconsistencyError: - nb_inconsistency_replace += 1 - nb_iter += 1 + scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs) + composite_outputs = Elemwise(aes.Composite(scalar_inputs, scalar_outputs))( + *inputs + ) + if not isinstance(composite_outputs, list): + composite_outputs = [composite_outputs] + for old_out, composite_out in zip(outputs, composite_outputs): + if old_out.name: + composite_out.name = old_out.name + + # print(f"{outputs=},\n{composite_outputs=},\n{inputs=}") + fgraph.replace_all_validate( + list(zip(outputs, composite_outputs)), + reason=self.__class__.__name__, + ) + # print(f"<< < | |TensorConstant{0.0} |D + |A """ ).lstrip()