Skip to content

Commit

Permalink
Updating optax to use jax.random.key instead of PRNGKey
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715397021
  • Loading branch information
rdyro authored and OptaxDev committed Jan 14, 2025
1 parent 6ff3dca commit b05d247
Show file tree
Hide file tree
Showing 15 changed files with 64 additions and 42 deletions.
12 changes: 7 additions & 5 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams(
chex.assert_trees_all_close(updates_inject, updates, rtol=1e-4)
with self.subTest('Equality of new optimizer states.'):
chex.assert_trees_all_close(
new_state_inject.inner_state, new_state, rtol=1e-4
otu.tree_unwrap_random_key_data(new_state_inject.inner_state),
otu.tree_unwrap_random_key_data(new_state),
rtol=1e-4,
)

@parameterized.product(
Expand Down Expand Up @@ -573,7 +575,7 @@ def zakharov(x, xnp):
class LBFGSTest(chex.TestCase):

def test_plain_preconditioning(self):
key = jrd.PRNGKey(0)
key = jrd.key(0)
key_ws, key_us, key_vec = jrd.split(key, 3)
m = 4
d = 3
Expand All @@ -592,7 +594,7 @@ def test_plain_preconditioning(self):

@parameterized.product(idx=[0, 1, 2, 3])
def test_preconditioning_by_lbfgs_on_vectors(self, idx: int):
key = jrd.PRNGKey(0)
key = jrd.key(0)
key_ws, key_us, key_vec = jrd.split(key, 3)
m = 4
d = 3
Expand All @@ -619,7 +621,7 @@ def test_preconditioning_by_lbfgs_on_vectors(self, idx: int):

@parameterized.product(idx=[0, 1, 2, 3])
def test_preconditioning_by_lbfgs_on_trees(self, idx: int):
key = jrd.PRNGKey(0)
key = jrd.key(0)
key_ws, key_us, key_vec = jrd.split(key, 3)
m = 4
shapes = ((3, 2), (5,))
Expand Down Expand Up @@ -721,7 +723,7 @@ def fun_(x):
def fun(x):
return otu.tree_sum(jax.tree.map(fun_, x))

key = jrd.PRNGKey(0)
key = jrd.key(0)
init_array = jrd.normal(key, (2, 4))
init_tree = (init_array[0], init_array[1])

Expand Down
6 changes: 3 additions & 3 deletions optax/_src/linear_algebra_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_global_norm(self):

def test_power_iteration_cond_fun(self, dim=6):
"""Test the condition function for power iteration."""
matrix = jax.random.normal(jax.random.PRNGKey(0), (dim, dim))
matrix = jax.random.normal(jax.random.key(0), (dim, dim))
matrix = matrix @ matrix.T
all_eigenval, all_eigenvec = jax.numpy.linalg.eigh(matrix)
dominant_eigenval = all_eigenval[-1]
Expand Down Expand Up @@ -102,7 +102,7 @@ def power_iteration(matrix, *, v0):
power_iteration = self.variant(power_iteration)

# create a random PSD matrix
matrix = jax.random.normal(jax.random.PRNGKey(0), (dim, dim))
matrix = jax.random.normal(jax.random.key(0), (dim, dim))
matrix = matrix @ matrix.T
v0 = jnp.ones((dim,))

Expand Down Expand Up @@ -148,7 +148,7 @@ def test_power_iteration_mlp_hessian(
):
"""Test power_iteration on the Hessian of an MLP."""
mlp = MLP(num_outputs=output_dim, hidden_sizes=[input_dim, 8, output_dim])
key = jax.random.PRNGKey(0)
key = jax.random.key(0)
key_params, key_input, key_output = jax.random.split(key, 3)
# initialize the mlp
params = mlp.init(key_params, jnp.ones(input_dim))
Expand Down
6 changes: 3 additions & 3 deletions optax/_src/linesearch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_linesearch(
seed,
):
"""Test backtracking linesearch (single update step)."""
key = jrd.PRNGKey(seed)
key = jrd.key(seed)
problem = get_problem(problem_name)
fn, input_shape = problem['fn'], problem['input_shape']
init_params = jrd.normal(key, input_shape)
Expand Down Expand Up @@ -270,7 +270,7 @@ def fn(params, x, y):

# Create artificial data
noise = 1e-3
key = jrd.PRNGKey(0)
key = jrd.key(0)
x_key, y_key, params_key = jrd.split(key, 3)
d, m, n = 2, 16, 2
xs = jrd.normal(x_key, (n, m, d))
Expand Down Expand Up @@ -446,7 +446,7 @@ def test_linesearch(self, problem_name: str, seed: int):
curv_rtol = 0.9
tol = 0.0

key = jrd.PRNGKey(seed)
key = jrd.key(seed)
params_key, precond_key = jrd.split(key, 2)
problem = get_problem(problem_name)
fn, input_shape = problem['fn'], problem['input_shape']
Expand Down
2 changes: 1 addition & 1 deletion optax/_src/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_sample_input_sequence_types(self, sample_type):
loc_shape = scale_shape = (2, 3)
loc, scale = self._get_loc_scale(loc_shape, scale_shape)
dist = utils.multi_normal(loc, scale)
samples = dist.sample(sample_shape, jax.random.PRNGKey(239))
samples = dist.sample(sample_shape, jax.random.key(239))
self.assertEqual(samples.shape, tuple(sample_shape) + loc_shape)

@parameterized.named_parameters([
Expand Down
2 changes: 1 addition & 1 deletion optax/losses/_fenchel_young_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class FenchelYoungTest(chex.TestCase):
def test_fenchel_young_reg(self):
# Checks the behavior of the Fenchel-Young loss.
fy_loss = self.variant(_fenchel_young.make_fenchel_young_loss(logsumexp))
rng = jax.random.PRNGKey(0)
rng = jax.random.key(0)
rngs = jax.random.split(rng, 2)
theta_true = jax.random.uniform(rngs[0], (8, 5))
y_true = jax.vmap(jax.nn.softmax)(theta_true)
Expand Down
22 changes: 11 additions & 11 deletions optax/monte_carlo/control_variates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class DeltaControlVariateTest(chex.TestCase):
def test_quadratic_function(self, effective_mean, effective_log_scale):
data_dims = 20
num_samples = 10**6
rng = jax.random.PRNGKey(1)
rng = jax.random.key(1)

mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32)
log_scale = effective_log_scale * jnp.ones(
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_polynomial_function(self, effective_mean, effective_log_scale):
params = [mean, log_scale]

dist = utils.multi_normal(*params)
rng = jax.random.PRNGKey(1)
rng = jax.random.key(1)
dist_samples = dist.sample((num_samples,), rng)
function = lambda x: jnp.sum(x**5)

Expand All @@ -131,7 +131,7 @@ def test_non_polynomial_function(self):
log_scale = jnp.ones(shape=(data_dims), dtype=jnp.float32)
params = [mean, log_scale]

rng = jax.random.PRNGKey(1)
rng = jax.random.key(1)
dist = utils.multi_normal(*params)
dist_samples = dist.sample((num_samples,), rng)
function = lambda x: jnp.sum(jnp.log(x**2))
Expand Down Expand Up @@ -167,7 +167,7 @@ def test_linear_function(self, effective_mean, effective_log_scale, decay):
params = [mean, log_scale]
function = lambda x: jnp.sum(weights * x)

rng = jax.random.PRNGKey(1)
rng = jax.random.key(1)
dist = utils.multi_normal(*params)
dist_samples = dist.sample((num_samples,), rng)

Expand Down Expand Up @@ -219,7 +219,7 @@ def test_linear_function_with_heuristic(
params = [mean, log_scale]
function = lambda x: jnp.sum(weights * x)

rng = jax.random.PRNGKey(1)
rng = jax.random.key(1)
dist = utils.multi_normal(*params)
dist_samples = dist.sample((num_samples,), rng)

Expand Down Expand Up @@ -274,7 +274,7 @@ def test_linear_function_zero_debias(
params = [mean, log_scale]
function = lambda x: jnp.sum(weights * x)

rng = jax.random.PRNGKey(1)
rng = jax.random.key(1)
dist = utils.multi_normal(*params)
dist_samples = dist.sample((num_samples,), rng)

Expand Down Expand Up @@ -346,7 +346,7 @@ def test_quadratic_function(

params = [mean, log_scale]
function = lambda x: jnp.sum(x**2)
rng = jax.random.PRNGKey(1)
rng = jax.random.key(1)

jacobians = _cv_jac_variant(self.variant)(
function,
Expand Down Expand Up @@ -429,7 +429,7 @@ def test_cubic_function(

params = [mean, log_scale]
function = lambda x: jnp.sum(x**3)
rng = jax.random.PRNGKey(1)
rng = jax.random.key(1)

jacobians = _cv_jac_variant(self.variant)(
function,
Expand Down Expand Up @@ -516,7 +516,7 @@ def test_forth_power_function(

params = [mean, log_scale]
function = lambda x: jnp.sum(x**4)
rng = jax.random.PRNGKey(1)
rng = jax.random.key(1)

jacobians = _cv_jac_variant(self.variant)(
function,
Expand Down Expand Up @@ -613,7 +613,7 @@ def test_weighted_linear_function(

params = [mean, log_scale]
function = lambda x: jnp.sum(weights * x)
rng = jax.random.PRNGKey(1)
rng = jax.random.key(1)
cv_rng, ge_rng = jax.random.split(rng)

jacobians = _cv_jac_variant(self.variant)(
Expand Down Expand Up @@ -702,7 +702,7 @@ def test_non_polynomial_function(

params = [mean, log_scale]
function = lambda x: jnp.log(jnp.sum(x**2))
rng = jax.random.PRNGKey(1)
rng = jax.random.key(1)
cv_rng, ge_rng = jax.random.split(rng)

jacobians = _cv_jac_variant(self.variant)(
Expand Down
14 changes: 7 additions & 7 deletions optax/monte_carlo/stochastic_gradient_estimators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_constant_function(self, estimator, constant):

effective_log_scale = 0.0
log_scale = effective_log_scale * _ones(data_dims)
rng = jax.random.PRNGKey(1)
rng = jax.random.key(1)

jacobians = _estimator_variant(self.variant, estimator)(
lambda x: jnp.array(constant),
Expand Down Expand Up @@ -144,7 +144,7 @@ def test_linear_function(
):
data_dims = 3
num_samples = _estimator_to_num_samples[estimator]
rng = jax.random.PRNGKey(1)
rng = jax.random.key(1)

mean = effective_mean * _ones(data_dims)
log_scale = effective_log_scale * _ones(data_dims)
Expand Down Expand Up @@ -185,7 +185,7 @@ def test_quadratic_function(
):
data_dims = 3
num_samples = _estimator_to_num_samples[estimator]
rng = jax.random.PRNGKey(1)
rng = jax.random.key(1)

mean = effective_mean * _ones(data_dims)
log_scale = effective_log_scale * _ones(data_dims)
Expand Down Expand Up @@ -233,7 +233,7 @@ def test_weighted_linear(
self, estimator, effective_mean, effective_log_scale, weights
):
num_samples = _weighted_estimator_to_num_samples[estimator]
rng = jax.random.PRNGKey(1)
rng = jax.random.key(1)

mean = jnp.array(effective_mean)
log_scale = jnp.array(effective_log_scale)
Expand Down Expand Up @@ -280,7 +280,7 @@ def test_weighted_quadratic(
self, estimator, effective_mean, effective_log_scale, weights
):
num_samples = _weighted_estimator_to_num_samples[estimator]
rng = jax.random.PRNGKey(1)
rng = jax.random.key(1)

mean = jnp.array(effective_mean, dtype=jnp.float32)
log_scale = jnp.array(effective_log_scale, dtype=jnp.float32)
Expand Down Expand Up @@ -342,7 +342,7 @@ def testNonPolynomialFunctionConsistencyWithPathwise(
self, effective_mean, effective_log_scale, function, coupling
):
num_samples = 10**5
rng = jax.random.PRNGKey(1)
rng = jax.random.key(1)
measure_rng, pathwise_rng = jax.random.split(rng)

mean = jnp.array(effective_mean, dtype=jnp.float32)
Expand Down Expand Up @@ -405,7 +405,7 @@ class MeasuredValuedEstimatorsTest(chex.TestCase):
@parameterized.parameters([True, False])
def test_raises_error_for_non_gaussian(self, coupling):
num_samples = 10**5
rng = jax.random.PRNGKey(1)
rng = jax.random.key(1)

function = lambda x: jnp.sum(x) ** 2

Expand Down
2 changes: 1 addition & 1 deletion optax/perturbations/_make_pert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def setUp(self):
super().setUp()
rng = np.random.RandomState(0)

self.rng_jax = jax.random.PRNGKey(0)
self.rng_jax = jax.random.key(0)
self.num_samples = 1_000
self.num_samples_small = 1_000
self.sigma = 0.5
Expand Down
6 changes: 3 additions & 3 deletions optax/schedules/_inject_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Tests for methods in `inject.py`."""

import functools
from typing import NamedTuple
from typing import NamedTuple, Union

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -171,7 +171,7 @@ def random_noise_optimizer(
key: chex.PRNGKey, scale: jax.Array
) -> base.GradientTransformation:
def init_fn(params_like: base.Params) -> tuple[chex.PRNGKey,
jax.Array | float]:
Union[jax.Array, float]]:
del params_like
nonlocal key, scale
return (key, scale)
Expand All @@ -180,7 +180,7 @@ def update_fn(
updates: base.Updates,
state: tuple[chex.PRNGKey, jax.Array],
params: None = None,
) -> tuple[base.Updates, tuple[chex.PRNGKey, jax.Array | float]]:
) -> tuple[base.Updates, tuple[chex.PRNGKey, Union[jax.Array, float]]]:
del params
key, scale = state
keyit = iter(random.split(key, len(jax.tree.leaves(updates)) + 1))
Expand Down
2 changes: 1 addition & 1 deletion optax/second_order/_hessian_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __call__(self, x):
return x

net = MLP()
self.parameters = net.init({'params': jax.random.PRNGKey(0)}, self.data)[
self.parameters = net.init({'params': jax.random.key(0)}, self.data)[
'params'
]

Expand Down
4 changes: 2 additions & 2 deletions optax/transforms/_accumulation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def test_multi_steps(self):
data = jnp.ones([batch_size, x_size])
loss = Loss()

params = loss.init({'params': jax.random.PRNGKey(0)}, data)['params']
params = loss.init({'params': jax.random.key(0)}, data)['params']

def loss_apply(params, data):
return loss.apply({'params': params}, data)
Expand Down Expand Up @@ -359,7 +359,7 @@ def compare_dtypes(tree1, tree2):
):
# Initialize parameters with current combination of dtypes
params = loss.init(
{'params': jax.random.PRNGKey(0)}, data, param_dtype=param_dtype
{'params': jax.random.key(0)}, data, param_dtype=param_dtype
)['params']
opt_state = opt.init(params)
ms_opt_state = ms_opt_init(params)
Expand Down
1 change: 1 addition & 0 deletions optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from optax.tree_utils._casting import tree_dtype
from optax.tree_utils._random import tree_random_like
from optax.tree_utils._random import tree_split_key_like
from optax.tree_utils._random import tree_unwrap_random_key_data
from optax.tree_utils._state_utils import NamedTupleKey
from optax.tree_utils._state_utils import tree_get
from optax.tree_utils._state_utils import tree_get_all_with_path
Expand Down
18 changes: 18 additions & 0 deletions optax/tree_utils/_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,21 @@ def tree_random_like(
target_tree,
keys_tree,
)


def tree_unwrap_random_key_data(input_tree: chex.ArrayTree) -> chex.ArrayTree:
"""Unwrap random.key objects in a tree for numerical comparison.
Args:
input_tree: a tree of arrays and random.key objects.
Returns:
a tree of arrays and random.key_data objects.
"""
def _unwrap_random_key_data(x):
if (isinstance(x, jax.Array)
and jax.dtypes.issubdtype(x.dtype, jax.dtypes.prng_key)):
return jax.random.key_data(x)
return x

return jax.tree.map(_unwrap_random_key_data, input_tree)
Loading

0 comments on commit b05d247

Please sign in to comment.