Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replaced seed with key #1167

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open

Conversation

Tomas542
Copy link

@Tomas542 Tomas542 commented Jan 6, 2025

Replaced all seed values with key for uniformity of style with jax.random

@Tomas542
Copy link
Author

Tomas542 commented Jan 6, 2025

@rdyro 3 month later, but I've finished it)

@Tomas542
Copy link
Author

Tomas542 commented Jan 6, 2025

Also should close issue #1137 after PR got accepted

@carlosgmartin
Copy link
Contributor

@Tomas542 JAX now recommends using random.key over random.PRNGKey, which will eventually be deprecated. I suggest editing this PR to use the former.

@Tomas542
Copy link
Author

Tomas542 commented Jan 8, 2025

@carlosgmartin yes, I know, but there are some problems with this in tests where NumPy checks for type compatibility. The new keys are of type key<fry> or something like that. It is possible iteratively just replace some of the old styles with the new ones without rewriting the tests if necessary. Tell me, I'll do it

P.S. Accidentally, I've made a mistake in hungarian_algorithm_test, where there were no such problem.

@rdyro rdyro self-requested a review January 8, 2025 19:38
@rdyro
Copy link
Collaborator

rdyro commented Jan 8, 2025

The changes look great so far!

@carlosgmartin is right that we should move to random.key and you're right @Tomas542 that it breaks the current hyperparameter validation funtion in optax.

Let me see if we can fix the optax hyperparameter validation before this PR so that we can stick to random.key

Copy link
Collaborator

@rdyro rdyro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, the changes look great! The PRNGKey fix is in now, so can you rebase and use jax.random.key in place of PRNGKey? The exception is chex.PRNGKey which should stay as is, thanks!

@@ -30,7 +30,7 @@ class HungarianAlgorithmTest(parameterized.TestCase):
m=[0, 1, 2, 4, 8, 16],
)
def test_hungarian_algorithm(self, n, m):
key = jrd.key(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to change to PRNGKey now

@@ -91,7 +91,7 @@ def test_hungarian_algorithm(self, n, m):
m=[0, 1, 2, 4],
)
def test_hungarian_algorithm_vmap(self, k, n, m):
key = jrd.key(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to change to PRNGKey now

@@ -106,7 +106,7 @@ def test_hungarian_algorithm_vmap(self, k, n, m):
assert j.shape == (k, r)

def test_hungarian_algorithm_jit(self):
key = jrd.key(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to change to PRNGKey now

l2_norm_clip=jnp.finfo(jnp.float32).max, noise_multiplier=0.0, seed=0
l2_norm_clip=jnp.finfo(jnp.float32).max,
noise_multiplier=0.0,
key=jrd.PRNGKey(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jrd.key(0) please instead of PRNGKey

l2_norm_clip=l2_norm_clip, noise_multiplier=0.0, seed=42
l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0,
key=jrd.PRNGKey(42)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jrd.key(0) please instead of PRNGKey

'opt_kwargs': {'learning_rate': 1.0, 'eta': 1e-4},
'opt_kwargs': {
'learning_rate': 1.0,
'key': jax.random.PRNGKey(0),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax.random.key(0) please instead of PRNGKey

) -> base.GradientTransformation:
"""Aggregates gradients based on the DPSGD algorithm.

Args:
l2_norm_clip: maximum L2 norm of the per-example gradients.
noise_multiplier: ratio of standard deviation to the clipping norm.
seed: initial seed used for the jax.random.PRNGKey
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax.random.key(0) please instead of PRNGKey

if key is None:
raise ValueError(
"differentially_private_aggregate optimizer requires specifying key: "
"differentially_private_aggregate(..., key=jax.random.PRNGKey(0))"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax.random.key(0) please instead of PRNGKey

if key is None:
raise ValueError(
"dpsgd optimizer requires specifying key: "
"dpsgd(..., key=jax.random.PRNGKey(0))"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax.random.key(0) please instead of PRNGKey

seed = 314
noise = _adding.add_noise(eta, gamma, seed)
noise_unit = _adding.add_noise(1.0, 0.0, seed)
key = jax.random.PRNGKey(314)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax.random.key(0) please instead of PRNGKey

@Tomas542
Copy link
Author

Hm, could someone help with docs? I can't understand what's the problem with tests. I know that CUDA Error and AttributeError: maybe_update are ok. But now there are moded-gpt and triplet loss errors, and tests fails, but I didn't change any of it.

@rdyro
Copy link
Collaborator

rdyro commented Jan 15, 2025

Sorry, the doc failures is on our side (our CI is briefly broken :( ), this should be fixed soon and you can rerun your tests

@Tomas542
Copy link
Author

Sorry, the doc failures is on our side (our CI is briefly broken :( ), this should be fixed soon and you can rerun your tests

ok, but what about error in state utils?
FAILED test_venv/lib/python3.9/site-packages/optax/tree_utils/_state_utils_test.py::StateUtilsTest::test_tree_get - TypeError: equal does not accept dtypes key<fry>, int32. Should I skip this file and return old style PRNGKey? You've changed it, but tests still don't pass

@rdyro
Copy link
Collaborator

rdyro commented Jan 16, 2025

The new JAX key implementation key is not directly comparable, for now, to compare states containing random keys you need this new tree_util in main (you'll need to rebase your branch):

then in _state_utils_test.py

from optax.tree_utils import _random

...

      chex.assert_trees_all_equal(
          _random.tree_unwrap_random_key_data(noise_state), 
          _random.tree_unwrap_random_key_data(expected_result)
      )

...

      chex.assert_trees_all_equal(
          _random.tree_unwrap_random_key_data(new_state), 
          _random.tree_unwrap_random_key_data(expected_result)
      )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants