diff --git a/docs/api/optimizer_wrappers.rst b/docs/api/optimizer_wrappers.rst index 7749fbd6..c4107ef4 100644 --- a/docs/api/optimizer_wrappers.rst +++ b/docs/api/optimizer_wrappers.rst @@ -39,11 +39,6 @@ Masked update .. autofunction:: masked .. autoclass:: MaskedState -Maybe update -~~~~~~~~~~~~ -.. autofunction:: maybe_update -.. autoclass:: MaybeUpdateState - Multi-step update ~~~~~~~~~~~~~~~~~ .. autoclass:: MultiSteps diff --git a/optax/losses/__init__.py b/optax/losses/__init__.py index e94ea083..627b1afb 100644 --- a/optax/losses/__init__.py +++ b/optax/losses/__init__.py @@ -42,5 +42,5 @@ from optax.losses._regression import log_cosh from optax.losses._regression import squared_error from optax.losses._self_supervised import ntxent -from optax.losses._self_supervised import triplet_loss +from optax.losses._self_supervised import triplet_margin_loss from optax.losses._smoothing import smooth_labels diff --git a/optax/losses/_self_supervised.py b/optax/losses/_self_supervised.py index be992da6..5c3a5e22 100644 --- a/optax/losses/_self_supervised.py +++ b/optax/losses/_self_supervised.py @@ -53,7 +53,7 @@ def ntxent( [0.8622629 ] [0.13612625]] >>> loss = optax.ntxent(out, labels) - >>> print("loss:", loss) + >>> print("loss:", loss) # doctest: +SKIP loss: 1.0986123 Args: @@ -120,7 +120,7 @@ def ntxent( return loss -def triplet_loss( +def triplet_margin_loss( anchors: chex.Array, positives: chex.Array, negatives: chex.Array, @@ -136,9 +136,10 @@ def triplet_loss( >>> anchors = jnp.array([[0.0, 0.0], [1.0, 1.0]]) >>> positives = jnp.array([[0.1, 0.1], [1.1, 1.1]]) >>> negatives = jnp.array([[1.0, 0.0], [0.0, 1.0]]) - >>> output = optax.triplet_loss(anchors, positives, negatives, margin=1.0) + >>> output = optax.losses.triplet_margin_loss(anchors, positives, negatives, + ... margin=1.0) >>> print(output) - >>> Array([0.14142442, 0.14142442], dtype=float32) + [0.14142442 0.14142442] Args: anchors: An array of anchor embeddings, with shape [batch, feature_dim]. @@ -149,9 +150,9 @@ def triplet_loss( axis: The axis along which to compute the distances (default is -1). norm_degree: The norm degree for distance calculation (default is 2 for Euclidean distance). - margin: The minimum margin by which the positive distance should be + margin: The minimum margin by which the positive distance should be smaller than the negative distance. - eps: A small epsilon value to ensure numerical stability in the distance + eps: A small epsilon value to ensure numerical stability in the distance calculation. Returns: diff --git a/optax/losses/_self_supervised_test.py b/optax/losses/_self_supervised_test.py index ecabf321..8d4d4259 100644 --- a/optax/losses/_self_supervised_test.py +++ b/optax/losses/_self_supervised_test.py @@ -95,15 +95,15 @@ def setUp(self): } ]) def test_batched(self, anchor, positive, negative, margin): - def testing_triplet_loss(a, p, n, margin=1.0, p_norm=2, eps=1e-6): + def testing_triplet_margin_loss(a, p, n, margin=1.0, p_norm=2, eps=1e-6): ap_distance = jnp.sqrt(jnp.sum(jnp.power(a - p, p_norm)) + eps) an_distance = jnp.sqrt(jnp.sum(jnp.power(a - n, p_norm)) + eps) return jnp.maximum(ap_distance - an_distance + margin, 0) - handmade_result = testing_triplet_loss( + handmade_result = testing_triplet_margin_loss( a=anchor, p=positive, n=negative, margin=margin ) - result = self.variant(_self_supervised.triplet_loss)( + result = self.variant(_self_supervised.triplet_margin_loss)( anchor, positive, negative ) np.testing.assert_allclose(result, handmade_result, atol=1e-4) @@ -117,13 +117,13 @@ def testing_triplet_loss(a, p, n, margin=1.0, p_norm=2, eps=1e-6): }, ]) def test_vmap(self, anchor, positive, negative): - original_loss = _self_supervised.triplet_loss(anchor, positive, - negative) + original_loss = _self_supervised.triplet_margin_loss(anchor, positive, + negative) anchor_batched = anchor.reshape(1, *anchor.shape) positive_batched = positive.reshape(1, *positive.shape) negative_batched = negative.reshape(1, *negative.shape) vmap_loss = self.variant( - jax.vmap(_self_supervised.triplet_loss, in_axes=(0, 0, 0)))( + jax.vmap(_self_supervised.triplet_margin_loss, in_axes=(0, 0, 0)))( anchor_batched, positive_batched, negative_batched) np.testing.assert_allclose(vmap_loss.flatten(), original_loss.flatten() , atol=1e-4)