Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speed up scipy optimize #8

Open
jacobpennington opened this issue Sep 8, 2022 · 2 comments
Open

speed up scipy optimize #8

jacobpennington opened this issue Sep 8, 2022 · 2 comments
Assignees
Labels
enhancement New feature or request

Comments

@jacobpennington
Copy link
Collaborator

Try out the suggestion here re: using JAX library to compute cost function gradient and providing that information to scipy.
(for their specific example, quoting a ~5000x speedup)
https://stackoverflow.com/questions/68507176/faster-scipy-optimizations

A little more involved than I first thought, since np.<whatever> operations have to be replaced with jnp.<whatever>. But aside from a few caveats like not using in-place operations most Layer implementations would be otherwise identical, so this could be added in as a backend (and would be much simpler than TF, just define evaluate_jax and still use scipy, but with a hook to use the gradient version). Without configuring GPU usage this would still be slower than TF, but may be a good intermediate option that's still much faster than vanilla scipy/numpy and easier for new users to implement.

Separately, try adding numba to the standard scipy evaluates (http://numba.pydata.org/). It looks like it's supposed to work with standard numpy unlike JAX, so may be simple to integrate improvements.

@jacobpennington jacobpennington self-assigned this Sep 8, 2022
@jacobpennington jacobpennington added the enhancement New feature or request label Sep 8, 2022
@jacobpennington
Copy link
Collaborator Author

Notes so far on Numba: worked great for STP revision, ~1000x speed up vs old non-quick algorithm, ~40x speedup vs quick_eval algorithm. Adds one extra dependency but so far seems worth it for those speedups, pending more testing to make sure this doesn't interfere with optimization (it shouldn't) and the outputs are close enough numerically for a variety of inputs.

Other @njit options to try:
nogil=True : unlocks global interpreter lock so that multiple threads can run simultaneously. This one requires some thought, but can speed up things like error checks (detecing NaNs for example) with no side-effects that are safe to run asynchronously.
cache=True: saves compiled functions to pycache (or a fallback directory) to speedup subsequent runs. Default cache=False means functions get compiled again every time the program is run, so first-uses will always be slower. Not a big factory for optimization, but turning this on for things like STP could make post-fit analyses faster (like loading and plotting a model), although those are generally fast enough already.
parallel=True: automatically parallelize a lot of numpy functions and other operators that already support parallelization.

https://numba.readthedocs.io/en/stable/reference/jit-compilation.html

@jacobpennington
Copy link
Collaborator Author

For FIR:

  1. Should be able to get rid of the loop over output channels with proper reshaping.
  2. Try routines from scipy.ndimage? They're supposed to have some additional optimizations for specific use-cases.
  3. Alternatively, try coding up the custom 1D filtering in cython or numba as simple nested for loops. Those should still be easy enough to read, and may speed things up by ignoring a lot of options/checks that generic convolution functions include.

Not clear how much effect on performance these changes would have.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant