From b942d607c476bc486ffe8d168661be3783540014 Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Wed, 1 May 2024 15:42:34 -0700 Subject: [PATCH] add mu_dtype guard --- src/levanter/optim/schedulefree_adam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/optim/schedulefree_adam.py b/src/levanter/optim/schedulefree_adam.py index bf03f2cce..231db6396 100644 --- a/src/levanter/optim/schedulefree_adam.py +++ b/src/levanter/optim/schedulefree_adam.py @@ -86,7 +86,7 @@ def _adam_gradient_transform( A `GradientTransformation` object. """ - mu_dtype = jax.canonicalize_dtype(mu_dtype) + mu_dtype = jax.canonicalize_dtype(mu_dtype) if mu_dtype is not None else None def init_fn(params): z = jax.tree_util.tree_map(jnp.copy, params) # schedule-free z