Skip to content

Commit

Permalink
Merge pull request #1169 from google-deepmind:inject-random-key-fix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 714980741
  • Loading branch information
OptaxDev committed Jan 13, 2025
2 parents 831ddbf + f8142c7 commit 6ff3dca
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
3 changes: 2 additions & 1 deletion optax/schedules/_inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@

def _convert_floats(x, dtype):
"""Convert float-like inputs to dtype, rest pass through."""
if jax.dtypes.scalar_type_of(x) is float:
current_dtype = x.dtype if hasattr(x, 'dtype') else type(x)
if jax.dtypes.issubdtype(current_dtype, jnp.floating):
return jnp.asarray(x, dtype=dtype)
return x

Expand Down
41 changes: 41 additions & 0 deletions optax/schedules/_inject_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
from absl.testing import parameterized
import chex
import jax
from jax import random
import jax.numpy as jnp
import numpy as np

from optax._src import base
from optax._src import clipping
from optax._src import transform
Expand Down Expand Up @@ -161,6 +163,45 @@ def test_numeric_static_args(self, static_args):

assert not set(state.hyperparams.keys()).intersection(set(static_args))

@chex.all_variants
def test_prng_key_not_hyperparameter(self):
"""Check that random.key can be handled by :func:``inject_hyperparams``."""

def random_noise_optimizer(
key: chex.PRNGKey, scale: jax.Array
) -> base.GradientTransformation:
def init_fn(params_like: base.Params) -> tuple[chex.PRNGKey,
jax.Array | float]:
del params_like
nonlocal key, scale
return (key, scale)

def update_fn(
updates: base.Updates,
state: tuple[chex.PRNGKey, jax.Array],
params: None = None,
) -> tuple[base.Updates, tuple[chex.PRNGKey, jax.Array | float]]:
del params
key, scale = state
keyit = iter(random.split(key, len(jax.tree.leaves(updates)) + 1))
new_updates = jax.tree.map(
lambda x: scale * random.normal(next(keyit), x.shape), updates
)
new_key = next(keyit)
return new_updates, (new_key, scale)

return base.GradientTransformation(init_fn, update_fn)

optim = _inject.inject_hyperparams(random_noise_optimizer)(
key=random.key(17), scale=1e-3
)

params = [jnp.ones((1, 2)), jnp.ones(2), jnp.ones((1, 1, 1))]
grads = params
state = self.variant(optim.init)(params)
_, state = self.variant(optim.update)(grads, state)
del state

@chex.all_variants
@parameterized.named_parameters(
('bf16hyp f32param bf16grad', jnp.bfloat16, jnp.float32, jnp.bfloat16),
Expand Down

0 comments on commit 6ff3dca

Please sign in to comment.