Skip to content

Commit

Permalink
optax doctests fix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715861186
  • Loading branch information
rdyro authored and OptaxDev committed Jan 15, 2025
1 parent b05d247 commit 3632938
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 18 deletions.
5 changes: 0 additions & 5 deletions docs/api/optimizer_wrappers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@ Masked update
.. autofunction:: masked
.. autoclass:: MaskedState

Maybe update
~~~~~~~~~~~~
.. autofunction:: maybe_update
.. autoclass:: MaybeUpdateState

Multi-step update
~~~~~~~~~~~~~~~~~
.. autoclass:: MultiSteps
Expand Down
2 changes: 1 addition & 1 deletion optax/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@
from optax.losses._regression import log_cosh
from optax.losses._regression import squared_error
from optax.losses._self_supervised import ntxent
from optax.losses._self_supervised import triplet_loss
from optax.losses._self_supervised import triplet_margin_loss
from optax.losses._smoothing import smooth_labels
13 changes: 7 additions & 6 deletions optax/losses/_self_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def ntxent(
>>> x = jax.random.normal(key1, shape=(4,2))
>>> labels = jnp.array([0, 0, 1, 1])
>>>
>>> print("input:", x)
>>> print("input:", x) # doctest: +SKIP
input: [[-0.9155995 1.5534698 ]
[ 0.2623586 -1.5908985 ]
[-0.15977189 0.480501 ]
Expand All @@ -47,13 +47,13 @@ def ntxent(
>>> b = jax.random.normal(key3, shape=(1,)) # params
>>> out = x @ w + b # model
>>>
>>> print("Embeddings:", out)
>>> print("Embeddings:", out) # doctest: +SKIP
Embeddings: [[-1.0076267]
[-1.2960069]
[-1.1829865]
[-1.3485558]]
>>> loss = optax.ntxent(out, labels)
>>> print("loss:", loss)
>>> print("loss:", loss) # doctest: +SKIP
loss: 1.0986123
Args:
Expand Down Expand Up @@ -120,7 +120,7 @@ def ntxent(
return loss


def triplet_loss(
def triplet_margin_loss(
anchors: chex.Array,
positives: chex.Array,
negatives: chex.Array,
Expand All @@ -136,9 +136,10 @@ def triplet_loss(
>>> anchors = jnp.array([[0.0, 0.0], [1.0, 1.0]])
>>> positives = jnp.array([[0.1, 0.1], [1.1, 1.1]])
>>> negatives = jnp.array([[1.0, 0.0], [0.0, 1.0]])
>>> output = optax.triplet_loss(anchors, positives, negatives, margin=1.0)
>>> output = optax.losses.triplet_margin_loss(anchors, positives, negatives,
... margin=1.0)
>>> print(output)
>>> Array([0.14142442, 0.14142442], dtype=float32)
[0.14142442 0.14142442]
Args:
anchors: An array of anchor embeddings, with shape [batch, feature_dim].
Expand Down
12 changes: 6 additions & 6 deletions optax/losses/_self_supervised_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@ def setUp(self):
}
])
def test_batched(self, anchor, positive, negative, margin):
def testing_triplet_loss(a, p, n, margin=1.0, p_norm=2, eps=1e-6):
def testing_triplet_margin_loss(a, p, n, margin=1.0, p_norm=2, eps=1e-6):
ap_distance = jnp.sqrt(jnp.sum(jnp.power(a - p, p_norm)) + eps)
an_distance = jnp.sqrt(jnp.sum(jnp.power(a - n, p_norm)) + eps)
return jnp.maximum(ap_distance - an_distance + margin, 0)

handmade_result = testing_triplet_loss(
handmade_result = testing_triplet_margin_loss(
a=anchor, p=positive, n=negative, margin=margin
)
result = self.variant(_self_supervised.triplet_loss)(
result = self.variant(_self_supervised.triplet_margin_loss)(
anchor, positive, negative
)
np.testing.assert_allclose(result, handmade_result, atol=1e-4)
Expand All @@ -117,13 +117,13 @@ def testing_triplet_loss(a, p, n, margin=1.0, p_norm=2, eps=1e-6):
},
])
def test_vmap(self, anchor, positive, negative):
original_loss = _self_supervised.triplet_loss(anchor, positive,
negative)
original_loss = _self_supervised.triplet_margin_loss(anchor, positive,
negative)
anchor_batched = anchor.reshape(1, *anchor.shape)
positive_batched = positive.reshape(1, *positive.shape)
negative_batched = negative.reshape(1, *negative.shape)
vmap_loss = self.variant(
jax.vmap(_self_supervised.triplet_loss, in_axes=(0, 0, 0)))(
jax.vmap(_self_supervised.triplet_margin_loss, in_axes=(0, 0, 0)))(
anchor_batched, positive_batched, negative_batched)
np.testing.assert_allclose(vmap_loss.flatten(), original_loss.flatten()
, atol=1e-4)
Expand Down

0 comments on commit 3632938

Please sign in to comment.