Skip to content

Commit

Permalink
Euclidean / SqEuclidean and changes in power (#157)
Browse files Browse the repository at this point in the history
* Euclidean / SqEuclidean and changes in power

* fix, following expose cost_fn in PointCloud

* fix power=2 that were still there

* fix

* fix in plot, *needed* to update lines in anims.
  • Loading branch information
marcocuturi authored Oct 19, 2022
1 parent 438eb11 commit 8646692
Show file tree
Hide file tree
Showing 18 changed files with 105 additions and 84 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Currently implements the following classes and functions:

- The [geometry](ott/geometry) folder describes tools that to encapsulate the essential ingredients of OT problems: measures and cost functions.

- The `CostFn` class in [costs.py](ott/geometry/costs.py) and its descendants define cost functions between points. A few simple costs are considered, `Euclidean` between vectors, and `Bures`, between a pair of mean vector and covariance (p.d.) matrix.
- The `CostFn` class in [costs.py](ott/geometry/costs.py) and its descendants define cost functions between points. A few simple costs are considered, `SqEuclidean` between vectors, and `Bures`, between a pair of mean vector and covariance (p.d.) matrix.

- The `Geometry` class in [geometry.py](ott/geometry/geometry.py) describes a cost structure between two measures. That cost structure is accessed through various member functions, either used when running the Sinkhorn algorithm (typically kernel multiplications, or log-sum-exp row/column-wise application) or after (to apply the OT matrix to a vector).

Expand Down
1 change: 1 addition & 0 deletions docs/geometry.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Cost Functions
:toctree: _autosummary

costs.CostFn
costs.SqEuclidean
costs.Euclidean
costs.Cosine
costs.Bures
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/Sinkhorn_Barycenters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@
"id": "tqg86SFQvzXC"
},
"source": [
"### Euclidean barycenter, for reference"
"### SqEuclidean barycenter, for reference"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions docs/notebooks/introduction_grid.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@
"outputs": [],
"source": [
"@jax.tree_util.register_pytree_node_class\n",
"class EuclideanTimes2(costs.CostFn):\n",
" \"\"\"The cost function corresponding to the squared euclidean distance times 2.\"\"\"\n",
"class SqEuclideanTimes2(costs.CostFn):\n",
" \"\"\"The cost function corresponding to the squared SqEuclidean distance times 2.\"\"\"\n",
"\n",
" def norm(self, x):\n",
" return jnp.sum(x**2, axis=-1) * 2\n",
Expand All @@ -319,7 +319,7 @@
" return -2 * jnp.sum(x * y) * 2\n",
"\n",
"\n",
"cost_fns = [EuclideanTimes2(), costs.Euclidean()]"
"cost_fns = [SqEuclideanTimes2(), costs.SqEuclidean()]"
]
},
{
Expand Down
8 changes: 5 additions & 3 deletions docs/notebooks/point_clouds.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
"source": [
"## Computes the regularized optimal transport\n",
"\n",
"To compute the transport matrix between the two point clouds, one can define a `PointCloud` geometry (which by default uses `ott.geometry.costs.Euclidean` for cost function), then call the `sinkhorn` function, and build the transport matrix from the optimized potentials."
"To compute the transport matrix between the two point clouds, one can define a `PointCloud` geometry (which by default uses `ott.geometry.costs.SqEuclidean` for cost function), then call the `sinkhorn` function, and build the transport matrix from the optimized potentials."
]
},
{
Expand Down Expand Up @@ -235,7 +235,7 @@
" y: jnp.ndarray,\n",
" a: jnp.ndarray,\n",
" b: jnp.ndarray,\n",
" cost_fn=ott.geometry.costs.Euclidean(),\n",
" cost_fn=ott.geometry.costs.SqEuclidean(),\n",
" num_iter: int = 101,\n",
" dump_every: int = 10,\n",
" learning_rate: float = 0.2,\n",
Expand Down Expand Up @@ -5972,7 +5972,9 @@
"source": [
"from IPython import display\n",
"\n",
"ots = optimize(x, y, a, b, num_iter=100, cost_fn=ott.geometry.costs.Euclidean())\n",
"ots = optimize(\n",
" x, y, a, b, num_iter=100, cost_fn=ott.geometry.costs.SqEuclidean()\n",
")\n",
"fig = plt.figure(figsize=(8, 5))\n",
"plott = ott.tools.plot.Plot(fig=fig)\n",
"anim = plott.animate(ots, frame_rate=4)\n",
Expand Down
8 changes: 4 additions & 4 deletions ott/core/bar_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class BarycenterProblem:
weights: Array of shape ``[num_measures,]`` containing the weights of the
measures.
cost_fn: Cost function used. If `None`,
use :class:`~ott.geometry.costs.Euclidean` cost.
use :class:`~ott.geometry.costs.SqEuclidean` cost.
epsilon: Epsilon regularization used to solve reg-OT problems.
debiased: **Currently not implemented.**
Whether the problem is debiased, in the sense that
Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(
raise ValueError("Specify weights if `y` is already segmented.")
self._b = b
self._weights = weights
self.cost_fn = costs.Euclidean() if cost_fn is None else cost_fn
self.cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn
self.epsilon = epsilon
self.debiased = debiased
self._kwargs = kwargs
Expand Down Expand Up @@ -309,7 +309,7 @@ def update_features(self, transports: jnp.ndarray,
"""Update the barycenter features in the fused case :cite:`vayer:19`.
Uses :cite:`cuturi:14` eq. 8, and is implemented only
for the squared :class:`~ott.geometry.costs.Euclidean` cost.
for the squared :class:`~ott.geometry.costs.SqEuclidean` cost.
Args:
transports: Transport maps of shape
Expand All @@ -330,7 +330,7 @@ def update_features(self, transports: jnp.ndarray,
transports = transports * inv_a[None, :, None]

if self._loss_name == "sqeucl":
cost = costs.Euclidean()
cost = costs.SqEuclidean()
return jnp.sum(
weights * barycentric_projection(transports, y_fused, cost), axis=0
)
Expand Down
2 changes: 1 addition & 1 deletion ott/core/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def callback(x: jnp.ndarray) -> float:
cost = pointcloud.PointCloud(
jnp.atleast_2d(x),
y,
cost_fn=self._geom._cost_fn,
cost_fn=self._geom.cost_fn,
power=self._geom.power,
epsilon=1.0 # epsilon is not used
).cost_matrix
Expand Down
11 changes: 10 additions & 1 deletion ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,16 @@ def tree_unflatten(cls, aux_data, children):

@jax.tree_util.register_pytree_node_class
class Euclidean(CostFn):
"""Squared Euclidean distance CostFn."""
"""Euclidean distance."""

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute Euclidean norm."""
return jnp.linalg.norm(x - y)


@jax.tree_util.register_pytree_node_class
class SqEuclidean(CostFn):
"""Squared Euclidean distance."""

def norm(self, x: jnp.ndarray) -> Union[float, jnp.ndarray]:
"""Compute squared Euclidean norm for vector."""
Expand Down
2 changes: 1 addition & 1 deletion ott/geometry/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
raise ValueError('Input either grid_size t-uple or grid locations x.')

if cost_fns is None:
cost_fns = [costs.Euclidean()]
cost_fns = [costs.SqEuclidean()]
self.cost_fns = cost_fns
self.kwargs = {
'num_a': self.num_a,
Expand Down
66 changes: 33 additions & 33 deletions ott/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
x: jnp.ndarray,
y: Optional[jnp.ndarray] = None,
cost_fn: Optional[costs.CostFn] = None,
power: float = 2.0,
power: float = 1.0,
batch_size: Optional[int] = None,
scale_cost: Union[bool, int, float,
Literal['mean', 'max_norm', 'max_bound', 'max_cost',
Expand All @@ -75,9 +75,9 @@ def __init__(
super().__init__(**kwargs)
self.x = x
self.y = self.x if y is None else y
self._cost_fn = costs.Euclidean() if cost_fn is None else cost_fn
self.cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn
self.power = power
self._axis_norm = 0 if callable(self._cost_fn.norm) else None
self._axis_norm = 0 if callable(self.cost_fn.norm) else None
if batch_size is not None:
assert batch_size > 0, f"`batch_size={batch_size}` must be positive."
self._batch_size = batch_size
Expand All @@ -86,13 +86,13 @@ def __init__(
@property
def _norm_x(self) -> Union[float, jnp.ndarray]:
if self._axis_norm == 0:
return self._cost_fn.norm(self.x)
return self.cost_fn.norm(self.x)
return 0.

@property
def _norm_y(self) -> Union[float, jnp.ndarray]:
if self._axis_norm == 0:
return self._cost_fn.norm(self.y)
return self.cost_fn.norm(self.y)
return 0.

@property
Expand Down Expand Up @@ -125,7 +125,7 @@ def is_symmetric(self) -> bool:

@property
def is_squared_euclidean(self) -> bool:
return isinstance(self._cost_fn, costs.Euclidean) and self.power == 2.0
return isinstance(self.cost_fn, costs.SqEuclidean) and self.power == 1.0

@property
def is_online(self) -> bool:
Expand Down Expand Up @@ -163,7 +163,7 @@ def inv_scale_cost(self) -> float:
"the cost matrix with the online mode is not implemented."
)
if self._scale_cost == 'max_norm':
if self._cost_fn.norm is not None:
if self.cost_fn.norm is not None:
return 1.0 / jnp.maximum(self._norm_x.max(), self._norm_y.max())
return 1.0
if self._scale_cost == 'max_bound':
Expand All @@ -183,11 +183,11 @@ def inv_scale_cost(self) -> float:
raise ValueError(f'Scaling {self._scale_cost} not implemented.')

def _compute_cost_matrix(self) -> jnp.ndarray:
cost_matrix = self._cost_fn.all_pairs_pairwise(self.x, self.y)
cost_matrix = self.cost_fn.all_pairs_pairwise(self.x, self.y)
if self._axis_norm is not None:
cost_matrix += self._norm_x[:, jnp.newaxis] + self._norm_y[jnp.newaxis, :]
if self.power != 2.0:
cost_matrix = jnp.abs(cost_matrix) ** (0.5 * self.power)
if self.power != 1.0:
cost_matrix = jnp.abs(cost_matrix) ** self.power
return cost_matrix

def apply_lse_kernel(
Expand All @@ -212,7 +212,7 @@ def body0(carry, i: int):
self._norm_y, (i * self.batch_size,), (self.batch_size,)
)
h_res, h_sgn = app(
self.x, y, self._norm_x, norm_y, f, g_, eps, vec, self._cost_fn,
self.x, y, self._norm_x, norm_y, f, g_, eps, vec, self.cost_fn,
self.power, self.inv_scale_cost
)
return carry, (h_res, h_sgn)
Expand All @@ -230,7 +230,7 @@ def body1(carry, i: int):
self._norm_x, (i * self.batch_size,), (self.batch_size,)
)
h_res, h_sgn = app(
self.y, x, self._norm_y, norm_x, g, f_, eps, vec, self._cost_fn,
self.y, x, self._norm_y, norm_x, g, f_, eps, vec, self.cost_fn,
self.power, self.inv_scale_cost
)
return carry, (h_res, h_sgn)
Expand All @@ -240,12 +240,12 @@ def finalize(i: int):
norm_y = self._norm_y if self._axis_norm is None else self._norm_y[i:]
return app(
self.x, self.y[i:], self._norm_x, norm_y, f, g[i:], eps, vec,
self._cost_fn, self.power, self.inv_scale_cost
self.cost_fn, self.power, self.inv_scale_cost
)
norm_x = self._norm_x if self._axis_norm is None else self._norm_x[i:]
return app(
self.y, self.x[i:], self._norm_y, norm_x, g, f[i:], eps, vec,
self._cost_fn, self.power, self.inv_scale_cost
self.cost_fn, self.power, self.inv_scale_cost
)

if not self.is_online:
Expand Down Expand Up @@ -297,12 +297,12 @@ def apply_kernel(
if axis == 0:
return app(
self.x, self.y, self._norm_x, self._norm_y, scaling, eps,
self._cost_fn, self.power, self.inv_scale_cost
self.cost_fn, self.power, self.inv_scale_cost
)
if axis == 1:
return app(
self.y, self.x, self._norm_y, self._norm_x, scaling, eps,
self._cost_fn, self.power, self.inv_scale_cost
self.cost_fn, self.power, self.inv_scale_cost
)

def transport_from_potentials(
Expand All @@ -318,7 +318,7 @@ def transport_from_potentials(
)
return transport(
self.y, self.x, self._norm_y, self._norm_x, g, f, self.epsilon,
self._cost_fn, self.power, self.inv_scale_cost
self.cost_fn, self.power, self.inv_scale_cost
)

def transport_from_scalings(
Expand All @@ -334,7 +334,7 @@ def transport_from_scalings(
)
return transport(
self.y, self.x, self._norm_y, self._norm_x, v, u, self.epsilon,
self._cost_fn, self.power, self.inv_scale_cost
self.cost_fn, self.power, self.inv_scale_cost
)

def apply_cost(
Expand Down Expand Up @@ -387,12 +387,12 @@ def _apply_cost(
arr = arr.reshape(-1, 1)
if axis == 0:
return app(
self.x, self.y, self._norm_x, self._norm_y, arr, self._cost_fn,
self.x, self.y, self._norm_x, self._norm_y, arr, self.cost_fn,
self.power, self.inv_scale_cost, fn
)
if axis == 1:
return app(
self.y, self.x, self._norm_y, self._norm_x, arr, self._cost_fn,
self.y, self.x, self._norm_y, self._norm_x, arr, self.cost_fn,
self.power, self.inv_scale_cost, fn
)
else:
Expand Down Expand Up @@ -464,7 +464,7 @@ def body0(carry, i: int):
else:
norm_y = self._leading_slice(self._norm_y, i)
h_res = app(
self.x, y, self._norm_x, norm_y, vec, self._cost_fn, self.power,
self.x, y, self._norm_x, norm_y, vec, self.cost_fn, self.power,
scale_cost
)
return carry, h_res
Expand All @@ -477,7 +477,7 @@ def body1(carry, i: int):
else:
norm_x = self._leading_slice(self._norm_x, i)
h_res = app(
self.y, x, self._norm_y, norm_x, vec, self._cost_fn, self.power,
self.y, x, self._norm_y, norm_x, vec, self.cost_fn, self.power,
scale_cost
)
return carry, h_res
Expand All @@ -486,12 +486,12 @@ def finalize(i: int):
if batch_for_y:
norm_y = self._norm_y if self._axis_norm is None else self._norm_y[i:]
return app(
self.x, self.y[i:], self._norm_x, norm_y, vec, self._cost_fn,
self.x, self.y[i:], self._norm_x, norm_y, vec, self.cost_fn,
self.power, scale_cost
)
norm_x = self._norm_x if self._axis_norm is None else self._norm_x[i:]
return app(
self.y, self.x[i:], self._norm_y, norm_x, vec, self._cost_fn,
self.y, self.x[i:], self._norm_y, norm_x, vec, self.cost_fn,
self.power, scale_cost
)

Expand Down Expand Up @@ -532,9 +532,9 @@ def finalize(i: int):
)

def barycenter(self, weights: jnp.ndarray) -> jnp.ndarray:
"""Compute barycenter of points in self.x using weights, valid for p=2.0."""
assert self.power == 2.0, self.power
return self._cost_fn.barycenter(self.x, weights)
"""Compute barycenter of points in self.x using weights, valid for p=1.0."""
assert self.power == 1.0, self.power
return self.cost_fn.barycenter(self.x, weights)

@classmethod
def prepare_divergences(
Expand All @@ -560,7 +560,7 @@ def prepare_divergences(

def tree_flatten(self):
# passing self.power in aux_data to be able to condition on it.
return ([self.x, self.y, self._src_mask, self._tgt_mask, self._cost_fn], {
return ([self.x, self.y, self._src_mask, self._tgt_mask, self.cost_fn], {
'epsilon': self._epsilon_init,
'relative_epsilon': self._relative_epsilon,
'scale_epsilon': self._scale_epsilon,
Expand All @@ -577,14 +577,14 @@ def tree_unflatten(cls, aux_data, children):
)

def _cosine_to_sqeucl(self) -> 'PointCloud':
assert isinstance(self._cost_fn, costs.Cosine), type(self._cost_fn)
assert self.power == 2, self.power
assert isinstance(self.cost_fn, costs.Cosine), type(self.cost_fn)
assert self.power == 1.0, self.power
(x, y, *args, _), aux_data = self.tree_flatten()
x = x / jnp.linalg.norm(x, axis=-1, keepdims=True)
y = y / jnp.linalg.norm(y, axis=-1, keepdims=True)
# TODO(michalk8): find a better way
aux_data["scale_cost"] = 2. / self.inv_scale_cost
cost_fn = costs.Euclidean()
cost_fn = costs.SqEuclidean()
return type(self).tree_unflatten(aux_data, [x, y] + args + [cost_fn])

def to_LRCGeometry(
Expand Down Expand Up @@ -767,8 +767,8 @@ def _transport_from_scalings_xy(
def _cost(x, y, norm_x, norm_y, cost_fn, cost_pow, scale_cost):
one_line_pairwise = jax.vmap(cost_fn.pairwise, in_axes=[0, None])
cost = norm_x + norm_y + one_line_pairwise(x, y)
if cost_pow != 2.0:
cost = jnp.abs(cost) ** (0.5 * cost_pow)
if cost_pow != 1.0:
cost = jnp.abs(cost) ** cost_pow
return cost * scale_cost


Expand Down
4 changes: 2 additions & 2 deletions ott/tools/k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def k_means(
Args:
geom: Point cloud of shape ``[n, ndim]`` to cluster. If passed as an array,
:class:`~ott.geometry.costs.Euclidean` cost is assumed.
:class:`~ott.geometry.costs.SqEuclidean` cost is assumed.
k: The number of clusters.
weights: The weights of input points. These weights are considered when
computing the centroids and inertia. If ``None``, use uniform weights.
Expand Down Expand Up @@ -388,7 +388,7 @@ def k_means(
0] >= k, f"Cannot cluster `{geom.shape[0]}` points into `{k}` clusters."
if isinstance(geom, jnp.ndarray):
geom = pointcloud.PointCloud(geom)
if isinstance(geom._cost_fn, costs.Cosine):
if isinstance(geom.cost_fn, costs.Cosine):
geom = geom._cosine_to_sqeucl()
assert geom.is_squared_euclidean

Expand Down
Loading

0 comments on commit 8646692

Please sign in to comment.