diff --git a/aesara/tensor/basic_opt.py b/aesara/tensor/basic_opt.py index 1760537381..0d3b8b25b0 100644 --- a/aesara/tensor/basic_opt.py +++ b/aesara/tensor/basic_opt.py @@ -2,11 +2,11 @@ import logging import sys -import time import traceback from collections import defaultdict from io import StringIO -from typing import Optional +from itertools import chain +from typing import List, Optional, Tuple import numpy as np @@ -18,8 +18,10 @@ from aesara.graph.basic import ( Apply, Constant, + Node, Variable, ancestors, + clone_replace, equal_computations, io_toposort, ) @@ -3116,11 +3118,7 @@ def elemwise_max_input_fct(node): class FusionOptimizer(GlobalOptimizer): - """Graph optimizer that simply runs local fusion operations. - - TODO: This is basically a `EquilibriumOptimizer`; we should just use that. - - """ + """Graph optimizer that fuses consecutive Elemwise operations.""" def __init__(self, local_optimizer): super().__init__() @@ -3129,38 +3127,199 @@ def __init__(self, local_optimizer): def add_requirements(self, fgraph): fgraph.attach_feature(ReplaceValidate()) + def find_fuseable_subgraphs( + self, fg: FunctionGraph + ) -> List[Tuple[List[Variable], List[Variable]]]: + """Find all subgraphs in a FunctionGraph that can be fused together + + Returns + ------- + List of independent subgraphs inputs and outputs + """ + + def find_leaf_elemwise_vars(node: Node): + # Only consider nodes with single outputs + if len(node.outputs) != 1: + return [] + + if isinstance(node.op, Elemwise): + return [node.outputs[0]] + + upstream_root_elemwise_vars = ( + find_leaf_elemwise_vars(inp.owner) + for inp in node.inputs + if inp.owner is not None + ) + # Flatten root variables + return list(chain.from_iterable(upstream_root_elemwise_vars)) + + def find_root_consecutive_elemwise_vars(node: Node): + # TODO: Do not walk across broadcastad elemwises + # TODO: Add special C-code check + root_elemwise_vars = [] + for inp in node.inputs: + if ( + inp.owner + and isinstance(inp.owner.op, Elemwise) + and len(inp.owner.outputs) == 1 + # Do not merge Elemwise Ops that don't have the same + # broadcastable pattern to avoid duplicated computations + and inp.type.broadcastable == node.outputs[0].type.broadcastable + ): + # Try further upstream + root_elemwise_vars.extend( + find_root_consecutive_elemwise_vars(inp.owner) + ) + else: + root_elemwise_vars.append(inp) + return root_elemwise_vars + + # aesara.dprint(fg) + elemwise_outputs = set() + for out in fg.outputs: + if out.owner is not None: + elemwise_outputs.update(find_leaf_elemwise_vars(out.owner)) + + # print(f"{elemwise_outputs=}") + if not elemwise_outputs: + return [] + + elemwise_inputs: dict = { + out: find_root_consecutive_elemwise_vars(out.owner) + for out in elemwise_outputs + } + # print(f"{elemwise_inputs=}") + + # Filter out isolated elemwise nodes + elemwise_outputs = [ + out for out in elemwise_outputs if elemwise_inputs[out] != out.owner.inputs + ] + # print(f"{elemwise_outputs=}") + if not elemwise_outputs: + return [] + + # Separate subgraphs that share no inputs whatsoever + disjoint_elemwise_outputs = [[elemwise_outputs.pop(0)]] + for next_out in elemwise_outputs: + disjoint = True + for prev_outs in disjoint_elemwise_outputs: + for prev_out in prev_outs: + if any( + set(elemwise_inputs[next_out]) & set(elemwise_inputs[prev_out]) + ): + prev_outs.append(next_out) + disjoint = False + break + if not disjoint: + break + if disjoint: + disjoint_elemwise_outputs.append([next_out]) + # print(f"{disjoint_elemwise_outputs=}") + + disjoint_elemwise_inputs = [] + for outs in disjoint_elemwise_outputs: + inps = [] + for out in outs: + for inp in elemwise_inputs[out]: + if inp not in inps: + inps.append(inp) + disjoint_elemwise_inputs.append(inps) + # print(f"{disjoint_elemwise_inputs=}") + + fuseable_subgraphs = [ + (inps, outs) + for inps, outs in zip(disjoint_elemwise_inputs, disjoint_elemwise_outputs) + ] + + # Call function in the inputs + inputs = [] + for inps in elemwise_inputs.values(): + for inp in inps: + if inp not in inputs: + inputs.append(inp) + # print(f"{inputs=}") + # print(" ") + upstream_fg = FunctionGraph(outputs=inputs, clone=False) + fuseable_subgraphs.extend(self.find_fuseable_subgraphs(upstream_fg)) + + return fuseable_subgraphs + + def elemwise_to_scalar(self, inputs, outputs): + replace_inputs = [(inp, inp.type()) for inp in inputs] + outputs = clone_replace(outputs, replace=replace_inputs) + + inputs = [inp for _, inp in replace_inputs] + fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=False) + middle_inputs = [] + + scalar_inputs = [ + aes.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs + ] + middle_scalar_inputs = [] + + # print(f"{fg.toposort()=}") + for node in fg.toposort(): + node_scalar_inputs = [] + for inp in node.inputs: + if inp in inputs: + node_scalar_inputs.append(scalar_inputs[inputs.index(inp)]) + elif inp in middle_inputs: + node_scalar_inputs.append( + middle_scalar_inputs[middle_inputs.index(inp)] + ) + else: + new_scalar_input = aes.get_scalar_type( + inp.type.dtype + ).make_variable() + node_scalar_inputs.append(new_scalar_input) + middle_scalar_inputs.append(new_scalar_input) + middle_inputs.append(inp) + + new_scalar_node = node.op.scalar_op.make_node(*node_scalar_inputs) + middle_scalar_inputs.append(new_scalar_node.outputs[0]) + middle_inputs.append(node.outputs[0]) + + scalar_outputs = [ + middle_scalar_inputs[middle_inputs.index(out)] for out in fg.outputs + ] + return scalar_inputs, scalar_outputs + def apply(self, fgraph): - did_something = True - nb_iter = 0 nb_replacement = 0 nb_inconsistency_replace = 0 - time_toposort = 0 + if fgraph.profile: validate_before = fgraph.profile.validate_time callbacks_before = fgraph.execute_callbacks_times.copy() callback_before = fgraph.execute_callbacks_time - while did_something: - t0 = time.time() - nodelist = list(fgraph.toposort()) - time_toposort += time.time() - t0 - nodelist.reverse() - did_something = False - for node in nodelist: - # Don't try to fuse node that have already been fused. - if node in fgraph.apply_nodes: - new_outputs = self.optimizer(fgraph, node) - if new_outputs: - assert len(new_outputs) == len(node.outputs) - try: - fgraph.replace_all_validate( - list(zip(node.outputs, new_outputs)), - reason=self.__class__.__name__, - ) - did_something = True - nb_replacement += 1 - except InconsistencyError: - nb_inconsistency_replace += 1 - nb_iter += 1 + + max_inputs = elemwise_max_input_fct(None) + for inputs, outputs in self.find_fuseable_subgraphs(fgraph): + # TODO: If we care about Python mode, we should try to fuse the + # largest possible subgraphs based on number of inputs, instead + # of just failing like we used to do before + if len(inputs) > max_inputs: + _logger.warning( + "Loop fusion failed because the resulting node would exceed " + "the kernel argument limit." + ) + continue + scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs) + composite_outputs = Elemwise(aes.Composite(scalar_inputs, scalar_outputs))( + *inputs + ) + if not isinstance(composite_outputs, list): + composite_outputs = [composite_outputs] + + try: + # print(f"{outputs=}, {composite_outputs=}") + fgraph.replace_all_validate( + list(zip(outputs, composite_outputs)), + reason=self.__class__.__name__, + ) + nb_replacement += 1 + except InconsistencyError: + nb_inconsistency_replace += 1 if fgraph.profile: validate_time = fgraph.profile.validate_time - validate_before @@ -3175,19 +3334,21 @@ def apply(self, fgraph): validate_time = None callback_time = None callbacks_time = {} + return ( self, - nb_iter, + 1, # nb_iter nb_replacement, nb_inconsistency_replace, validate_time, callback_time, callbacks_time, - time_toposort, + 0, # toposort_time ) @staticmethod def print_profile(stream, prof, level=0): + # TODO: Update this blanc = " " * level print(blanc, "FusionOptimizer", file=stream) print(blanc, " nb_iter", prof[1], file=stream) diff --git a/tests/tensor/test_basic_opt.py b/tests/tensor/test_basic_opt.py index d8e07d0a06..6620d4e2d7 100644 --- a/tests/tensor/test_basic_opt.py +++ b/tests/tensor/test_basic_opt.py @@ -1,4 +1,3 @@ -import contextlib import copy import numpy as np @@ -1157,11 +1156,8 @@ def impl(self, x): @pytest.mark.parametrize("test_value", [np.c_[[1.0]], np.c_[[]]]) def test_test_values(self, test_value): - """Make sure that `local_elemwise_fusion_op` uses test values correctly when they have zero dimensions. - - The test values we're talking about are the ones used when C implementations - are checked. - + """Make sure that `local_elemwise_fusion_op` uses test values correctly + when they have zero dimensions. """ opts = OptimizationQuery( @@ -1181,27 +1177,33 @@ def test_test_values(self, test_value): y.tag.test_value = test_value z.tag.test_value = test_value - if test_value.size == 0: - cm = pytest.raises(ValueError) - else: - cm = contextlib.suppress() - with config.change_flags( compute_test_value="raise", compute_test_value_opt="raise" ): out = x * y + z - with cm: - f = function([x, y, z], out, mode=mode) + f = function([x, y, z], out, mode=mode) - if test_value.size != 0: - # Confirm that the fusion happened - assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite) - assert len(f.maker.fgraph.toposort()) == 1 + # Confirm that the fusion happened + assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite) + assert len(f.maker.fgraph.toposort()) == 1 - x_c, y_c, z_c = f.maker.fgraph.outputs[0].owner.inputs - assert np.array_equal( - f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]] - ) + assert np.array_equal( + f.maker.fgraph.outputs[0].tag.test_value, + np.full_like(test_value, 2.0), + ) + + def test_multiple_outputs(self): + x = vector("x") + y = exp(x / 4) + w = y * 2 + z = y + 2 + + f = aesara.function([x], [w, z]) + aesara.dprint(f) + assert len(f.maker.fgraph.apply_nodes) == 1 + r = f([0, 0]) + assert np.allclose(r[0], [2, 2]) + assert np.allclose(r[1], [3, 3]) class TimesN(aes.basic.UnaryScalarOp):