Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add axis argument to Softmax and related Ops #673

Merged
merged 7 commits into from
Dec 13, 2021

Conversation

ricardoV94
Copy link
Contributor

@ricardoV94 ricardoV94 commented Nov 24, 2021

TODO:

  • Add axis to Softmax
  • Add axis to SoftmaxGrad
  • Add axis to LogSoftmax
  • Update Numba dispatch
  • Update JAX dispatch
  • Generalize optimizations for Softmax
  • Generalize optimizations for SoftmaxGrad
  • Generalize optimizations for LogSoftmax
  • Move Softmax/LogSoftmax contents from nnet/basic.py to tensor/softmax.py? Will do in follow-up PR
  • Rename logsoftmax -> log_softmax for consistency with scipy (with deprecated alias)? Will do in follow-up PR

Closes #183

@ricardoV94 ricardoV94 changed the title Add axis to Softmax and SoftmaxGrad Ops Add axis to Softmax and related Ops Nov 24, 2021
@ricardoV94 ricardoV94 force-pushed the Softmax_axis branch 6 times, most recently from 96c0b0f to ad3bb57 Compare November 24, 2021 15:16
@codecov
Copy link

codecov bot commented Nov 24, 2021

Codecov Report

Merging #673 (9ce41da) into main (34375f4) will increase coverage by 0.05%.
The diff coverage is 97.98%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #673      +/-   ##
==========================================
+ Coverage   77.61%   77.66%   +0.05%     
==========================================
  Files         152      152              
  Lines       46915    46952      +37     
  Branches    10883    10891       +8     
==========================================
+ Hits        36413    36466      +53     
+ Misses       7895     7888       -7     
+ Partials     2607     2598       -9     
Impacted Files Coverage Δ
aesara/tensor/nnet/basic.py 80.06% <97.22%> (+1.93%) ⬆️
aesara/link/jax/dispatch.py 80.27% <100.00%> (+0.31%) ⬆️
aesara/link/numba/dispatch/elemwise.py 97.78% <100.00%> (+0.23%) ⬆️

@ricardoV94 ricardoV94 force-pushed the Softmax_axis branch 4 times, most recently from 5a7dbf4 to f867087 Compare November 24, 2021 19:26
@brandonwillard brandonwillard added the enhancement New feature or request label Nov 25, 2021
@ricardoV94 ricardoV94 force-pushed the Softmax_axis branch 3 times, most recently from f79e832 to c77daf4 Compare November 25, 2021 08:15
@ricardoV94 ricardoV94 force-pushed the Softmax_axis branch 2 times, most recently from 5106acf to 0f260ae Compare November 25, 2021 09:59
@ricardoV94 ricardoV94 marked this pull request as ready for review November 25, 2021 10:00
@ricardoV94 ricardoV94 force-pushed the Softmax_axis branch 6 times, most recently from e50b335 to e6542eb Compare November 25, 2021 12:37
brandonwillard
brandonwillard previously approved these changes Nov 26, 2021
aesara/tensor/nnet/basic.py Outdated Show resolved Hide resolved
@brandonwillard
Copy link
Member

  • Move Softmax/LogSoftmax contents from nnet/basic.py to tensor/softmax.py?
  • Rename logsoftmax -> log_softmax for consistency with scipy (with deprecated alias)?

These sound good.

Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to know exactly why/if we should have special Ops for these composite operations nowadays.

Do they necessarily perform better than their equivalents that are built using other Ops? Assuming that the constituent Ops all have C/JAX/Numba implementations, it would be surprising to find—and good to be aware of—a significant difference, because their backend implementations don't clearly imply any advantages.

If not, we need to consider replacing these with simple helper functions; otherwise, we'll have a lot of redundant custom code to maintain.

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Nov 26, 2021

It would be good to know exactly why/if we should have special Ops for these composite operations nowadays.

Do they necessarily perform better than their equivalents that are built using other Ops?

I explored that option. It was not so much a question of performance but pattern matching difficulties and subsequent numerical stability because of forms that were not being captured.

This was specially relevant for the SoftmaxGrad which can have many different patterns depending on which grads are being requested. This is related to all those tests that are being skipped.

Thinking about it, we might be able to get rid of the log_softmax and use the expression based on logsumpexp that scipy also uses. That would already be something. I'll check that

@ricardoV94
Copy link
Contributor Author

  • Move Softmax/LogSoftmax contents from nnet/basic.py to tensor/softmax.py?
  • Rename logsoftmax -> log_softmax for consistency with scipy (with deprecated alias)?

These sound good.

I am waiting until all other changes are sorted out / reviewed so as to not suffer with interactive rebases

@brandonwillard
Copy link
Member

I explored that option. It was not so much a question of performance but pattern matching difficulties and subsequent numerical stability because of forms that were not being captured.

Both of these are important enough to address directly and immediately; otherwise, the library will simply not improve or advance. The approach taken by Softmax and SoftmaxGrad is essentially antithetical to the design and intent of this library. In other words, if something as simple as Softmax cannot be implemented sufficiently using this library itself, then we have much bigger and more important problems.

