Skip to content

Commit

Permalink
Correct mistake on defining norm
Browse files Browse the repository at this point in the history
  • Loading branch information
marcocuturi committed Nov 18, 2024
1 parent 4fd262d commit 8ac3ea5
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
9 changes: 5 additions & 4 deletions src/ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,22 +348,23 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102


@jtu.register_pytree_node_class
class PNorm(TICost):
r""":math:`p`-norm between vectors.
class EuclideanP(TICost):
r""":math:`p`-power of Euclidean norm.
Uses custom implementation of `norm` to avoid `NaN` values when
differentiating the norm of :math:`x-x`.
Args:
p: Power of the p-norm in :math:`[1, +\infty)`.
p: Power used to raise Euclidean norm, in :math:`[1, +\infty)`.
"""

def __init__(self, p: float):
super().__init__()
self.p = p

def h(self, z: jnp.ndarray) -> float: # noqa: D102
return mu.norm(z, self.p) / self.p
# Computed by raising squared-norm to p/2.
return mu.norm(z) ** (self.p / 2.)

def tree_flatten(self): # noqa: D102
return (), (self.p,)
Expand Down
2 changes: 1 addition & 1 deletion src/ott/tools/unreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,6 @@ def wassdis_p(x: jnp.ndarray, y: jnp.ndarray, p: float = 2.0) -> float:
Returns:
The p-Wasserstein distance between these point clouds.hungarian
"""
geom = pointcloud.PointCloud(x, y, cost_fn=costs.PNorm(p))
geom = pointcloud.PointCloud(x, y, cost_fn=costs.EuclideanP(p))
cost, _ = hungarian(geom)
return cost ** 1. / p
2 changes: 1 addition & 1 deletion tests/tools/unreg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_wass(self, rng: jax.Array, p: float):
n, m, dim = 12, 12, 5
rng1, rng2 = jax.random.split(rng, 2)
x, y = gen_data(rng1, n, m, dim)
geom = pointcloud.PointCloud(x, y, cost_fn=costs.PNorm(p=p))
geom = pointcloud.PointCloud(x, y, cost_fn=costs.EuclideanP(p=p))
cost_hung, _ = unreg.hungarian(geom)
w_p = unreg.wassdis_p(x, y, p)
np.testing.assert_allclose(w_p, cost_hung ** 1. / p, rtol=1e-3, atol=1e-3)
Expand Down

0 comments on commit 8ac3ea5

Please sign in to comment.