Skip to content

Commit

Permalink
Improvements to the neural dual solver (#219)
Browse files Browse the repository at this point in the history
* Initial batch of updates to neuraldual code.

This swaps the order of f and g potentials and adds:

1. an option to add a conjugate solver that fine-tunes g's prediction
   when updating f, along with a default solver using JaxOpt's
   LBFGS optimizer
2. amortization modes to learn g
3. an option to update both potentials with the same batch (in parallel)
4. the option to model the *gradient* of the g potential
   with a neural network and a simple MLP implementing this
5. a callback so the training progress can be monitored

* icnn: only don't use an activation on the first layer

* set default amortization loss to regression

* Improve notebook

1. Update f ICNN size from [64 ,64, 64, 64] to [128, 128],
   use LReLU activations, and increase init_std to 1.0
2. Use an MLP to model the gradient of g
3. Use the default JaxOpt LBFGS conjugate solver and
   update g to regress onto it
4. Increase batch size from 1000 to 10000
5. Sample the data in batches rather than sequentially
6. Tweak Adam (use weight decay and a cosine schedule)

* update icnn to use the posdef potential last

* X,Y -> source,target

* don't fine-tune when the conjugate solver is not provided

* provides_gradient->returns_potential

* NeuralDualSolver->W2NeuralDual

* ConjugateSolverLBFGS->FenchelConjugateLBFGS

* rm wandb in logging comment

* remove line split

* BFGS->LBFGS

* address review comments from @michalk8

* also permit the MLP to model the potential (in addition to the gradient)

* add alternating update directions option and ability to initialize with existing parameters

* add option to finetune g when getting the potentials

* ->back_and_forth

* fix typing

* add back the W2 distance estimate and penalty

* fix logging with back-and-forth

* add weight term

* default MLP to returns_potential=True

* fix newline

* update icnn_inits nb

* bump down init_std

* make bib style consistent

* ->W2NeuralDual

* pass on docs

* limit flax version

* neural_dual nb: use two examples, finalize potentials with stable hyper-params

* update pin for jax/flax testing dependency

* address review comments from @michalk8

* fix potential type hints

* re-run nb with latest code

* address review comments from @michalk8

* potential_{value,gradient} -> functions

* add latest nb (no outputs)

* fix typos

* plot_ot_map: keep source/target measures fixed

* fix types

* add review comments from @michalk8 and @marcocuturi

* consolidate models

* inverse->forward

* support back_and_forth when using a gradient mapping, improve docs

* move plotting code

* update nbs

* update intro text

* conjugate_solver -> conjugate_solvers

* address review comments from @michalk8

* update docstring

* create problems.nn.dataset, add more tests

* fix indentation

* polish docs and restructure dataset to address @michalk8's comments

* move plotting code back to ott.problems.linear.potentials

* fix test and dataloaders->dataset

* update notebooks to latest code

* bump up batch size and update notebooks

* address review comments from @michalk8 (sorry for multiple force-pushes)
  • Loading branch information
bamos authored Feb 13, 2023
1 parent de2c1d7 commit 40bf4af
Show file tree
Hide file tree
Showing 21 changed files with 1,745 additions and 960 deletions.
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
"scikit-sparse": ("https://scikit-sparse.readthedocs.io/en/latest/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/reference/", None),
"pot": ("https://pythonot.github.io/", None),
"jaxopt": ("https://jaxopt.github.io/stable", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
}

master_doc = 'index'
Expand Down
5 changes: 3 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ To achieve this, ``OTT`` rests on two families of tools:
:cite:`scetbon:21` algorithms, and moving up towards Gromov-Wasserstein
:cite:`memoli:11,peyre:16`;
- the second family consists in *continuous* solvers, using suitable neural
architectures :cite:`amos:17` coupled with SGD type estimators
:cite:`makkuva:20,korotin:21`.
architectures such as an MLP or input-convex neural network
:cite:`amos:17` coupled with SGD-type estimators
:cite:`makkuva:20,korotin:21,amos:23`.

Installation
------------
Expand Down
1 change: 1 addition & 0 deletions docs/problems/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ ott.problems

linear
quadratic
nn
15 changes: 15 additions & 0 deletions docs/problems/nn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
ott.problems.nn
===============
.. currentmodule:: ott.problems.nn
.. automodule:: ott.problems.nn

.. TODO(marcocuturi): maybe add some text here
Dataset
-------
.. autosummary::
:toctree: _autosummary

dataset.create_gaussian_mixture_samplers
dataset.Dataset
dataset.GaussianMixture
34 changes: 34 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,14 @@ @ARTICLE{chen:20
doi={10.1109/TPAMI.2019.2908635}
}

@inproceedings{amos:23,
title={On amortizing convex conjugates for optimal transport},
author={Amos, Brandon},
booktitle={International Conference on Learning Representations},
year={2023},
url={https://arxiv.org/abs/2210.12153}
}

@ARTICLE{schreck:15,
author={Schreck, Amandine and Fort, Gersende and Le Corff, Sylvain and Moulines, Eric},
journal={IEEE Journal of Selected Topics in Signal Processing},
Expand Down Expand Up @@ -740,3 +748,29 @@ @article{zou:05
eprint = {https://rss.onlinelibrary.wiley.com/doi/pdf/10.1111/j.1467-9868.2005.00503.x},
year = {2005}
}

@article{jacobs:20,
title={A fast approach to optimal transport: The back-and-forth method},
author={Jacobs, Matt and L{\'e}ger, Flavien},
journal={Numerische Mathematik},
volume={146},
number={3},
pages={513--544},
year={2020},
publisher={Springer}
}


@phdthesis{bertsekas:71,
title={Control of uncertain systems with a set-membership description of the uncertainty.},
author={Bertsekas, Dimitri P},
year={1971},
school={Massachusetts Institute of Technology}
}

@book{danskin:67,
title={The Theory of Max-Min and its Application to Weapons Allocation Problems},
author={Danskin, John M},
publisher={Springer},
year={1967}
}
20 changes: 15 additions & 5 deletions docs/solvers/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,22 @@ Neural Dual
.. autosummary::
:toctree: _autosummary

neuraldual.NeuralDualSolver
neuraldual.W2NeuralDual

ICNN
----
Models
------
.. autosummary::
:toctree: _autosummary

icnn.ICNN
layers.PositiveDense
models.ModelBase
models.ICNN
models.MLP

Conjugate Solvers
-----------------
.. autosummary::
:toctree: _autosummary

conjugate_solvers.ConjugateResults
conjugate_solvers.FenchelConjugateSolver
conjugate_solvers.FenchelConjugateLBFGS
Loading

0 comments on commit 40bf4af

Please sign in to comment.