Skip to content

Commit

Permalink
Added KanrenRelationSub for distributive rewrites
Browse files Browse the repository at this point in the history
  • Loading branch information
kc611 committed Jan 9, 2022
1 parent f2be969 commit 7b03c00
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 0 deletions.
166 changes: 166 additions & 0 deletions aesara/tensor/math_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@
from functools import partial, reduce

import numpy as np
from etuples import etuple
from kanren.assoccomm import assoc_flatten, associative, commutative
from kanren.core import lall
from kanren.graph import mapo
from unification import vars as lvars

import aesara.scalar.basic as aes
import aesara.scalar.math as aes_math
import aesara.tensor as at
from aesara.compile import optdb
from aesara.graph.basic import Constant, Variable
from aesara.graph.kanren import KanrenRelationSub
from aesara.graph.opt import (
LocalOptGroup,
LocalOptimizer,
Expand All @@ -19,6 +27,7 @@
local_optimizer,
)
from aesara.graph.opt_utils import get_clients_at_depth
from aesara.graph.optdb import EquilibriumDB
from aesara.misc.safe_asarray import _asarray
from aesara.raise_op import assert_op
from aesara.tensor.basic import (
Expand Down Expand Up @@ -3523,6 +3532,122 @@ def local_reciprocal_1_plus_exp(fgraph, node):
return out


def distribute_mul_over_add(in_lv, out_lv):
from kanren import conso, eq, fact, heado, tailo

# This does the optimization A * (x + y + z) = A * x + A * y + A * z
A_lv, add_term_lv, add_cdr_lv, mul_cdr_lv, add_flat_lv = lvars(5)
fact(commutative, at.mul)
fact(associative, at.mul)
fact(commutative, at.add)
fact(associative, at.add)
return lall(
# Make sure the input is a `at.mul`
eq(in_lv, etuple(at.mul, A_lv, add_term_lv)),
# Make sure the term being `at.mul`ed is an `add`
heado(at.add, add_term_lv),
# Flatten the associative pairings of `add` operations
assoc_flatten(add_term_lv, add_flat_lv),
# Get the flattened `add` arguments
tailo(add_cdr_lv, add_flat_lv),
# Add all the `at.mul`ed arguments and set the output
conso(at.add, mul_cdr_lv, out_lv),
# Apply the `at.mul` to all the flattened `add` arguments
mapo(lambda x, y: conso(at.mul, etuple(A_lv, x), y), add_cdr_lv, mul_cdr_lv),
)


distribute_mul_over_add_opt = KanrenRelationSub(
lambda x, y: distribute_mul_over_add(y, x)
)
distribute_mul_over_add_opt.__name__ = distribute_mul_over_add.__name__


def distribute_div_over_add(in_lv, out_lv):
from kanren import conso, eq, fact, heado, tailo

# This does the optimization (x + y + z) / A = A / x + A / y + A / z
A_lv, add_term_lv, add_cdr_lv, div_cdr_lv, add_flat_lv = lvars(5)
fact(commutative, at.add)
fact(associative, at.add)
return lall(
# Make sure the input is a `at.div`
eq(in_lv, etuple(at.true_div, add_term_lv, A_lv)),
# Make sure the term being `at.div`ed is an `add`
heado(at.add, add_term_lv),
# Flatten the associative pairings of `add` operations
assoc_flatten(add_term_lv, add_flat_lv),
# Get the flattened `add` arguments
tailo(add_cdr_lv, add_flat_lv),
# Add all the `at.div`ed arguments and set the output
conso(at.add, div_cdr_lv, out_lv),
# Apply the `at.div` to all the flattened `add` arguments
mapo(
lambda x, y: conso(at.true_div, etuple(x, A_lv), y), add_cdr_lv, div_cdr_lv
),
)


distribute_div_over_add_opt = KanrenRelationSub(
lambda x, y: distribute_div_over_add(y, x)
)
distribute_div_over_add_opt.__name__ = distribute_div_over_add.__name__


def distribute_mul_over_sub(in_lv, out_lv):
from kanren import eq, fact

fact(commutative, at.mul)
fact(associative, at.mul)
a_lv, x_lv, y_lv = lvars(3)

return lall(
# lhs == a * x - a * y
eq(
etuple(
at.sub,
etuple(at.mul, a_lv, x_lv),
etuple(at.mul, a_lv, y_lv),
),
in_lv,
),
# rhs == a * (x - y)
eq(
etuple(at.mul, a_lv, etuple(at.sub, x_lv, y_lv)),
out_lv,
),
)


distribute_mul_over_sub_opt = KanrenRelationSub(distribute_mul_over_sub)
distribute_mul_over_sub_opt.__name__ = distribute_mul_over_sub.__name__


def distribute_div_over_sub(in_lv, out_lv):
from kanren import eq

a_lv, x_lv, y_lv = lvars(3)
return lall(
# lhs == x / a - y / a
eq(
etuple(
at.sub,
etuple(at.true_div, x_lv, a_lv),
etuple(at.true_div, y_lv, a_lv),
),
in_lv,
),
# rhs == (x + y) / a
eq(
etuple(at.true_div, etuple(at.add, x_lv, y_lv), a_lv),
out_lv,
),
)


distribute_div_over_sub_opt = KanrenRelationSub(distribute_div_over_sub)
distribute_div_over_sub_opt.__name__ = distribute_div_over_sub.__name__

# 1 - sigmoid(x) -> sigmoid(-x)
local_1msigmoid = PatternSub(
(sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x")),
Expand Down Expand Up @@ -3567,3 +3692,44 @@ def local_reciprocal_1_plus_exp(fgraph, node):
)
register_canonicalize(local_sigmoid_logit)
register_specialize(local_sigmoid_logit)


fastmath = EquilibriumDB()

optdb.register("fastmath", fastmath, 0.05, "fast_run", "mul")

fastmath.register(
"dist_mul_over_add_opt",
in2out(distribute_mul_over_add_opt, ignore_newtrees=True),
1,
"distribute_opts",
"fast_run",
"mul",
)

fastmath.register(
"dist_div_over_add_opt",
in2out(distribute_div_over_add_opt, ignore_newtrees=True),
1,
"distribute_opts",
"fast_run",
"div",
)

fastmath.register(
"dist_mul_over_sub_opt",
in2out(distribute_mul_over_sub_opt, ignore_newtrees=True),
1,
"distribute_opts",
"fast_run",
"div",
)

fastmath.register(
"dist_div_over_sub_opt",
in2out(distribute_div_over_sub_opt, ignore_newtrees=True),
1,
"distribute_opts",
"fast_run",
"div",
)
46 changes: 46 additions & 0 deletions tests/tensor/test_math_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4549,3 +4549,49 @@ def logit_fn(x):
fg = optimize(FunctionGraph([x], [out]))
assert not list(fg.toposort())
assert fg.inputs[0] is fg.outputs[0]


class TestDistributiveOpts:
x_at = vector("x")
y_at = vector("y")
a_at = matrix("a")

@pytest.mark.parametrize(
"orig_operation, optimized_operation",
[
(a_at * x_at + a_at * y_at, a_at * (x_at + y_at)),
(x_at / a_at + y_at / a_at, (x_at + y_at) / a_at),
((a_at * x_at - a_at * y_at, a_at * (x_at - y_at))),
(x_at / a_at - y_at / a_at, (x_at - y_at) / a_at),
],
)
def test_distributive_opts(self, orig_operation, optimized_operation):
fgraph = FunctionGraph([self.x_at, self.y_at, self.a_at], [orig_operation])
out_orig = fgraph.outputs[0]

fgraph_res = FunctionGraph(
[self.x_at, self.y_at, self.a_at], [optimized_operation]
)
out_res = fgraph_res.outputs[0]

fgraph_opt = optimize(fgraph)
out_opt = fgraph_opt.outputs[0]

assert all(
[
isinstance(out_orig.owner.op, Elemwise),
isinstance(out_res.owner.op, Elemwise),
isinstance(out_opt.owner.op, Elemwise),
]
)

# The scalar op originally in the output node (The Op to be 'collected').
# Should not be equal to the outer scalar Op in optimized version of graph
original_scalar_op = type(out_orig.owner.op.scalar_op)
# The outer scalar Op in in the resulting graph should
# be equal to the outer scalar Op in optimized version of the graph
resulting_scalar_op = type(out_res.owner.op.scalar_op)
optimized_scalar_op = type(out_opt.owner.op.scalar_op)

assert not original_scalar_op == resulting_scalar_op
assert resulting_scalar_op == optimized_scalar_op

0 comments on commit 7b03c00

Please sign in to comment.