From 8102b961b466ad1473525ab607711c68e584d001 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 21 Jan 2025 14:36:16 -0800 Subject: [PATCH] allow loading partial checkpoints --- src/levanter/trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 82f32422a..7984e59b7 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -343,6 +343,7 @@ def initial_state( mesh=self.device_mesh, subpath="model", do_load=True, + allow_partial=self.config.allow_partial_checkpoint, )() model_init = jax.tree_util.Partial(lambda m: m, loaded_model) @@ -369,6 +370,7 @@ def init_state_and_model(model_init, training_key): mesh=self.device_mesh, is_checkpointed=saveable_train_state, do_load=load_checkpoint, + allow_partial=self.config.allow_partial_checkpoint, )(model_init, training_key) return state @@ -629,6 +631,9 @@ class TrainerConfig: load_checkpoint_path: Optional[str] = None """can be a parent (to find latest) or a specific checkpoint. if None, will set to checkpointer.base_path.""" initialize_from: Optional[str] = None # Levanter trainer checkpoint to initialize from + allow_partial_checkpoint: bool = False + """If True, we allow loading a checkpoint that doesn't have all the parameters in the model. + Missing parameters are initialized from the model_init function.""" jax_config: Mapping[str, JsonAtom] = field( default_factory=lambda: copy.deepcopy(DEFAULT_JAX_CONFIG)