diff --git a/src/levanter/trainer_state.py b/src/levanter/trainer_state.py index 5eb9ed788..d0a8f9858 100644 --- a/src/levanter/trainer_state.py +++ b/src/levanter/trainer_state.py @@ -53,11 +53,12 @@ class TrainerState(eqx.Module, Generic[M]): optimizer: GradientTransformation = eqx.field(static=True) opt_state: OptState training_key: PRNGKeyArray - model_averaging: ModelAveraging[M] | None = None is_trainable: FilterTree = eqx.field(static=True) mp: jmp.Policy = eqx.field(static=True) + model_averaging: ModelAveraging[M] | None = None + @property def int_step(self) -> int: """