diff --git a/aesara/tensor/math_opt.py b/aesara/tensor/math_opt.py index d2d05f508a..bc83e33808 100644 --- a/aesara/tensor/math_opt.py +++ b/aesara/tensor/math_opt.py @@ -6,10 +6,18 @@ from functools import partial, reduce import numpy as np +from etuples import etuple +from kanren.assoccomm import assoc_flatten, associative, commutative +from kanren.core import lall +from kanren.graph import mapo +from unification import vars as lvars import aesara.scalar.basic as aes import aesara.scalar.math as aes_math +import aesara.tensor as at +from aesara.compile import optdb from aesara.graph.basic import Constant, Variable +from aesara.graph.kanren import KanrenRelationSub from aesara.graph.opt import ( LocalOptGroup, LocalOptimizer, @@ -19,6 +27,7 @@ local_optimizer, ) from aesara.graph.opt_utils import get_clients_at_depth +from aesara.graph.optdb import EquilibriumDB from aesara.misc.safe_asarray import _asarray from aesara.raise_op import assert_op from aesara.tensor.basic import ( @@ -3523,6 +3532,122 @@ def local_reciprocal_1_plus_exp(fgraph, node): return out +def distribute_mul_over_add(in_lv, out_lv): + from kanren import conso, eq, fact, heado, tailo + + # This does the optimization A * (x + y + z) = A * x + A * y + A * z + A_lv, add_term_lv, add_cdr_lv, mul_cdr_lv, add_flat_lv = lvars(5) + fact(commutative, at.mul) + fact(associative, at.mul) + fact(commutative, at.add) + fact(associative, at.add) + return lall( + # Make sure the input is a `at.mul` + eq(in_lv, etuple(at.mul, A_lv, add_term_lv)), + # Make sure the term being `at.mul`ed is an `add` + heado(at.add, add_term_lv), + # Flatten the associative pairings of `add` operations + assoc_flatten(add_term_lv, add_flat_lv), + # Get the flattened `add` arguments + tailo(add_cdr_lv, add_flat_lv), + # Add all the `at.mul`ed arguments and set the output + conso(at.add, mul_cdr_lv, out_lv), + # Apply the `at.mul` to all the flattened `add` arguments + mapo(lambda x, y: conso(at.mul, etuple(A_lv, x), y), add_cdr_lv, mul_cdr_lv), + ) + + +distribute_mul_over_add_opt = KanrenRelationSub( + lambda x, y: distribute_mul_over_add(y, x) +) +distribute_mul_over_add_opt.__name__ = distribute_mul_over_add.__name__ + + +def distribute_div_over_add(in_lv, out_lv): + from kanren import conso, eq, fact, heado, tailo + + # This does the optimization (x + y + z) / A = A / x + A / y + A / z + A_lv, add_term_lv, add_cdr_lv, div_cdr_lv, add_flat_lv = lvars(5) + fact(commutative, at.add) + fact(associative, at.add) + return lall( + # Make sure the input is a `at.div` + eq(in_lv, etuple(at.true_div, add_term_lv, A_lv)), + # Make sure the term being `at.div`ed is an `add` + heado(at.add, add_term_lv), + # Flatten the associative pairings of `add` operations + assoc_flatten(add_term_lv, add_flat_lv), + # Get the flattened `add` arguments + tailo(add_cdr_lv, add_flat_lv), + # Add all the `at.div`ed arguments and set the output + conso(at.add, div_cdr_lv, out_lv), + # Apply the `at.div` to all the flattened `add` arguments + mapo( + lambda x, y: conso(at.true_div, etuple(x, A_lv), y), add_cdr_lv, div_cdr_lv + ), + ) + + +distribute_div_over_add_opt = KanrenRelationSub( + lambda x, y: distribute_div_over_add(y, x) +) +distribute_div_over_add_opt.__name__ = distribute_div_over_add.__name__ + + +def distribute_mul_over_sub(in_lv, out_lv): + from kanren import eq, fact + + fact(commutative, at.mul) + fact(associative, at.mul) + a_lv, x_lv, y_lv = lvars(3) + + return lall( + # lhs == a * x - a * y + eq( + etuple( + at.sub, + etuple(at.mul, a_lv, x_lv), + etuple(at.mul, a_lv, y_lv), + ), + in_lv, + ), + # rhs == a * (x - y) + eq( + etuple(at.mul, a_lv, etuple(at.sub, x_lv, y_lv)), + out_lv, + ), + ) + + +distribute_mul_over_sub_opt = KanrenRelationSub(distribute_mul_over_sub) +distribute_mul_over_sub_opt.__name__ = distribute_mul_over_sub.__name__ + + +def distribute_div_over_sub(in_lv, out_lv): + from kanren import eq + + a_lv, x_lv, y_lv = lvars(3) + return lall( + # lhs == x / a - y / a + eq( + etuple( + at.sub, + etuple(at.true_div, x_lv, a_lv), + etuple(at.true_div, y_lv, a_lv), + ), + in_lv, + ), + # rhs == (x + y) / a + eq( + etuple(at.true_div, etuple(at.add, x_lv, y_lv), a_lv), + out_lv, + ), + ) + + +distribute_div_over_sub_opt = KanrenRelationSub(distribute_div_over_sub) +distribute_div_over_sub_opt.__name__ = distribute_div_over_sub.__name__ + # 1 - sigmoid(x) -> sigmoid(-x) local_1msigmoid = PatternSub( (sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x")), @@ -3567,3 +3692,44 @@ def local_reciprocal_1_plus_exp(fgraph, node): ) register_canonicalize(local_sigmoid_logit) register_specialize(local_sigmoid_logit) + + +fastmath = EquilibriumDB() + +optdb.register("fastmath", fastmath, 0.05, "fast_run", "mul") + +fastmath.register( + "dist_mul_over_add_opt", + in2out(distribute_mul_over_add_opt, ignore_newtrees=True), + 1, + "distribute_opts", + "fast_run", + "mul", +) + +fastmath.register( + "dist_div_over_add_opt", + in2out(distribute_div_over_add_opt, ignore_newtrees=True), + 1, + "distribute_opts", + "fast_run", + "div", +) + +fastmath.register( + "dist_mul_over_sub_opt", + in2out(distribute_mul_over_sub_opt, ignore_newtrees=True), + 1, + "distribute_opts", + "fast_run", + "div", +) + +fastmath.register( + "dist_div_over_sub_opt", + in2out(distribute_div_over_sub_opt, ignore_newtrees=True), + 1, + "distribute_opts", + "fast_run", + "div", +) diff --git a/tests/tensor/test_math_opt.py b/tests/tensor/test_math_opt.py index f5ec80dde7..1d0bd0742a 100644 --- a/tests/tensor/test_math_opt.py +++ b/tests/tensor/test_math_opt.py @@ -4549,3 +4549,49 @@ def logit_fn(x): fg = optimize(FunctionGraph([x], [out])) assert not list(fg.toposort()) assert fg.inputs[0] is fg.outputs[0] + + +class TestDistributiveOpts: + x_at = vector("x") + y_at = vector("y") + a_at = matrix("a") + + @pytest.mark.parametrize( + "orig_operation, optimized_operation", + [ + (a_at * x_at + a_at * y_at, a_at * (x_at + y_at)), + (x_at / a_at + y_at / a_at, (x_at + y_at) / a_at), + ((a_at * x_at - a_at * y_at, a_at * (x_at - y_at))), + (x_at / a_at - y_at / a_at, (x_at - y_at) / a_at), + ], + ) + def test_distributive_opts(self, orig_operation, optimized_operation): + fgraph = FunctionGraph([self.x_at, self.y_at, self.a_at], [orig_operation]) + out_orig = fgraph.outputs[0] + + fgraph_res = FunctionGraph( + [self.x_at, self.y_at, self.a_at], [optimized_operation] + ) + out_res = fgraph_res.outputs[0] + + fgraph_opt = optimize(fgraph) + out_opt = fgraph_opt.outputs[0] + + assert all( + [ + isinstance(out_orig.owner.op, Elemwise), + isinstance(out_res.owner.op, Elemwise), + isinstance(out_opt.owner.op, Elemwise), + ] + ) + + # The scalar op originally in the output node (The Op to be 'collected'). + # Should not be equal to the outer scalar Op in optimized version of graph + original_scalar_op = type(out_orig.owner.op.scalar_op) + # The outer scalar Op in in the resulting graph should + # be equal to the outer scalar Op in optimized version of the graph + resulting_scalar_op = type(out_res.owner.op.scalar_op) + optimized_scalar_op = type(out_opt.owner.op.scalar_op) + + assert not original_scalar_op == resulting_scalar_op + assert resulting_scalar_op == optimized_scalar_op