Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improvements to the neural dual solver (#219)
* Initial batch of updates to neuraldual code. This swaps the order of f and g potentials and adds: 1. an option to add a conjugate solver that fine-tunes g's prediction when updating f, along with a default solver using JaxOpt's LBFGS optimizer 2. amortization modes to learn g 3. an option to update both potentials with the same batch (in parallel) 4. the option to model the *gradient* of the g potential with a neural network and a simple MLP implementing this 5. a callback so the training progress can be monitored * icnn: only don't use an activation on the first layer * set default amortization loss to regression * Improve notebook 1. Update f ICNN size from [64 ,64, 64, 64] to [128, 128], use LReLU activations, and increase init_std to 1.0 2. Use an MLP to model the gradient of g 3. Use the default JaxOpt LBFGS conjugate solver and update g to regress onto it 4. Increase batch size from 1000 to 10000 5. Sample the data in batches rather than sequentially 6. Tweak Adam (use weight decay and a cosine schedule) * update icnn to use the posdef potential last * X,Y -> source,target * don't fine-tune when the conjugate solver is not provided * provides_gradient->returns_potential * NeuralDualSolver->W2NeuralDual * ConjugateSolverLBFGS->FenchelConjugateLBFGS * rm wandb in logging comment * remove line split * BFGS->LBFGS * address review comments from @michalk8 * also permit the MLP to model the potential (in addition to the gradient) * add alternating update directions option and ability to initialize with existing parameters * add option to finetune g when getting the potentials * ->back_and_forth * fix typing * add back the W2 distance estimate and penalty * fix logging with back-and-forth * add weight term * default MLP to returns_potential=True * fix newline * update icnn_inits nb * bump down init_std * make bib style consistent * ->W2NeuralDual * pass on docs * limit flax version * neural_dual nb: use two examples, finalize potentials with stable hyper-params * update pin for jax/flax testing dependency * address review comments from @michalk8 * fix potential type hints * re-run nb with latest code * address review comments from @michalk8 * potential_{value,gradient} -> functions * add latest nb (no outputs) * fix typos * plot_ot_map: keep source/target measures fixed * fix types * add review comments from @michalk8 and @marcocuturi * consolidate models * inverse->forward * support back_and_forth when using a gradient mapping, improve docs * move plotting code * update nbs * update intro text * conjugate_solver -> conjugate_solvers * address review comments from @michalk8 * update docstring * create problems.nn.dataset, add more tests * fix indentation * polish docs and restructure dataset to address @michalk8's comments * move plotting code back to ott.problems.linear.potentials * fix test and dataloaders->dataset * update notebooks to latest code * bump up batch size and update notebooks * address review comments from @michalk8 (sorry for multiple force-pushes)
- Loading branch information