Regarding the latter issue, where is this demonstrated?

@ricardoV94
Copy link
Contributor Author

What latter issue are you referring to?

@brandonwillard
Copy link
Member

What latter issue are you referring to?

The numerical stability and forms that aren't being captured.

@brandonwillard
Copy link
Member

I just looked over the *Softmax*-related optimizations, and there don't appear to be any that substantially depend on the encapsulation provided by those specialized Op classes. The few optimizations that are in place only serve to replace those specialized Ops with other specialized Ops (e.g. turn log(softmax(...)) into a LogSoftmax, turn softmax(sum(...)) into a softmax_w_bias(...), simplify log(SoftmaxGrad), etc.)

The only one that looks like it might require any real pattern matching in the absence of those Ops is the misplaced and misnamed local_argmax_pushdown specialization. It appears to lift MaxAndArgmaxs through monotonic Ops only when the arg-max part of the output is being used in a graph.

This optimization could easily be replaced by simple pattern matching. Also, anything that aspires to do this kind of thing beyond a few simple cases will ultimately need to use more general functional properties anyway (e.g. employ monotonic/convex function closure identities from which a soft-max/plus, sigmoid, etc., could be easily derived).

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Nov 27, 2021

Both the softmax and log_softmax graphs are easy to identify and replace by the numerical stable versions that shift by the max.

The issues I found concerned the gradients of both ops (as well as the gradient of SoftmaxGrad) which introduce new softmax terms and would also need the shifting by the max to become stable. These are difficult to match because they can have different patterns depending on which gradients are actually being requested.

You can see that the existing rewrites seem to concern mostly the gradients and the old Theano issue I linked (Theano/Theano#4452) was concerned about not having a rewrite to match the gradient of the softmax when the specialized Op was not being used from the beginning.

I also checked what would happen if softmax and log_softmax returned the numerically stable graph immediately, but the Aesara generated gradients were still unstable.

I am not saying it's impossible, just that for the specific goal of adding axes, it turned out to be less trouble to update the Ops and rewrites.

These changes do not preclude a future replacement by symbolic graphs when the gradient pattern matching issues are figured out.

@brandonwillard
Copy link
Member

brandonwillard commented Nov 27, 2021

The issues I found concerned the gradients of both ops (as well as the gradient of SoftmaxGrad) which introduce new softmax terms and would also need the shifting by the max to become stable. These are difficult to match because they can have different patterns depending on which gradients are actually being requested.

Yes, this is the detail I was looking for. If this doesn't have an issue, it needs one.

You can see that the existing rewrites seem to concern mostly the gradients and the old Theano issue I linked (Theano/Theano#4452) was concerned about not having a rewrite to match the gradient of the softmax when the specialized Op was not being used from the beginning.

If I'm understanding Theano/Theano#4452 correctly, it doesn't sound like a good solution for this at all. It seems like it's akin to using OpFromGraph to replace the need for pattern matching. These kinds of things only shift the work to different areas, which more often than not doesn't actually imply less work, just different work.

For instance, we would need to carefully orchestrate the compilation phases to make sure that rewrites are applied after this proposed GradOp is unfolded/expanded. Doing so requires multiple re-runs of entire optimization phases (e.g. more canonicalization passes), and that's the last direction we want to go just for some mild short-term conveniences—if anything, it's something we need to avoid.

I am not saying it's impossible, just that for the specific goal of adding axes, it turned out to be less trouble to update the Ops and rewrites.

It might only seem like less trouble when the scope is narrow; however, if we address whatever problems underlie this issue, we are very likely to improve more things than just the Softmax Op, as well as avoid future pitfalls, like time spent on Op-specific extensions and maintenance (e.g. writing those C/Numba/JAX implementations and their tests).

These changes do not preclude a future replacement by symbolic graphs when the gradient pattern matching issues are figured out.

No, but the time we spend creating and maintaining these work-arounds takes away from priority improvements like that gradient issue. Remember, if these rewrites can stabilize graphs that aren't explicit Softmax and SoftmaxGrad Ops, then they'll be able to stabilize more graphs, and that's what we want.

@brandonwillard
Copy link
Member

I created an issue for this: #682. We can continue the conversation there; otherwise, we can still merge this in the meantime, but we need to give priority to solving that issue.

@brandonwillard
Copy link
Member

brandonwillard commented Dec 13, 2021

@ricardoV94, do you want to move the Ops out of nnet in this PR or another?

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Dec 13, 2021

I will move them in a PR immediately after this one

@brandonwillard brandonwillard merged commit bca9a38 into aesara-devs:main Dec 13, 2021
@brandonwillard brandonwillard changed the title Add axis to Softmax and related Ops Add axis argument to Softmax and related Ops Dec 13, 2021
@ricardoV94 ricardoV94 deleted the Softmax_axis branch December 16, 2021 09:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement vector-proof Softmax function
2 participants