diff --git a/README.md b/README.md index 99b92c720..a3aa1ea17 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ Currently implements the following classes and functions: - The [geometry](ott/geometry) folder describes tools that to encapsulate the essential ingredients of OT problems: measures and cost functions. - - The `CostFn` class in [costs.py](ott/geometry/costs.py) and its descendants define cost functions between points. A few simple costs are considered, `Euclidean` between vectors, and `Bures`, between a pair of mean vector and covariance (p.d.) matrix. + - The `CostFn` class in [costs.py](ott/geometry/costs.py) and its descendants define cost functions between points. A few simple costs are considered, `SqEuclidean` between vectors, and `Bures`, between a pair of mean vector and covariance (p.d.) matrix. - The `Geometry` class in [geometry.py](ott/geometry/geometry.py) describes a cost structure between two measures. That cost structure is accessed through various member functions, either used when running the Sinkhorn algorithm (typically kernel multiplications, or log-sum-exp row/column-wise application) or after (to apply the OT matrix to a vector). diff --git a/docs/geometry.rst b/docs/geometry.rst index e13f9c885..c2c438f0c 100644 --- a/docs/geometry.rst +++ b/docs/geometry.rst @@ -50,6 +50,7 @@ Cost Functions :toctree: _autosummary costs.CostFn + costs.SqEuclidean costs.Euclidean costs.Cosine costs.Bures diff --git a/docs/notebooks/Sinkhorn_Barycenters.ipynb b/docs/notebooks/Sinkhorn_Barycenters.ipynb index febfce418..db3116b76 100644 --- a/docs/notebooks/Sinkhorn_Barycenters.ipynb +++ b/docs/notebooks/Sinkhorn_Barycenters.ipynb @@ -425,7 +425,7 @@ "id": "tqg86SFQvzXC" }, "source": [ - "### Euclidean barycenter, for reference" + "### SqEuclidean barycenter, for reference" ] }, { diff --git a/docs/notebooks/introduction_grid.ipynb b/docs/notebooks/introduction_grid.ipynb index 8769d1f08..402041b86 100644 --- a/docs/notebooks/introduction_grid.ipynb +++ b/docs/notebooks/introduction_grid.ipynb @@ -309,8 +309,8 @@ "outputs": [], "source": [ "@jax.tree_util.register_pytree_node_class\n", - "class EuclideanTimes2(costs.CostFn):\n", - " \"\"\"The cost function corresponding to the squared euclidean distance times 2.\"\"\"\n", + "class SqEuclideanTimes2(costs.CostFn):\n", + " \"\"\"The cost function corresponding to the squared SqEuclidean distance times 2.\"\"\"\n", "\n", " def norm(self, x):\n", " return jnp.sum(x**2, axis=-1) * 2\n", @@ -319,7 +319,7 @@ " return -2 * jnp.sum(x * y) * 2\n", "\n", "\n", - "cost_fns = [EuclideanTimes2(), costs.Euclidean()]" + "cost_fns = [SqEuclideanTimes2(), costs.SqEuclidean()]" ] }, { diff --git a/docs/notebooks/point_clouds.ipynb b/docs/notebooks/point_clouds.ipynb index 596d41f29..530917f69 100644 --- a/docs/notebooks/point_clouds.ipynb +++ b/docs/notebooks/point_clouds.ipynb @@ -85,7 +85,7 @@ "source": [ "## Computes the regularized optimal transport\n", "\n", - "To compute the transport matrix between the two point clouds, one can define a `PointCloud` geometry (which by default uses `ott.geometry.costs.Euclidean` for cost function), then call the `sinkhorn` function, and build the transport matrix from the optimized potentials." + "To compute the transport matrix between the two point clouds, one can define a `PointCloud` geometry (which by default uses `ott.geometry.costs.SqEuclidean` for cost function), then call the `sinkhorn` function, and build the transport matrix from the optimized potentials." ] }, { @@ -235,7 +235,7 @@ " y: jnp.ndarray,\n", " a: jnp.ndarray,\n", " b: jnp.ndarray,\n", - " cost_fn=ott.geometry.costs.Euclidean(),\n", + " cost_fn=ott.geometry.costs.SqEuclidean(),\n", " num_iter: int = 101,\n", " dump_every: int = 10,\n", " learning_rate: float = 0.2,\n", @@ -5972,7 +5972,9 @@ "source": [ "from IPython import display\n", "\n", - "ots = optimize(x, y, a, b, num_iter=100, cost_fn=ott.geometry.costs.Euclidean())\n", + "ots = optimize(\n", + " x, y, a, b, num_iter=100, cost_fn=ott.geometry.costs.SqEuclidean()\n", + ")\n", "fig = plt.figure(figsize=(8, 5))\n", "plott = ott.tools.plot.Plot(fig=fig)\n", "anim = plott.animate(ots, frame_rate=4)\n", diff --git a/ott/core/bar_problems.py b/ott/core/bar_problems.py index ee214ca6d..68c29f325 100644 --- a/ott/core/bar_problems.py +++ b/ott/core/bar_problems.py @@ -42,7 +42,7 @@ class BarycenterProblem: weights: Array of shape ``[num_measures,]`` containing the weights of the measures. cost_fn: Cost function used. If `None`, - use :class:`~ott.geometry.costs.Euclidean` cost. + use :class:`~ott.geometry.costs.SqEuclidean` cost. epsilon: Epsilon regularization used to solve reg-OT problems. debiased: **Currently not implemented.** Whether the problem is debiased, in the sense that @@ -75,7 +75,7 @@ def __init__( raise ValueError("Specify weights if `y` is already segmented.") self._b = b self._weights = weights - self.cost_fn = costs.Euclidean() if cost_fn is None else cost_fn + self.cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn self.epsilon = epsilon self.debiased = debiased self._kwargs = kwargs @@ -309,7 +309,7 @@ def update_features(self, transports: jnp.ndarray, """Update the barycenter features in the fused case :cite:`vayer:19`. Uses :cite:`cuturi:14` eq. 8, and is implemented only - for the squared :class:`~ott.geometry.costs.Euclidean` cost. + for the squared :class:`~ott.geometry.costs.SqEuclidean` cost. Args: transports: Transport maps of shape @@ -330,7 +330,7 @@ def update_features(self, transports: jnp.ndarray, transports = transports * inv_a[None, :, None] if self._loss_name == "sqeucl": - cost = costs.Euclidean() + cost = costs.SqEuclidean() return jnp.sum( weights * barycentric_projection(transports, y_fused, cost), axis=0 ) diff --git a/ott/core/potentials.py b/ott/core/potentials.py index 5a1beb42e..0170825cc 100644 --- a/ott/core/potentials.py +++ b/ott/core/potentials.py @@ -162,7 +162,7 @@ def callback(x: jnp.ndarray) -> float: cost = pointcloud.PointCloud( jnp.atleast_2d(x), y, - cost_fn=self._geom._cost_fn, + cost_fn=self._geom.cost_fn, power=self._geom.power, epsilon=1.0 # epsilon is not used ).cost_matrix diff --git a/ott/geometry/costs.py b/ott/geometry/costs.py index f98c573ec..fae64a596 100644 --- a/ott/geometry/costs.py +++ b/ott/geometry/costs.py @@ -92,7 +92,16 @@ def tree_unflatten(cls, aux_data, children): @jax.tree_util.register_pytree_node_class class Euclidean(CostFn): - """Squared Euclidean distance CostFn.""" + """Euclidean distance.""" + + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + """Compute Euclidean norm.""" + return jnp.linalg.norm(x - y) + + +@jax.tree_util.register_pytree_node_class +class SqEuclidean(CostFn): + """Squared Euclidean distance.""" def norm(self, x: jnp.ndarray) -> Union[float, jnp.ndarray]: """Compute squared Euclidean norm for vector.""" diff --git a/ott/geometry/grid.py b/ott/geometry/grid.py index fdd9f346a..25e9809eb 100644 --- a/ott/geometry/grid.py +++ b/ott/geometry/grid.py @@ -101,7 +101,7 @@ def __init__( raise ValueError('Input either grid_size t-uple or grid locations x.') if cost_fns is None: - cost_fns = [costs.Euclidean()] + cost_fns = [costs.SqEuclidean()] self.cost_fns = cost_fns self.kwargs = { 'num_a': self.num_a, diff --git a/ott/geometry/pointcloud.py b/ott/geometry/pointcloud.py index 64c10482b..9b518090e 100644 --- a/ott/geometry/pointcloud.py +++ b/ott/geometry/pointcloud.py @@ -65,7 +65,7 @@ def __init__( x: jnp.ndarray, y: Optional[jnp.ndarray] = None, cost_fn: Optional[costs.CostFn] = None, - power: float = 2.0, + power: float = 1.0, batch_size: Optional[int] = None, scale_cost: Union[bool, int, float, Literal['mean', 'max_norm', 'max_bound', 'max_cost', @@ -75,9 +75,9 @@ def __init__( super().__init__(**kwargs) self.x = x self.y = self.x if y is None else y - self._cost_fn = costs.Euclidean() if cost_fn is None else cost_fn + self.cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn self.power = power - self._axis_norm = 0 if callable(self._cost_fn.norm) else None + self._axis_norm = 0 if callable(self.cost_fn.norm) else None if batch_size is not None: assert batch_size > 0, f"`batch_size={batch_size}` must be positive." self._batch_size = batch_size @@ -86,13 +86,13 @@ def __init__( @property def _norm_x(self) -> Union[float, jnp.ndarray]: if self._axis_norm == 0: - return self._cost_fn.norm(self.x) + return self.cost_fn.norm(self.x) return 0. @property def _norm_y(self) -> Union[float, jnp.ndarray]: if self._axis_norm == 0: - return self._cost_fn.norm(self.y) + return self.cost_fn.norm(self.y) return 0. @property @@ -125,7 +125,7 @@ def is_symmetric(self) -> bool: @property def is_squared_euclidean(self) -> bool: - return isinstance(self._cost_fn, costs.Euclidean) and self.power == 2.0 + return isinstance(self.cost_fn, costs.SqEuclidean) and self.power == 1.0 @property def is_online(self) -> bool: @@ -163,7 +163,7 @@ def inv_scale_cost(self) -> float: "the cost matrix with the online mode is not implemented." ) if self._scale_cost == 'max_norm': - if self._cost_fn.norm is not None: + if self.cost_fn.norm is not None: return 1.0 / jnp.maximum(self._norm_x.max(), self._norm_y.max()) return 1.0 if self._scale_cost == 'max_bound': @@ -183,11 +183,11 @@ def inv_scale_cost(self) -> float: raise ValueError(f'Scaling {self._scale_cost} not implemented.') def _compute_cost_matrix(self) -> jnp.ndarray: - cost_matrix = self._cost_fn.all_pairs_pairwise(self.x, self.y) + cost_matrix = self.cost_fn.all_pairs_pairwise(self.x, self.y) if self._axis_norm is not None: cost_matrix += self._norm_x[:, jnp.newaxis] + self._norm_y[jnp.newaxis, :] - if self.power != 2.0: - cost_matrix = jnp.abs(cost_matrix) ** (0.5 * self.power) + if self.power != 1.0: + cost_matrix = jnp.abs(cost_matrix) ** self.power return cost_matrix def apply_lse_kernel( @@ -212,7 +212,7 @@ def body0(carry, i: int): self._norm_y, (i * self.batch_size,), (self.batch_size,) ) h_res, h_sgn = app( - self.x, y, self._norm_x, norm_y, f, g_, eps, vec, self._cost_fn, + self.x, y, self._norm_x, norm_y, f, g_, eps, vec, self.cost_fn, self.power, self.inv_scale_cost ) return carry, (h_res, h_sgn) @@ -230,7 +230,7 @@ def body1(carry, i: int): self._norm_x, (i * self.batch_size,), (self.batch_size,) ) h_res, h_sgn = app( - self.y, x, self._norm_y, norm_x, g, f_, eps, vec, self._cost_fn, + self.y, x, self._norm_y, norm_x, g, f_, eps, vec, self.cost_fn, self.power, self.inv_scale_cost ) return carry, (h_res, h_sgn) @@ -240,12 +240,12 @@ def finalize(i: int): norm_y = self._norm_y if self._axis_norm is None else self._norm_y[i:] return app( self.x, self.y[i:], self._norm_x, norm_y, f, g[i:], eps, vec, - self._cost_fn, self.power, self.inv_scale_cost + self.cost_fn, self.power, self.inv_scale_cost ) norm_x = self._norm_x if self._axis_norm is None else self._norm_x[i:] return app( self.y, self.x[i:], self._norm_y, norm_x, g, f[i:], eps, vec, - self._cost_fn, self.power, self.inv_scale_cost + self.cost_fn, self.power, self.inv_scale_cost ) if not self.is_online: @@ -297,12 +297,12 @@ def apply_kernel( if axis == 0: return app( self.x, self.y, self._norm_x, self._norm_y, scaling, eps, - self._cost_fn, self.power, self.inv_scale_cost + self.cost_fn, self.power, self.inv_scale_cost ) if axis == 1: return app( self.y, self.x, self._norm_y, self._norm_x, scaling, eps, - self._cost_fn, self.power, self.inv_scale_cost + self.cost_fn, self.power, self.inv_scale_cost ) def transport_from_potentials( @@ -318,7 +318,7 @@ def transport_from_potentials( ) return transport( self.y, self.x, self._norm_y, self._norm_x, g, f, self.epsilon, - self._cost_fn, self.power, self.inv_scale_cost + self.cost_fn, self.power, self.inv_scale_cost ) def transport_from_scalings( @@ -334,7 +334,7 @@ def transport_from_scalings( ) return transport( self.y, self.x, self._norm_y, self._norm_x, v, u, self.epsilon, - self._cost_fn, self.power, self.inv_scale_cost + self.cost_fn, self.power, self.inv_scale_cost ) def apply_cost( @@ -387,12 +387,12 @@ def _apply_cost( arr = arr.reshape(-1, 1) if axis == 0: return app( - self.x, self.y, self._norm_x, self._norm_y, arr, self._cost_fn, + self.x, self.y, self._norm_x, self._norm_y, arr, self.cost_fn, self.power, self.inv_scale_cost, fn ) if axis == 1: return app( - self.y, self.x, self._norm_y, self._norm_x, arr, self._cost_fn, + self.y, self.x, self._norm_y, self._norm_x, arr, self.cost_fn, self.power, self.inv_scale_cost, fn ) else: @@ -464,7 +464,7 @@ def body0(carry, i: int): else: norm_y = self._leading_slice(self._norm_y, i) h_res = app( - self.x, y, self._norm_x, norm_y, vec, self._cost_fn, self.power, + self.x, y, self._norm_x, norm_y, vec, self.cost_fn, self.power, scale_cost ) return carry, h_res @@ -477,7 +477,7 @@ def body1(carry, i: int): else: norm_x = self._leading_slice(self._norm_x, i) h_res = app( - self.y, x, self._norm_y, norm_x, vec, self._cost_fn, self.power, + self.y, x, self._norm_y, norm_x, vec, self.cost_fn, self.power, scale_cost ) return carry, h_res @@ -486,12 +486,12 @@ def finalize(i: int): if batch_for_y: norm_y = self._norm_y if self._axis_norm is None else self._norm_y[i:] return app( - self.x, self.y[i:], self._norm_x, norm_y, vec, self._cost_fn, + self.x, self.y[i:], self._norm_x, norm_y, vec, self.cost_fn, self.power, scale_cost ) norm_x = self._norm_x if self._axis_norm is None else self._norm_x[i:] return app( - self.y, self.x[i:], self._norm_y, norm_x, vec, self._cost_fn, + self.y, self.x[i:], self._norm_y, norm_x, vec, self.cost_fn, self.power, scale_cost ) @@ -532,9 +532,9 @@ def finalize(i: int): ) def barycenter(self, weights: jnp.ndarray) -> jnp.ndarray: - """Compute barycenter of points in self.x using weights, valid for p=2.0.""" - assert self.power == 2.0, self.power - return self._cost_fn.barycenter(self.x, weights) + """Compute barycenter of points in self.x using weights, valid for p=1.0.""" + assert self.power == 1.0, self.power + return self.cost_fn.barycenter(self.x, weights) @classmethod def prepare_divergences( @@ -560,7 +560,7 @@ def prepare_divergences( def tree_flatten(self): # passing self.power in aux_data to be able to condition on it. - return ([self.x, self.y, self._src_mask, self._tgt_mask, self._cost_fn], { + return ([self.x, self.y, self._src_mask, self._tgt_mask, self.cost_fn], { 'epsilon': self._epsilon_init, 'relative_epsilon': self._relative_epsilon, 'scale_epsilon': self._scale_epsilon, @@ -577,14 +577,14 @@ def tree_unflatten(cls, aux_data, children): ) def _cosine_to_sqeucl(self) -> 'PointCloud': - assert isinstance(self._cost_fn, costs.Cosine), type(self._cost_fn) - assert self.power == 2, self.power + assert isinstance(self.cost_fn, costs.Cosine), type(self.cost_fn) + assert self.power == 1.0, self.power (x, y, *args, _), aux_data = self.tree_flatten() x = x / jnp.linalg.norm(x, axis=-1, keepdims=True) y = y / jnp.linalg.norm(y, axis=-1, keepdims=True) # TODO(michalk8): find a better way aux_data["scale_cost"] = 2. / self.inv_scale_cost - cost_fn = costs.Euclidean() + cost_fn = costs.SqEuclidean() return type(self).tree_unflatten(aux_data, [x, y] + args + [cost_fn]) def to_LRCGeometry( @@ -767,8 +767,8 @@ def _transport_from_scalings_xy( def _cost(x, y, norm_x, norm_y, cost_fn, cost_pow, scale_cost): one_line_pairwise = jax.vmap(cost_fn.pairwise, in_axes=[0, None]) cost = norm_x + norm_y + one_line_pairwise(x, y) - if cost_pow != 2.0: - cost = jnp.abs(cost) ** (0.5 * cost_pow) + if cost_pow != 1.0: + cost = jnp.abs(cost) ** cost_pow return cost * scale_cost diff --git a/ott/tools/k_means.py b/ott/tools/k_means.py index 844e1f4cf..335d16dc8 100644 --- a/ott/tools/k_means.py +++ b/ott/tools/k_means.py @@ -358,7 +358,7 @@ def k_means( Args: geom: Point cloud of shape ``[n, ndim]`` to cluster. If passed as an array, - :class:`~ott.geometry.costs.Euclidean` cost is assumed. + :class:`~ott.geometry.costs.SqEuclidean` cost is assumed. k: The number of clusters. weights: The weights of input points. These weights are considered when computing the centroids and inertia. If ``None``, use uniform weights. @@ -388,7 +388,7 @@ def k_means( 0] >= k, f"Cannot cluster `{geom.shape[0]}` points into `{k}` clusters." if isinstance(geom, jnp.ndarray): geom = pointcloud.PointCloud(geom) - if isinstance(geom._cost_fn, costs.Cosine): + if isinstance(geom.cost_fn, costs.Cosine): geom = geom._cosine_to_sqeucl() assert geom.is_squared_euclidean diff --git a/ott/tools/plot.py b/ott/tools/plot.py index e8183e8f1..fec2e4037 100644 --- a/ott/tools/plot.py +++ b/ott/tools/plot.py @@ -43,15 +43,15 @@ class Plot: possibilities to create animations as matplotlib.animation.FuncAnimation, which can in turned be saved to disk at will. There are two design principles here: 1) we do not rely on saving to/loading from disk to create animations - 2) we try as much as possible to disentangle the transport problem(s) from the - its visualization(s), leveraging the transport.Transport interface. + 2) we try as much as possible to disentangle the transport problem(s) from + its visualization(s). """ def __init__( self, fig: Optional[plt.Figure] = None, ax: Optional[plt.Axes] = None, - cost_threshold: float = 0.0, + cost_threshold: float = -1.0, # should be negative for animations. scale: int = 200, show_lines: bool = True, cmap: str = 'cool' @@ -167,7 +167,7 @@ def animate( return animation.FuncAnimation( self.fig, lambda i: self.update(transports[i]), - np.arange(1, len(transports)), + np.arange(0, len(transports)), init_func=lambda: self.update(transports[0]), interval=1000 / frame_rate, blit=True diff --git a/ott/tools/segment_sinkhorn.py b/ott/tools/segment_sinkhorn.py index 690bf74e8..b770e06be 100644 --- a/ott/tools/segment_sinkhorn.py +++ b/ott/tools/segment_sinkhorn.py @@ -65,7 +65,7 @@ def segment_sinkhorn( max_measure_size: Total size of measures after padding. Should ideally be set to an upper bound on points clouds processed with the segment interface. Providing this number is required for JIT compilation to work. - cost_fn: Cost function, defaults to :class:`~ott.core.costs.Euclidean`. + cost_fn: Cost function, defaults to :class:`~ott.core.costs.SqEuclidean`. segment_ids_x: **1st interface** The segment ID for which each row of x belongs. This is a similar interface to `jax.ops.segment_sum`. segment_ids_y: **1st interface** The segment ID for which each row of y diff --git a/ott/tools/sinkhorn_divergence.py b/ott/tools/sinkhorn_divergence.py index b3564622d..faa3e1ba0 100644 --- a/ott/tools/sinkhorn_divergence.py +++ b/ott/tools/sinkhorn_divergence.py @@ -228,7 +228,7 @@ def segment_sinkhorn_divergence( set to an upper bound on points clouds processed with the segment interface. Should also be smaller than total length of `x` or `y`. Providing this number is required for JIT compilation to work. - cost_fn: Cost function, defaults to :class:`~ott.core.costs.Euclidean`. + cost_fn: Cost function, defaults to :class:`~ott.core.costs.SqEuclidean`. segment_ids_x: **1st interface** The segment ID for which each row of x belongs. This is a similar interface to :func:`jax.ops.segment_sum`. segment_ids_y: **1st interface** The segment ID for which each row of y diff --git a/tests/core/sinkhorn_diff_test.py b/tests/core/sinkhorn_diff_test.py index f8ab740b1..7e60a8309 100644 --- a/tests/core/sinkhorn_diff_test.py +++ b/tests/core/sinkhorn_diff_test.py @@ -24,7 +24,7 @@ from ott.core import implicit_differentiation as implicit_lib from ott.core import linear_problems, sinkhorn -from ott.geometry import geometry, grid, pointcloud +from ott.geometry import costs, geometry, grid, pointcloud from ott.tools import transport @@ -232,13 +232,13 @@ def loss_fn(cm): np.testing.assert_array_equal(jnp.isnan(custom_grad), False) @pytest.mark.fast.with_args( - "lse_mode,implicit_differentiation,min_iter,max_iter,epsilon", + "lse_mode,implicit_differentiation,min_iter,max_iter,epsilon,power", [ - (True, True, 0, 2000, 1e-3), - (True, True, 1000, 1000, 1e-3), - (True, False, 1000, 1000, 1e-2), - (True, False, 0, 2000, 1e-2), - (False, True, 0, 2000, 1e-2), + (True, True, 0, 2000, 1e-3, 1.0), + (True, True, 1000, 1000, 1e-3, 1.0), + (True, False, 1000, 1000, 1e-2, 2.0), + (True, False, 0, 2000, 1e-2, 2.0), + (False, True, 0, 2000, 1e-2, 1.0), ], ids=[ "lse-implicit", "lse-implicit-force_scan", "lse-backprop-force_scan", @@ -247,13 +247,8 @@ def loss_fn(cm): only_fast=[0, 1], ) def test_gradient_sinkhorn_euclidean( - self, - rng: jnp.ndarray, - lse_mode: bool, - implicit_differentiation: bool, - min_iter: int, - max_iter: int, - epsilon: float, + self, rng: jnp.ndarray, lse_mode: bool, implicit_differentiation: bool, + min_iter: int, max_iter: int, epsilon: float, power: float ): """Test gradient w.r.t. locations x of reg-ot-cost.""" # TODO(cuturi): ensure scaling mode works with backprop. @@ -270,9 +265,16 @@ def test_gradient_sinkhorn_euclidean( b = b.at[3].set(0) a = a / jnp.sum(a) b = b / jnp.sum(b) + # Adding some near-zero distances to test proper handling with power==1.0 + y = y.at[0].set(x[0, :] + 1e-3) def loss_fn(x, y): - geom = pointcloud.PointCloud(x, y, epsilon=epsilon) + geom = pointcloud.PointCloud( + x, + y, + epsilon=epsilon, + cost_fn=costs.SqEuclidean() if power == 2.0 else costs.Euclidean() + ) out = sinkhorn.sinkhorn( geom, a, @@ -283,7 +285,7 @@ def loss_fn(x, y): max_iterations=max_iter, jit=False ) - return out.reg_ot_cost, (geom, out.f, out.g) + return out.reg_ot_cost, out delta = jax.random.normal(keys[0], (n, d)) delta = delta / jnp.sqrt(jnp.vdot(delta, delta)) @@ -291,17 +293,24 @@ def loss_fn(x, y): # first calculation of gradient loss_and_grad = jax.value_and_grad(loss_fn, has_aux=True) - (loss_value, aux), grad_loss = loss_and_grad(x, y) + (loss_value, out), grad_loss = loss_and_grad(x, y) custom_grad = jnp.sum(delta * grad_loss) assert not jnp.isnan(loss_value) np.testing.assert_array_equal(grad_loss.shape, x.shape) np.testing.assert_array_equal(jnp.isnan(grad_loss), False) - # second calculation of gradient - tm = aux[0].transport_from_potentials(aux[1], aux[2]) - tmp = 2 * tm[:, :, None] * (x[:, None, :] - y[None, :, :]) - grad_x = jnp.sum(tmp, 1) - other_grad = jnp.sum(delta * grad_x) + # second calculation of gradient, only valid for power=2.0 + tm = out.matrix + if power == 2.0: + tmp = 2 * tm[:, :, None] * (x[:, None, :] - y[None, :, :]) + grad_x = jnp.sum(tmp, 1) + other_grad = jnp.sum(delta * grad_x) + if power == 1.0: + tmp = tm[:, :, None] * (x[:, None, :] - y[None, :, :]) + norms = jnp.linalg.norm(x[:, None] - y[None, :], axis=-1) + tmp /= norms[:, :, None] + 1e-8 # to stabilize when computed by hand + grad_x = jnp.sum(tmp, 1) + other_grad = jnp.sum(delta * grad_x) # third calculation of gradient loss_delta_plus, _ = loss_fn(x + eps * delta, y) diff --git a/tests/core/sinkhorn_test.py b/tests/core/sinkhorn_test.py index 580b334d9..ae3d04255 100644 --- a/tests/core/sinkhorn_test.py +++ b/tests/core/sinkhorn_test.py @@ -209,12 +209,12 @@ def test_online_vs_batch_euclidean_point_cloud(self, lse_mode: bool): self.x, self.y, epsilon=eps, batch_size=7 ) online_geom_euc = pointcloud.PointCloud( - self.x, self.y, cost_fn=costs.Euclidean(), epsilon=eps, batch_size=10 + self.x, self.y, cost_fn=costs.SqEuclidean(), epsilon=eps, batch_size=10 ) batch_geom = pointcloud.PointCloud(self.x, self.y, epsilon=eps) batch_geom_euc = pointcloud.PointCloud( - self.x, self.y, cost_fn=costs.Euclidean(), epsilon=eps + self.x, self.y, cost_fn=costs.SqEuclidean(), epsilon=eps ) out_online = sinkhorn.sinkhorn( diff --git a/tests/geometry/geometry_pointcloud_apply_test.py b/tests/geometry/geometry_pointcloud_apply_test.py index bc31770a6..0be595e63 100644 --- a/tests/geometry/geometry_pointcloud_apply_test.py +++ b/tests/geometry/geometry_pointcloud_apply_test.py @@ -37,10 +37,10 @@ def test_apply_cost_and_kernel(self, rng: jnp.ndarray): vec0 = jax.random.normal(keys[2], (n, b)) vec1 = jax.random.normal(keys[3], (m, b)) - geom = pointcloud.PointCloud(x, y, power=2, batch_size=3) + geom = pointcloud.PointCloud(x, y, batch_size=3) prod0_online = geom.apply_cost(vec0, axis=0) prod1_online = geom.apply_cost(vec1, axis=1) - geom = pointcloud.PointCloud(x, y, power=2, batch_size=None) + geom = pointcloud.PointCloud(x, y, batch_size=None) prod0 = geom.apply_cost(vec0, axis=0) prod1 = geom.apply_cost(vec1, axis=1) geom = geometry.Geometry(cost) @@ -105,7 +105,7 @@ def test_apply_cost_without_norm(self, rng: jnp.ndarray, axis: 1): pc = pointcloud.PointCloud(x, y, cost_fn=costs.Cosine()) arr = jnp.ones((pc.shape[0],)) if axis == 0 else jnp.ones((pc.shape[1],)) - assert pc._cost_fn.norm is None + assert pc.cost_fn.norm is None with pytest.raises( AssertionError, match=r"Cost matrix is not a squared Euclidean\." ): diff --git a/tests/tools/k_means_test.py b/tests/tools/k_means_test.py index defcfd2c6..416a4a849 100644 --- a/tests/tools/k_means_test.py +++ b/tests/tools/k_means_test.py @@ -24,7 +24,7 @@ def make_blobs( if cost_fn is None: pass elif cost_fn == 'sqeucl': - X = pointcloud.PointCloud(X, cost_fn=costs.Euclidean()) + X = pointcloud.PointCloud(X, cost_fn=costs.SqEuclidean()) elif cost_fn == 'cosine': X = pointcloud.PointCloud(X, cost_fn=costs.Cosine()) else: