-
-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace custom Softmax*
Op
s with Aesara graphs
#682
Comments
Here's a simple example that shows the gradient of a manually constructed softmax function and demonstrates one of the issue that's currently being solved by the custom import numpy as np
import aesara
import aesara.tensor as at
from aesara.graph.opt_utils import optimize_graph
def softmax(x):
return at.exp(x) / at.exp(x).sum(axis=-1, keepdims=True)
x = at.vector("x")
y = softmax(x)
(y_g,) = aesara.grad(y.sum(), [x])
y_g_opt = optimize_graph(y_g, clone=False)
aesara.dprint(y_g_opt)
# Elemwise{add,no_inplace} [id A] ''
# |Elemwise{true_div,no_inplace} [id B] ''
# | |Elemwise{exp} [id C] ''
# | | |x [id D]
# | |InplaceDimShuffle{x} [id E] ''
# | |Sum{acc_dtype=float64} [id F] ''
# | |Elemwise{exp} [id C] ''
# |Elemwise{mul} [id G] ''
# |InplaceDimShuffle{x} [id H] ''
# | |Sum{acc_dtype=float64} [id I] ''
# | |Elemwise{true_div,no_inplace} [id J] ''
# | |Elemwise{mul,no_inplace} [id K] ''
# | | |TensorConstant{(1,) of -1.0} [id L]
# | | |Elemwise{exp} [id C] ''
# | |Elemwise{mul,no_inplace} [id M] ''
# | |InplaceDimShuffle{x} [id E] ''
# | |InplaceDimShuffle{x} [id E] ''
# |Elemwise{exp} [id C] ''
import black
print(black.format_str(aesara.pprint(y_g_opt), mode=black.Mode()))
# (
# (exp(x) / sum(exp(x), axis=None))
# + (
# sum(
# (([-1.0] * exp(x)) / (sum(exp(x), axis=None) * sum(exp(x), axis=None))),
# axis=None,
# )
# * exp(x)
# )
# )
from aesara.tensor.nnet import softmax
x_val = np.array([1e-18, 1e18, -1e18])
aesara.grad(softmax(x).sum(), [x])[0].eval({x: x_val})
# array([0., 0., 0.])
y_g_opt.eval({x: x_val})
# array([nan, nan, nan]) |
Just to clarify, what we ultimately need for things like this is equational unification. Terms like To handle more general terms, we need to incorporate equational properties like lifting/sinking and associativity-commutativity (AC). For instance, when one considers how While we can attempt to account for these equational properties via canonicalization/normalization, they will not necessarily produce the expressions we desire in every case, so we need them to be available in more localized contexts—like these kinds of special-purpose softmax rewrites. Also, using these properties in the context of matching/unification avoids some of the problems that their use in canonicalization doesn't (e.g. numerical concerns), because unification doesn't need to alter/clone graphs—although it does need to search large spaces more than once in some cases (e.g. equations with AC properties). Anyway, the reason I'm always pushing for In other words, this is directly related to #523. |
The There are some improvements in pythological/kanren#40 that I would like to put in place (e.g. better scalability, less redundant matching, etc.), but the current functionality should work for basic things right now. |
Both the softmax and log_softmax graphs are easy to identify and replace by the numerical stable versions that shift by the max.
The issues I found concerned the gradients of both ops (as well as the gradient of SoftmaxGrad) which introduce new softmax terms and would also need the shifting by the max to become stable. These are difficult to match because they can have different patterns depending on which gradients are actually being requested.
You can see that the existing rewrites seem to concern mostly the gradients and the old Theano issue I linked (Theano/Theano#4452) was concerned about not having a rewrite to match the gradient of the softmax when the specialized Op was not being used from the beginning.
I also checked what would happen if softmax and log_softmax returned the numerically stable graph immediately, but the Aesara generated gradients were still unstable.
Originally posted by @ricardoV94 in #673 (comment)
The text was updated successfully, but these errors were encountered: