diff --git a/lingvo/jax/test_utils.py b/lingvo/jax/test_utils.py index 01ef355c8..590a1b083 100644 --- a/lingvo/jax/test_utils.py +++ b/lingvo/jax/test_utils.py @@ -112,14 +112,14 @@ def to_tf(x: Any) -> JTensor: return tf.nest.map_structure(to_tf, x_nmap) -def apply(layer, layer_vars, method, *args, context_p=None, seed=123, **kwags): +def apply(layer, layer_vars, method, *args, context_p=None, seed=123, **kwargs): prng_key = jax.random.PRNGKey(seed=seed) with base_layer.JaxContext.new_context( params=context_p, prng_key=prng_key, global_step=jnp.array(0, dtype=jnp.uint32)) as jax_context: jax_context.bind(layer, layer.vars_to_flax_vars(layer_vars)) - return method(*args, **kwags) + return method(*args, **kwargs) def replace_jax_transformer_ffwd_vars_to_tf(