Skip to content

Commit

Permalink
allow loading partial checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Jan 21, 2025
1 parent e9bbd98 commit 8102b96
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def initial_state(
mesh=self.device_mesh,
subpath="model",
do_load=True,
allow_partial=self.config.allow_partial_checkpoint,
)()
model_init = jax.tree_util.Partial(lambda m: m, loaded_model)

Expand All @@ -369,6 +370,7 @@ def init_state_and_model(model_init, training_key):
mesh=self.device_mesh,
is_checkpointed=saveable_train_state,
do_load=load_checkpoint,
allow_partial=self.config.allow_partial_checkpoint,
)(model_init, training_key)

return state
Expand Down Expand Up @@ -629,6 +631,9 @@ class TrainerConfig:
load_checkpoint_path: Optional[str] = None
"""can be a parent (to find latest) or a specific checkpoint. if None, will set to checkpointer.base_path."""
initialize_from: Optional[str] = None # Levanter trainer checkpoint to initialize from
allow_partial_checkpoint: bool = False
"""If True, we allow loading a checkpoint that doesn't have all the parameters in the model.
Missing parameters are initialized from the model_init function."""

jax_config: Mapping[str, JsonAtom] = field(
default_factory=lambda: copy.deepcopy(DEFAULT_JAX_CONFIG)
Expand Down

0 comments on commit 8102b96

Please sign in to comment.