Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace custom Softmax* Ops with Aesara graphs #682

Open
brandonwillard opened this issue Nov 27, 2021 · 3 comments
Open

Replace custom Softmax* Ops with Aesara graphs #682

brandonwillard opened this issue Nov 27, 2021 · 3 comments
Labels
enhancement New feature or request graph rewriting help wanted Extra attention is needed

Comments

@brandonwillard
Copy link
Member

Both the softmax and log_softmax graphs are easy to identify and replace by the numerical stable versions that shift by the max.

The issues I found concerned the gradients of both ops (as well as the gradient of SoftmaxGrad) which introduce new softmax terms and would also need the shifting by the max to become stable. These are difficult to match because they can have different patterns depending on which gradients are actually being requested.

You can see that the existing rewrites seem to concern mostly the gradients and the old Theano issue I linked (Theano/Theano#4452) was concerned about not having a rewrite to match the gradient of the softmax when the specialized Op was not being used from the beginning.

I also checked what would happen if softmax and log_softmax returned the numerically stable graph immediately, but the Aesara generated gradients were still unstable.

Originally posted by @ricardoV94 in #673 (comment)

@brandonwillard
Copy link
Member Author

brandonwillard commented Nov 27, 2021

Here's a simple example that shows the gradient of a manually constructed softmax function and demonstrates one of the issue that's currently being solved by the custom Ops:

import numpy as np

import aesara
import aesara.tensor as at

from aesara.graph.opt_utils import optimize_graph


def softmax(x):
    return at.exp(x) / at.exp(x).sum(axis=-1, keepdims=True)


x = at.vector("x")
y = softmax(x)

(y_g,) = aesara.grad(y.sum(), [x])

y_g_opt = optimize_graph(y_g, clone=False)


aesara.dprint(y_g_opt)
# Elemwise{add,no_inplace} [id A] ''   
#  |Elemwise{true_div,no_inplace} [id B] ''   
#  | |Elemwise{exp} [id C] ''   
#  | | |x [id D]
#  | |InplaceDimShuffle{x} [id E] ''   
#  |   |Sum{acc_dtype=float64} [id F] ''   
#  |     |Elemwise{exp} [id C] ''   
#  |Elemwise{mul} [id G] ''   
#    |InplaceDimShuffle{x} [id H] ''   
#    | |Sum{acc_dtype=float64} [id I] ''   
#    |   |Elemwise{true_div,no_inplace} [id J] ''   
#    |     |Elemwise{mul,no_inplace} [id K] ''   
#    |     | |TensorConstant{(1,) of -1.0} [id L]
#    |     | |Elemwise{exp} [id C] ''   
#    |     |Elemwise{mul,no_inplace} [id M] ''   
#    |       |InplaceDimShuffle{x} [id E] ''   
#    |       |InplaceDimShuffle{x} [id E] ''   
#    |Elemwise{exp} [id C] ''   

import black
print(black.format_str(aesara.pprint(y_g_opt), mode=black.Mode()))
# (
#     (exp(x) / sum(exp(x), axis=None))
#     + (
#         sum(
#             (([-1.0] * exp(x)) / (sum(exp(x), axis=None) * sum(exp(x), axis=None))),
#             axis=None,
#         )
#         * exp(x)
#     )
# )

from aesara.tensor.nnet import softmax

x_val = np.array([1e-18, 1e18, -1e18])

aesara.grad(softmax(x).sum(), [x])[0].eval({x: x_val})
# array([0., 0., 0.])

y_g_opt.eval({x: x_val})
# array([nan, nan, nan])

@brandonwillard brandonwillard added enhancement New feature or request help wanted Extra attention is needed labels Nov 27, 2021
@brandonwillard
Copy link
Member Author

brandonwillard commented Nov 28, 2021

Just to clarify, what we ultimately need for things like this is equational unification.

Terms like y_g_opt can easily be "matched" using syntactic unification (e.g. the kind currently available via PatternSub and logical-unification, and demonstrated in this symbolic-pymc walkthrough), but y_g_opt only represents one particular term, and slight variations won't be handled.

To handle more general terms, we need to incorporate equational properties like lifting/sinking and associativity-commutativity (AC). For instance, when one considers how y_g_opt could be rewritten in terms of softmax functions, rewrites like sum(b[i] / a, (i, ...)) -> sum(b[i], (i, ...)) / a and (b * c) / a -> (b * (c / a)) come to mind.

While we can attempt to account for these equational properties via canonicalization/normalization, they will not necessarily produce the expressions we desire in every case, so we need them to be available in more localized contexts—like these kinds of special-purpose softmax rewrites.

Also, using these properties in the context of matching/unification avoids some of the problems that their use in canonicalization doesn't (e.g. numerical concerns), because unification doesn't need to alter/clone graphs—although it does need to search large spaces more than once in some cases (e.g. equations with AC properties).

Anyway, the reason I'm always pushing for kanren is that it's a framework that already "hosts" equational unification. For example, due to its stream processing (of unification-based "goals") and relational nature, AC properties can be handled quite well, and kanren already contains support for them.

In other words, this is directly related to #523.

@brandonwillard
Copy link
Member Author

The test_KanrenRelationSub_filters test in #523 provides a complete example of AC unification via kanren (i.e. it will match against A.dot((x + y) + z), A.dot((y + x) + z), A.dot(x + (y + z)), etc.).

There are some improvements in pythological/kanren#40 that I would like to put in place (e.g. better scalability, less redundant matching, etc.), but the current functionality should work for basic things right now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

1 participant