Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optax doctests fix #1176

Merged
merged 1 commit into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions docs/api/optimizer_wrappers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@ Masked update
.. autofunction:: masked
.. autoclass:: MaskedState

Maybe update
~~~~~~~~~~~~
.. autofunction:: maybe_update
.. autoclass:: MaybeUpdateState

Multi-step update
~~~~~~~~~~~~~~~~~
.. autoclass:: MultiSteps
Expand Down
2 changes: 1 addition & 1 deletion optax/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@
from optax.losses._regression import log_cosh
from optax.losses._regression import squared_error
from optax.losses._self_supervised import ntxent
from optax.losses._self_supervised import triplet_loss
from optax.losses._self_supervised import triplet_margin_loss
from optax.losses._smoothing import smooth_labels
13 changes: 7 additions & 6 deletions optax/losses/_self_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def ntxent(
[0.8622629 ]
[0.13612625]]
>>> loss = optax.ntxent(out, labels)
>>> print("loss:", loss)
>>> print("loss:", loss) # doctest: +SKIP
loss: 1.0986123

Args:
Expand Down Expand Up @@ -120,7 +120,7 @@ def ntxent(
return loss


def triplet_loss(
def triplet_margin_loss(
anchors: chex.Array,
positives: chex.Array,
negatives: chex.Array,
Expand All @@ -136,9 +136,10 @@ def triplet_loss(
>>> anchors = jnp.array([[0.0, 0.0], [1.0, 1.0]])
>>> positives = jnp.array([[0.1, 0.1], [1.1, 1.1]])
>>> negatives = jnp.array([[1.0, 0.0], [0.0, 1.0]])
>>> output = optax.triplet_loss(anchors, positives, negatives, margin=1.0)
>>> output = optax.losses.triplet_margin_loss(anchors, positives, negatives,
... margin=1.0)
>>> print(output)
>>> Array([0.14142442, 0.14142442], dtype=float32)
[0.14142442 0.14142442]

Args:
anchors: An array of anchor embeddings, with shape [batch, feature_dim].
Expand All @@ -149,9 +150,9 @@ def triplet_loss(
axis: The axis along which to compute the distances (default is -1).
norm_degree: The norm degree for distance calculation (default is 2 for
Euclidean distance).
margin: The minimum margin by which the positive distance should be
margin: The minimum margin by which the positive distance should be
smaller than the negative distance.
eps: A small epsilon value to ensure numerical stability in the distance
eps: A small epsilon value to ensure numerical stability in the distance
calculation.

Returns:
Expand Down
12 changes: 6 additions & 6 deletions optax/losses/_self_supervised_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@ def setUp(self):
}
])
def test_batched(self, anchor, positive, negative, margin):
def testing_triplet_loss(a, p, n, margin=1.0, p_norm=2, eps=1e-6):
def testing_triplet_margin_loss(a, p, n, margin=1.0, p_norm=2, eps=1e-6):
ap_distance = jnp.sqrt(jnp.sum(jnp.power(a - p, p_norm)) + eps)
an_distance = jnp.sqrt(jnp.sum(jnp.power(a - n, p_norm)) + eps)
return jnp.maximum(ap_distance - an_distance + margin, 0)

handmade_result = testing_triplet_loss(
handmade_result = testing_triplet_margin_loss(
a=anchor, p=positive, n=negative, margin=margin
)
result = self.variant(_self_supervised.triplet_loss)(
result = self.variant(_self_supervised.triplet_margin_loss)(
anchor, positive, negative
)
np.testing.assert_allclose(result, handmade_result, atol=1e-4)
Expand All @@ -117,13 +117,13 @@ def testing_triplet_loss(a, p, n, margin=1.0, p_norm=2, eps=1e-6):
},
])
def test_vmap(self, anchor, positive, negative):
original_loss = _self_supervised.triplet_loss(anchor, positive,
negative)
original_loss = _self_supervised.triplet_margin_loss(anchor, positive,
negative)
anchor_batched = anchor.reshape(1, *anchor.shape)
positive_batched = positive.reshape(1, *positive.shape)
negative_batched = negative.reshape(1, *negative.shape)
vmap_loss = self.variant(
jax.vmap(_self_supervised.triplet_loss, in_axes=(0, 0, 0)))(
jax.vmap(_self_supervised.triplet_margin_loss, in_axes=(0, 0, 0)))(
anchor_batched, positive_batched, negative_batched)
np.testing.assert_allclose(vmap_loss.flatten(), original_loss.flatten()
, atol=1e-4)
Expand Down
Loading