Skip to content

Commit

Permalink
add mu_dtype guard
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ committed May 1, 2024
1 parent ae6c90c commit b942d60
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/levanter/optim/schedulefree_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _adam_gradient_transform(
A `GradientTransformation` object.
"""

mu_dtype = jax.canonicalize_dtype(mu_dtype)
mu_dtype = jax.canonicalize_dtype(mu_dtype) if mu_dtype is not None else None

def init_fn(params):
z = jax.tree_util.tree_map(jnp.copy, params) # schedule-free z
Expand Down

0 comments on commit b942d60

Please sign in to comment.