Skip to content

Commit

Permalink
propagate sinkdiv to Monge gap implementation (#597)
Browse files Browse the repository at this point in the history
* propagate `sinkdiv` to Monge gap implementation

* add to docs

* remove linkcheck

* remove problematic link

* another attempt...
  • Loading branch information
marcocuturi authored Nov 18, 2024
1 parent 86e8f2c commit fd18299
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 11 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@
"https://doi.org/10.1145/2516971.2516977",
"https://doi.org/10.1145/2766963",
"https://keras.io/examples/nlp/pretrained_word_embeddings/",
"https://proceedings.neurips.cc/",
]
linkcheck_report_timeouts_as_broken = False

Expand Down
1 change: 1 addition & 0 deletions docs/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Sinkhorn Divergence
.. autosummary::
:toctree: _autosummary

sinkhorn_divergence.sinkdiv
sinkhorn_divergence.sinkhorn_divergence
sinkhorn_divergence.SinkhornDivergenceOutput
sinkhorn_divergence.segment_sinkhorn_divergence
Expand Down
9 changes: 4 additions & 5 deletions docs/tutorials/neural/200_Monge_Gap.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"\n",
"The first requirement (efficiency) can be quantified with the **Monge gap** $\\mathcal{M}_\\mu^c$, a non-negative regularizer defined through $\\mu$ and $c$, and which takes as an argument any map $T : \\mathbb{R}^d \\rightarrow \\mathbb{R}^d$. The value $\\mathcal{M}_\\mu^c(T)$ quantifies how $T$ moves mass efficiently between $\\mu$ and $T \\sharp \\mu$, and only cancels $\\mathcal{M}_\\mu^c(T) = 0$ i.f.f. $T$ is optimal between $\\mu$ and $T \\sharp \\mu$ for the cost $c$.\n",
"\n",
"The second requirement (landing on $\\nu$) is then simply handled using a fitting loss $\\Delta$ between $T \\sharp \\mu$ and $\\nu$. This can be measured, e.g., using a {func}`~ott.tools.sinkhorn_divergence.sinkhorn_divergence`. Introducing a regularization strength $\\lambda_\\mathrm{MG} > 0$, looking for a Monge map can be reformulated as finding a $T$ that minimizes:\n",
"The second requirement (landing on $\\nu$) is then simply handled using a fitting loss $\\Delta$ between $T \\sharp \\mu$ and $\\nu$. This can be measured, e.g., using the Sinkhorn divergence, {func}`~ott.tools.sinkhorn_divergence.sinkdiv`. Introducing a regularization strength $\\lambda_\\mathrm{MG} > 0$, looking for a Monge map can be reformulated as finding a $T$ that minimizes:\n",
"\n",
"$$\n",
"\\min_{T:\\mathbb{R}^d \\rightarrow \\mathbb{R}^d} \\Delta(T\\sharp \\mu, \\nu) + \\lambda_\\mathrm{MG} \\mathcal{M}_\\mu^c(T)\n",
Expand Down Expand Up @@ -324,8 +324,7 @@
"$$\n",
"\\min_{T:\\mathbb{R}^d \\rightarrow \\mathbb{R}^d} \\Delta(T\\sharp \\mu, \\nu) + \\lambda_\\mathrm{MG} \\mathcal{M}_\\mu^c(T)\n",
"$$\n",
"For all fittings, we use $\\Delta = S_{\\varepsilon, \\ell_2^2}$, the {func}`~ott.tools.sinkhorn_divergence.sinkhorn_divergence` with the {class}`squared Euclidean cost <ott.geometry.costs.SqEuclidean>`\n",
"The function considers a ground cost function `cost_fn` (corresponding to $c$), as well as the `epsilon` regularization parameters to compute approximated Wasserstein distances, both for fitting and regularizer."
"For all fittings, we use $\\Delta = S_{\\varepsilon, \\ell_2^2}$, the :term:`Sinkhorn divergence`, {func}`~ott.tools.sinkhorn_divergence.sinkdiv` with the {class}`squared Euclidean cost <ott.geometry.costs.SqEuclidean>` :term:`ground cost` function `cost_fn` (corresponding to $c$), as well as the `epsilon` regularization parameters to compute approximated Wasserstein distances, both for fitting and regularizer."
]
},
{
Expand Down Expand Up @@ -359,8 +358,8 @@
"\n",
" @jax.jit\n",
" def fitting_loss(x, y):\n",
" div, out = sinkhorn_divergence.sinkhorn_divergence(\n",
" pointcloud.PointCloud, x, y, epsilon=epsilon_fitting, static_b=True\n",
" div, out = sinkhorn_divergence.sinkdiv(\n",
" x, y, epsilon=epsilon_fitting, static_b=True\n",
" )\n",
" return div, out.n_iters\n",
"\n",
Expand Down
7 changes: 4 additions & 3 deletions src/ott/neural/methods/monge_gap.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class MongeGapEstimator:
sets of points.
For instance, :math:`\Delta` can be the
:func:`~ott.tools.sinkhorn_divergence.sinkhorn_divergence`
:func:`~ott.tools.sinkhorn_divergence.sinkdiv`
and :math:`R` the :func:`~ott.neural.methods.monge_gap.monge_gap_from_samples`
:cite:`uscidda:23` for a given cost function :math:`c`.
In that case, it estimates a :math:`c`-OT map, i.e. a map :math:`T`
Expand Down Expand Up @@ -260,7 +260,8 @@ def setup(
def regularizer(self) -> Callable[[jnp.ndarray, jnp.ndarray], float]:
"""Regularizer added to the fitting loss.
Can be, e.g. the :func:`~ott.neural.methods.monge_gap.monge_gap_from_samples`.
Can be, e.g. the
:func:`~ott.neural.methods.monge_gap.monge_gap_from_samples`.
If no regularizer is passed for solver instantiation,
or regularization weight :attr:`regularizer_strength` is 0,
return 0 by default along with an empty set of log values.
Expand All @@ -273,7 +274,7 @@ def regularizer(self) -> Callable[[jnp.ndarray, jnp.ndarray], float]:
def fitting_loss(self) -> Callable[[jnp.ndarray, jnp.ndarray], float]:
"""Fitting loss to fit the marginal constraint.
Can be, e.g. :func:`~ott.tools.sinkhorn_divergence.sinkhorn_divergence`.
Can be, e.g. :func:`~ott.tools.sinkhorn_divergence.sinkdiv`.
If no fitting_loss is passed for solver instantiation, return 0 by default,
and no log values.
"""
Expand Down
5 changes: 2 additions & 3 deletions tests/neural/methods/monge_gap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np

from ott import datasets
from ott.geometry import costs, pointcloud, regularizers
from ott.geometry import costs, regularizers
from ott.neural.methods import monge_gap
from ott.neural.networks import potentials
from ott.tools import sinkhorn_divergence
Expand Down Expand Up @@ -144,8 +144,7 @@ def fitting_loss(
mapped_samples: jnp.ndarray,
) -> Optional[float]:
r"""Sinkhorn divergence fitting loss."""
div, _ = sinkhorn_divergence.sinkhorn_divergence(
pointcloud.PointCloud,
div, _ = sinkhorn_divergence.sinkdiv(
x=samples,
y=mapped_samples,
)
Expand Down

0 comments on commit fd18299

Please sign in to comment.