-
-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add axis
argument to Softmax
and related Op
s
#673
Conversation
96c0b0f
to
ad3bb57
Compare
Codecov Report
@@ Coverage Diff @@
## main #673 +/- ##
==========================================
+ Coverage 77.61% 77.66% +0.05%
==========================================
Files 152 152
Lines 46915 46952 +37
Branches 10883 10891 +8
==========================================
+ Hits 36413 36466 +53
+ Misses 7895 7888 -7
+ Partials 2607 2598 -9
|
5a7dbf4
to
f867087
Compare
f79e832
to
c77daf4
Compare
5106acf
to
0f260ae
Compare
e50b335
to
e6542eb
Compare
e6542eb
to
3d22e4b
Compare
These sound good. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to know exactly why/if we should have special Op
s for these composite operations nowadays.
Do they necessarily perform better than their equivalents that are built using other Op
s? Assuming that the constituent Op
s all have C/JAX/Numba implementations, it would be surprising to find—and good to be aware of—a significant difference, because their backend implementations don't clearly imply any advantages.
If not, we need to consider replacing these with simple helper functions; otherwise, we'll have a lot of redundant custom code to maintain.
I explored that option. It was not so much a question of performance but pattern matching difficulties and subsequent numerical stability because of forms that were not being captured. This was specially relevant for the SoftmaxGrad which can have many different patterns depending on which grads are being requested. This is related to all those tests that are being skipped. Thinking about it, we might be able to get rid of the log_softmax and use the expression based on logsumpexp that scipy also uses. That would already be something. I'll check that |
I am waiting until all other changes are sorted out / reviewed so as to not suffer with interactive rebases |
Both of these are important enough to address directly and immediately; otherwise, the library will simply not improve or advance. The approach taken by Regarding the latter issue, where is this demonstrated? |
What latter issue are you referring to? |
The numerical stability and forms that aren't being captured. |
I just looked over the The only one that looks like it might require any real pattern matching in the absence of those This optimization could easily be replaced by simple pattern matching. Also, anything that aspires to do this kind of thing beyond a few simple cases will ultimately need to use more general functional properties anyway (e.g. employ monotonic/convex function closure identities from which a soft-max/plus, sigmoid, etc., could be easily derived). |
Both the softmax and log_softmax graphs are easy to identify and replace by the numerical stable versions that shift by the max. The issues I found concerned the gradients of both ops (as well as the gradient of SoftmaxGrad) which introduce new softmax terms and would also need the shifting by the max to become stable. These are difficult to match because they can have different patterns depending on which gradients are actually being requested. You can see that the existing rewrites seem to concern mostly the gradients and the old Theano issue I linked (Theano/Theano#4452) was concerned about not having a rewrite to match the gradient of the softmax when the specialized Op was not being used from the beginning. I also checked what would happen if softmax and log_softmax returned the numerically stable graph immediately, but the Aesara generated gradients were still unstable. I am not saying it's impossible, just that for the specific goal of adding axes, it turned out to be less trouble to update the Ops and rewrites. These changes do not preclude a future replacement by symbolic graphs when the gradient pattern matching issues are figured out. |
Yes, this is the detail I was looking for. If this doesn't have an issue, it needs one.
If I'm understanding Theano/Theano#4452 correctly, it doesn't sound like a good solution for this at all. It seems like it's akin to using For instance, we would need to carefully orchestrate the compilation phases to make sure that rewrites are applied after this proposed
It might only seem like less trouble when the scope is narrow; however, if we address whatever problems underlie this issue, we are very likely to improve more things than just the
No, but the time we spend creating and maintaining these work-arounds takes away from priority improvements like that gradient issue. Remember, if these rewrites can stabilize graphs that aren't explicit |
I created an issue for this: #682. We can continue the conversation there; otherwise, we can still merge this in the meantime, but we need to give priority to solving that issue. |
@ricardoV94, do you want to move the |
I will move them in a PR immediately after this one |
be74bf9
to
a193faf
Compare
a193faf
to
9ce41da
Compare
axis
argument to Softmax
and related Op
s
TODO:
Softmax/LogSoftmax
contents fromnnet/basic.py
totensor/softmax.py
? Will do in follow-up PRlogsoftmax
->log_softmax
for consistency with scipy (with deprecated alias)? Will do in follow-up PRCloses #183