Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

149 custom optimization loop #151

Merged
merged 1 commit into from
Apr 9, 2024
Merged

Conversation

jgallowa07
Copy link
Member

@jgallowa07 jgallowa07 commented Mar 26, 2024

This PR addresses #149

Currently the custom update loop has been added in addition to adding an Model.iter_error property which hold the error per iteration for the last call to Model.fit(). It also updates the multidms.model_collection.fit_models to conform to the new fitting heuristic.

TODO:

  • Add Model.plot_iter_error()
  • Add ModelCollection.plot_iter_error()
  • Test model fitting on simulations
  • Test model fitting on spike data
  • update docs

@wsdewitt
Copy link
Contributor

@jgallowa07 I think we would still want to look at trajectories for our objective function over iterates, not necessarily the error metric (which is a norm on the step taken in parameter space).

@jgallowa07
Copy link
Member Author

jgallowa07 commented Mar 26, 2024

@wsdewitt I think I agree, but are you suggesting that we don't need to look at error at all?

For context, here's the state.error trajectories through time for our simulations:

Screenshot from 2024-03-26 12-18-37

This is with a tolerance of $1e-03$.

So a few questions come to mind:

  1. Is the notion of "converged" (error < tol) even useful given the error is so reliant on the model and hyper-parameters (e.g. lasso and ridge penalties)
  2. Are you suggesting we look at the at no-penalty (?) objective loss trajectory, $L'(X)$? Or the change in loss $\Delta L'(X)$.
  3. Related to #2 above, should "look" be something that users do by eye alone with plotting infrastructure to decide if a model converged, or should we define convergence criteria to be with respect to something else ...
  4. Is it possible that we should be defining an "outer" objective function that relies upon validation data? similar to this jax example. This could take a bit longer to implement if so ...

@jgallowa07
Copy link
Member Author

jgallowa07 commented Mar 27, 2024

These figures are from a the simulated data run for 100K iterations with FISTA acceleration, with a tolerance of $1e-4$. None of these models met the tolerance threshold (state.error < threshold) i.e. "did not converge". Note that I'm only showing every 1000th step, past the first 5000 steps for lower end resolution

loss (objective w/o penalty). This is not a $\Delta$ between iterations.

Screenshot from 2024-03-26 17-47-02

and error (state.error). This is the $\Delta$ in error between iteration taking into account step size with acceleration.

Screenshot from 2024-03-26 17-49-37

@wsdewitt
Copy link
Contributor

Thanks Jared! Indeed, the model isn't converging. Given the oscillatory loss and error traces, I'm guessing it's a line search (step size) issue. Can you try increasing line search iterations to maxls=50? If that doesn't help, you might try turning off acceleration.

@jgallowa07
Copy link
Member Author

jgallowa07 commented Mar 27, 2024

@wsdewitt

  • maxls = 50 seems to have done very little to the results

  • Turning acceleration off really slows down convergence. The objective loss is still slightly sloping down even after 100K iterations

HOWEVER, it seems double precision (re #150 ) is exactly what we needed. Here's a direct comparison of single vs double error trajectories on the simulation data. test this out on the spike data first, but if the results look good I think we should be good to cleanup and merge this PR.

Screenshot from 2024-03-27 13-15-09

@jgallowa07 jgallowa07 force-pushed the 149_custom_optimization_loop branch from 10a6569 to 379805e Compare April 9, 2024 21:10
@jgallowa07
Copy link
Member Author

I've moved the prototype notebook to (notebooks/param_transform_ref_equivariance_prototype.ipynb)[notebooks/param_transform_ref_equivariance_prototype.ipynb], and squashed the commits. With that, I think we can merge this PR so that we may branch off main for implementing the parameter transformation and reference equivariance fitting.

@jgallowa07 jgallowa07 merged commit 8e30b43 into main Apr 9, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants