diff --git a/aesara/tensor/rewriting/elemwise.py b/aesara/tensor/rewriting/elemwise.py index 08449bfc55..2442e19838 100644 --- a/aesara/tensor/rewriting/elemwise.py +++ b/aesara/tensor/rewriting/elemwise.py @@ -1,19 +1,19 @@ import sys -import time from collections import defaultdict -from typing import Optional +from itertools import chain +from typing import List, Tuple from warnings import warn import aesara import aesara.scalar.basic as aes from aesara import compile from aesara.configdefaults import config -from aesara.graph.basic import Apply, Constant, io_toposort +from aesara.graph import FunctionGraph +from aesara.graph.basic import Apply, Constant, Variable, clone_replace, io_toposort from aesara.graph.features import ReplaceValidate -from aesara.graph.op import compute_test_value, get_test_value from aesara.graph.rewriting.basic import GraphRewriter, copy_stack_trace, node_rewriter from aesara.graph.rewriting.db import SequenceDB -from aesara.graph.utils import InconsistencyError, MethodNotDefined, TestValueError +from aesara.graph.utils import InconsistencyError from aesara.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.exceptions import NotScalarConstantError @@ -523,281 +523,6 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): return rval -def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None): - r"""Create a recursive function that fuses `Elemwise` `Op`\s. - - The basic idea is that we loop through an `Elemwise` node's inputs, find - other `Elemwise` nodes, determine the scalars input types for all of the - `Elemwise` `Op`\s, construct a new scalar `Op` using the scalar input types - and each `Elemwise`'s scalar `Op`, and use the composite scalar `Op` in a - new "fused" `Elemwise`. - - It's parameterized in order to work for `Elemwise` `Op`\s. - - Parameters - ---------- - op_class : type - `Elemwise` class (the one that we want to fuse) - max_input_fct : callable - A function that returns the maximum number of inputs that this `Elemwise` - can take. - On the CPU we limit to 32 input variables since that is the maximum - NumPy support. - - maker: callable - A function with the signature ``(node, *args)`` that constructs an - `op_class` instance (e.g. ``op_class(*args)``). - - """ - if maker is None: - - def maker(node, scalar_op): - return op_class(scalar_op) - - def local_fuse(fgraph, node): - r"""Fuse `Elemwise` `Op`\s in a node. - - As part of specialization, we fuse two consecutive `Elemwise` `Op`\s of the - same shape. - - For mixed dtype, we let the `Composite` `Op` do the cast. It lets the C - compiler do the cast. - - The number of dimensions is validated at call time by Aesara itself. - - """ - # TODO: use broadcast flag? - - # TODO: don't do this rewrite as a `NodeRewriter`. - # Analyze the graph in terms of elemwise subgraphs, and then - # replace each subgraph with a Composite version. - - # TODO: use malloc and copy to transfer arguments that don't - # fit within the parameter space of 256 bytes - # - # TODO: Merge with multiple output to merge when an inputs - # have multiple clients. This can't be done with a `NodeRewriter` - - # TODO: Related: Support composites with multiple outputs - - # TODO: Use Composite to combine Elemwise and Reduce - # operations. We have to loop over the data anyway... might - # as well sum it up while we're at it (this can be trickier - # than i'm making it seound here. The data-traversal should be - # done contiguously, and the summing-up might not be easy or - # worthwhile if the summation axis doesn't line up with a - # contiguous dimension) - - if type(node.op) is not op_class: - return False - - if len(node.outputs) > 1: - # We don't support fusion for nodes with multiple outputs. - return - - inputs = [] # inputs of the new Elemwise op. - s_inputs = [] # inputs of the new scalar op used by the Composite. - # Inputs of the new scalar op that represents the current node. - s_g = [] - - # There is a hard limit of 256 bytes for the formal argument list to a - # GPU kernel function. - max_nb_input = max_input_fct(node) - # The number of inputs to the new fused op if we do not fuse more - # inputs. - new_nb_input = len(node.inputs) - # Did we fuse something? - # Needed as we can fuse unary op that don't change the number of - # inputs. - # And there is a case where the inputs are the same as the current - # node. That won't change the number of inputs of the new op. - fused = False - - for i in node.inputs: - scalar_node: Optional[Apply] = None - # Will store inputs of the fused node that are not currently inputs - # of the node we want to create (to avoid duplicating inputs). - tmp_input = [] - # Same as tmp_input, but for scalars. - tmp_scalar = [] - - # We should not check the number of inputs here - # As fusing op don't always change the number of input. - # If a variable is used as multiple into to the same node, - # we still want to fusion. So we take the set. - if ( - i.owner - and isinstance(i.owner.op, op_class) - and len({n for n, idx in fgraph.clients[i]}) == 1 - and - # Do not merge elemwise that don't have the same - # broadcastable pattern to don't redo duplicate - # computation due to broadcast. - i.owner.outputs[0].broadcastable == node.outputs[0].broadcastable - ): - try: - tmp_s_input = [] - # we should not put duplicate input into s_inputs and inputs - for ii in i.owner.inputs: - if ii in inputs: - tmp_s_input.append(s_inputs[inputs.index(ii)]) - elif ii in tmp_input: - tmp_s_input.append(tmp_scalar[tmp_input.index(ii)]) - else: - tmp = aes.get_scalar_type(ii.type.dtype).make_variable() - - try: - tv = get_test_value(ii) - # Sometimes the original inputs have - # zero-valued shapes in some dimensions, which - # implies that this whole scalar thing doesn't - # make sense (i.e. we're asking for the scalar - # value of an entry in a zero-dimensional - # array). - # This will eventually lead to an error in the - # `compute_test_value` call below when/if - # `config.compute_test_value_opt` is enabled - # (for debugging, more or less) - tmp.tag.test_value = tv.item() - except (TestValueError, ValueError): - pass - - tmp_s_input.append(tmp) - tmp_input.append(ii) - tmp_scalar.append(tmp_s_input[-1]) - - # Use the `Op.make_node` interface in case `Op.__call__` - # has been customized - scalar_node = i.owner.op.scalar_op.make_node(*tmp_s_input) - - if config.compute_test_value_opt != "off": - # This is required because `Op.make_node` won't do it - compute_test_value(scalar_node) - - # If the scalar_op doesn't have a C implementation, we skip - # its fusion to allow fusion of the other ops - i.owner.op.scalar_op.c_code( - scalar_node, - "test_presence_of_c_code", - ["x" for x in i.owner.inputs], - ["z" for z in i.owner.outputs], - {"fail": "%(fail)s"}, - ) - - except (NotImplementedError, MethodNotDefined): - warn( - ( - "Rewrite warning: " - f"The Op {i.owner.op.scalar_op} does not provide a C implementation." - " As well as being potentially slow, this also disables " - "loop fusion." - ) - ) - scalar_node = None - - # Compute the number of inputs in case we fuse this input. - # We subtract 1 because we replace the existing input with the new - # inputs from `tmp_input`. - new_nb_input_ = new_nb_input + len(tmp_input) - 1 - - # If the new input is already an input of the current node, it was - # already counted when `new_nb_input` was initialized to - # len(node.inputs). - # This can happen when a variable is used both by the Elemwise to - # fuse and the current node. - for x in tmp_input: - if x in node.inputs: - new_nb_input_ -= 1 - - if scalar_node and (new_nb_input_ <= max_nb_input): - fused = True - new_nb_input = new_nb_input_ - inputs.extend(tmp_input) - s_inputs.extend(tmp_scalar) - s_g.extend(scalar_node.outputs) - else: - # We must support the case where the same variable appears many - # times within the inputs - if inputs.count(i) == node.inputs.count(i): - s = s_inputs[inputs.index(i)] - else: - s = aes.get_scalar_type(i.type.dtype).make_variable() - if config.compute_test_value_opt != "off": - try: - v = get_test_value(i) - # See the zero-dimensional test value situation - # described above. - s.tag.test_value = v.item() - except (TestValueError, ValueError): - pass - - inputs.append(i) - s_inputs.append(s) - s_g.append(s) - - if not fused: - return False - - if new_nb_input != len(inputs) or len(s_inputs) != len(inputs): - # TODO FIXME: This shouldn't be a generic `Exception` - raise Exception( - "Something has gone wrong with the elemwise fusion rewrite; skipping." - ) - - s_new_out = node.op.scalar_op(*s_g, return_list=True) - try: - s_new_out[0].owner.op.c_code( - s_new_out[0].owner, - "test_presence_of_c_code", - ["x" for x in s_g], - ["z" for x in s_new_out], - {"fail": "%(fail)s"}, - ) - except (NotImplementedError, MethodNotDefined): - name = str(s_new_out[0].owner.op) - warn( - ( - "Rewrite warning: " - f"The Op {name} does not provide a C implementation." - " As well as being potentially slow, this also disables " - "loop fusion." - ) - ) - return False - - # create the composite op. - composite_op = aes.Composite(s_inputs, s_new_out) - - # create the new node. - # Do not call make_node to have test_value - new_node = maker(node, composite_op)(*inputs).owner - - assert len(new_node.outputs) == 1 - assert node.outputs[0].type.dtype == new_node.outputs[0].type.dtype - - if len(new_node.inputs) > max_nb_input: - warn( - "Loop fusion failed because the resulting node " - "would exceed the kernel argument limit." - ) - return False - - # we fuse as many that we can at the same time to make debug mode faster - # debug mode will be faster as it won't test all intermediate step. - while True: - ret = local_fuse(fgraph, new_node) - if ret is not False and ret is not None: - assert len(ret) == len(new_node.outputs) - assert len(ret) == 1 - new_node = ret[0].owner - else: - break - - return new_node.outputs - - return local_fuse - - def elemwise_max_input_fct(node): # `Elemwise.perform` uses NumPy ufuncs and they are limited to 31 inputs. if not config.cxx: @@ -805,55 +530,232 @@ def elemwise_max_input_fct(node): return 1024 -local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fct) - - class FusionOptimizer(GraphRewriter): - """Graph rewriter that simply runs node fusion operations. - - TODO: This is basically an `EquilibriumGraphRewriter`; we should just use that. + """Graph optimizer that fuses consecutive Elemwise operations.""" - """ - - def __init__(self, node_rewriter): + def __init__(self, local_optimizer=None): + # TODO: Figure out what to do with this super().__init__() - self.node_rewriter = node_rewriter + self.optimizer = local_optimizer def add_requirements(self, fgraph): fgraph.attach_feature(ReplaceValidate()) + def find_fuseable_subgraphs( + self, fg: FunctionGraph + ) -> List[Tuple[List[Variable], List[Variable]]]: + """Find all subgraphs in a FunctionGraph that can be fused together + + Returns + ------- + List of independent subgraphs inputs and outputs + """ + + def elemwise_scalar_op_has_c_code(node: Apply): + if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): + return True + else: + warn( + ( + "Optimization Warning: " + f"The Op {node.op.scalar_op} does not provide a C implementation." + " As well as being potentially slow, this also disables " + "loop fusion." + ) + ) + return False + + def find_leaf_elemwise_vars(node: Apply): + # Only consider nodes with single outputs + if len(node.outputs) != 1: + return [] + + # TODO: This will raise a warning even if Fusion wouldn't be applicable + # (i.e., when we have an isolated elemwise node) + if isinstance(node.op, Elemwise) and elemwise_scalar_op_has_c_code(node): + return [node.outputs[0]] + + # In this case we didn't yet find an appropriate Elemwise node, + # keep searching upstream + upstream_leaf_elemwise_vars = ( + find_leaf_elemwise_vars(inp.owner) + for inp in node.inputs + if inp.owner is not None + ) + # Flatten root variables + return list(chain.from_iterable(upstream_leaf_elemwise_vars)) + + def find_root_consecutive_elemwise_vars(node): + root_elemwise_vars = [] + for inp in node.inputs: + if ( + inp.owner + and isinstance(inp.owner.op, Elemwise) + and len(inp.owner.outputs) == 1 + # Do not merge Elemwise Ops that don't have the same + # broadcastable pattern to avoid duplicated computations + and inp.type.broadcastable == node.outputs[0].type.broadcastable + # TODO: We should specialize FusionOptimizer for different + # backends. This does not matter for non-C backends + and elemwise_scalar_op_has_c_code(inp.owner) + ): + # Try further upstream + root_elemwise_vars.extend( + find_root_consecutive_elemwise_vars(inp.owner) + ) + else: + root_elemwise_vars.append(inp) + return root_elemwise_vars + + # aesara.dprint(fg) + elemwise_outputs = [] + for out in fg.outputs: + if out.owner is not None: + for leaf in find_leaf_elemwise_vars(out.owner): + if leaf not in elemwise_outputs: + elemwise_outputs.append(leaf) + + # print(f"{elemwise_outputs=}") + if not elemwise_outputs: + return [] + + elemwise_inputs = { + out: find_root_consecutive_elemwise_vars(out.owner) + for out in elemwise_outputs + } + # print(f"{elemwise_inputs=}") + + # Filter out isolated elemwise nodes + # TODO: Don't filter if they have shared inputs with another output + elemwise_outputs = [ + out for out in elemwise_outputs if elemwise_inputs[out] != out.owner.inputs + ] + # print(f"{elemwise_outputs=}") + if not elemwise_outputs: + return [] + + # Separate subgraphs that share no inputs whatsoever + disjoint_elemwise_outputs = [[elemwise_outputs.pop(0)]] + for next_out in elemwise_outputs: + disjoint = True + for prev_outs in disjoint_elemwise_outputs: + for prev_out in prev_outs: + if any( + set(elemwise_inputs[next_out]) & set(elemwise_inputs[prev_out]) + ): + prev_outs.append(next_out) + disjoint = False + break + if not disjoint: + break + if disjoint: + disjoint_elemwise_outputs.append([next_out]) + # print(f"{disjoint_elemwise_outputs=}") + + disjoint_elemwise_inputs = [] + for outs in disjoint_elemwise_outputs: + inps = [] + for out in outs: + for inp in elemwise_inputs[out]: + if inp not in inps: + inps.append(inp) + disjoint_elemwise_inputs.append(inps) + # print(f"{disjoint_elemwise_inputs=}") + + fuseable_subgraphs = [ + (inps, outs) + for inps, outs in zip(disjoint_elemwise_inputs, disjoint_elemwise_outputs) + ] + + # Call function in the inputs + inputs = [] + for inps in elemwise_inputs.values(): + for inp in inps: + if inp not in inputs: + inputs.append(inp) + # print(f"{inputs=}") + # print(" ") + upstream_fg = FunctionGraph(outputs=inputs, clone=False) + fuseable_subgraphs.extend(self.find_fuseable_subgraphs(upstream_fg)) + + return fuseable_subgraphs + + def elemwise_to_scalar(self, inputs, outputs): + replace_inputs = [(inp, inp.type()) for inp in inputs] + outputs = clone_replace(outputs, replace=replace_inputs) + + inputs = [inp for _, inp in replace_inputs] + fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=False) + middle_inputs = [] + + scalar_inputs = [ + aes.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs + ] + middle_scalar_inputs = [] + + # print(f"{fg.toposort()=}") + for node in fg.toposort(): + node_scalar_inputs = [] + for inp in node.inputs: + if inp in inputs: + node_scalar_inputs.append(scalar_inputs[inputs.index(inp)]) + elif inp in middle_inputs: + node_scalar_inputs.append( + middle_scalar_inputs[middle_inputs.index(inp)] + ) + else: + new_scalar_input = aes.get_scalar_type( + inp.type.dtype + ).make_variable() + node_scalar_inputs.append(new_scalar_input) + middle_scalar_inputs.append(new_scalar_input) + middle_inputs.append(inp) + + new_scalar_node = node.op.scalar_op.make_node(*node_scalar_inputs) + middle_scalar_inputs.append(new_scalar_node.outputs[0]) + middle_inputs.append(node.outputs[0]) + + scalar_outputs = [ + middle_scalar_inputs[middle_inputs.index(out)] for out in fg.outputs + ] + return scalar_inputs, scalar_outputs + def apply(self, fgraph): - did_something = True - nb_iter = 0 nb_replacement = 0 nb_inconsistency_replace = 0 - time_toposort = 0 + if fgraph.profile: validate_before = fgraph.profile.validate_time callbacks_before = fgraph.execute_callbacks_times.copy() callback_before = fgraph.execute_callbacks_time - while did_something: - t0 = time.time() - nodelist = list(fgraph.toposort()) - time_toposort += time.time() - t0 - nodelist.reverse() - did_something = False - for node in nodelist: - # Don't try to fuse node that have already been fused. - if node in fgraph.apply_nodes: - new_outputs = self.node_rewriter(fgraph, node) - if new_outputs: - assert len(new_outputs) == len(node.outputs) - try: - fgraph.replace_all_validate( - list(zip(node.outputs, new_outputs)), - reason=self.__class__.__name__, - ) - did_something = True - nb_replacement += 1 - except InconsistencyError: - nb_inconsistency_replace += 1 - nb_iter += 1 + + max_inputs = elemwise_max_input_fct(None) + for inputs, outputs in self.find_fuseable_subgraphs(fgraph): + # TODO: If we care about Python mode, we should try to fuse the + # largest possible subgraphs based on number of inputs, instead + # of just failing like we used to do before + if len(inputs) > max_inputs: + warn( + "Loop fusion failed because the resulting node would exceed " + "the kernel argument limit." + ) + continue + scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs) + composite_outputs = Elemwise(aes.Composite(scalar_inputs, scalar_outputs))( + *inputs + ) + if not isinstance(composite_outputs, list): + composite_outputs = [composite_outputs] + + try: + # print(f"{outputs=}, {composite_outputs=}") + fgraph.replace_all_validate( + list(zip(outputs, composite_outputs)), + reason=self.__class__.__name__, + ) + nb_replacement += 1 + except InconsistencyError: + nb_inconsistency_replace += 1 if fgraph.profile: validate_time = fgraph.profile.validate_time - validate_before @@ -868,21 +770,23 @@ def apply(self, fgraph): validate_time = None callback_time = None callbacks_time = {} + return ( self, - nb_iter, + 1, # nb_iter nb_replacement, nb_inconsistency_replace, validate_time, callback_time, callbacks_time, - time_toposort, + 0, # toposort_time ) - @classmethod - def print_profile(cls, stream, prof, level=0): + @staticmethod + def print_profile(stream, prof, level=0): + # TODO: Update this blanc = " " * level - print(blanc, cls.__name__, file=stream) + print(blanc, "FusionOptimizer", file=stream) print(blanc, " nb_iter", prof[1], file=stream) print(blanc, " nb_replacement", prof[2], file=stream) print(blanc, " nb_inconsistency_replace", prof[3], file=stream) @@ -900,7 +804,7 @@ def print_profile(cls, stream, prof, level=0): if config.tensor__local_elemwise_fusion: fuse_seqopt.register( "composite_elemwise_fusion", - FusionOptimizer(local_elemwise_fusion), + FusionOptimizer(), "fast_run", "fusion", position=1, @@ -918,7 +822,7 @@ def print_profile(cls, stream, prof, level=0): else: compile.optdb.register( # type: ignore "elemwise_fusion", - FusionOptimizer(local_elemwise_fusion), + FusionOptimizer(), "fusion", "local_elemwise_fusion", "FusionOptimizer", diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 51958402ff..acfd1f642a 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1,5 +1,3 @@ -import contextlib - import numpy as np import pytest @@ -998,6 +996,7 @@ def test_big_fusion(self): for node in dlogp.maker.fgraph.toposort() ) + @pytest.mark.xfail(reason="Fails due to #1244") def test_add_mul_fusion_precedence(self): """Test that additions and multiplications are "fused together" before a `Composite` `Op` is introduced. This fusion is done by canonicalization @@ -1074,11 +1073,8 @@ def impl(self, x): @pytest.mark.parametrize("test_value", [np.c_[[1.0]], np.c_[[]]]) def test_test_values(self, test_value): - """Make sure that `local_elemwise_fusion_op` uses test values correctly when they have zero dimensions. - - The test values we're talking about are the ones used when C implementations - are checked. - + """Make sure that `local_elemwise_fusion_op` uses test values correctly + when they have zero dimensions. """ rewrites = RewriteDatabaseQuery( @@ -1098,27 +1094,20 @@ def test_test_values(self, test_value): y.tag.test_value = test_value z.tag.test_value = test_value - if test_value.size == 0: - cm = pytest.raises(ValueError) - else: - cm = contextlib.suppress() - with config.change_flags( compute_test_value="raise", compute_test_value_opt="raise" ): out = x * y + z - with cm: - f = function([x, y, z], out, mode=mode) + f = function([x, y, z], out, mode=mode) - if test_value.size != 0: - # Confirm that the fusion happened - assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite) - assert len(f.maker.fgraph.toposort()) == 1 + # Confirm that the fusion happened + assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite) + assert len(f.maker.fgraph.toposort()) == 1 - x_c, y_c, z_c = f.maker.fgraph.outputs[0].owner.inputs - assert np.array_equal( - f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]] - ) + assert np.array_equal( + f.maker.fgraph.outputs[0].tag.test_value, + np.full_like(test_value, 2.0), + ) def test_not_fusing_broadcasted_subgraphs(self): # There are some cases in self.test_elemwise_fusion, but this test @@ -1148,6 +1137,40 @@ def test_not_fusing_broadcasted_subgraphs(self): aes.mul, } + def test_multiple_outputs(self): + x = vector("x") + y = exp(x / 4) + w = y * 2 + z = y + 2 + + f = aesara.function([x], [w, z]) + aesara.dprint(f) + assert len(f.maker.fgraph.apply_nodes) == 1 + r = f([0, 0]) + assert np.allclose(r[0], [2, 2]) + assert np.allclose(r[1], [3, 3]) + + @pytest.mark.xfail(reason="Not implemented yet") + def test_multiple_outputs_fused_root_elemwise(self): + """Test that a root elemwise output (single layer) is reused when + there is another fused output""" + + # By default, we do not introduce Composite for single layers of Elemwise + x = at.vector("x") + out1 = at.cos(x) + f = aesara.function([x], out1) + nodes = tuple(f.maker.fgraph.apply_nodes) + assert len(nodes) == 1 + assert isinstance(nodes[0].op.scalar_op, aes.Cos) + + # However, when it can be composed with another output, we should not + # compute that root Elemwise twice + out2 = at.log(out1) + f = aesara.function([x], [out1, out2]) + nodes = tuple(f.maker.fgraph.apply_nodes) + assert len(nodes) == 1 + assert isinstance(nodes[0].op.scalar_op, Composite) + class TimesN(aes.basic.UnaryScalarOp): """