Skip to content

Commit

Permalink
Partial checkpoints (#861)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Jan 22, 2025
1 parent 5581062 commit 6ad37c9
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 37 deletions.
52 changes: 26 additions & 26 deletions config/harness/eval_llama3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,32 @@ eval_harness:
task_spec:
- task: commonsense_qa # 5-way multiple-choice questions based on common-sense, everyday scenarios
num_fewshot: 10
- task: agieval_lsat_ar # 3-shot tests in legal domain
num_fewshot: 3
- task: arc_easy # 10-shot, four-way MCQ questions involving grade 3-9 basic science
num_fewshot: 10
- task: arc_challenge # a (harder) version of arc_easy
num_fewshot: 10
- task: boolq # answer yes/no questions based on a passage
num_fewshot: 10
- task: copa # use causal reasoning to predict the correct outcome of a given scenario
num_fewshot: 0
- task: hellaswag # 4-way multiple choice commonsense reasoning dataset
num_fewshot: 0
task_alias: hellaswag_0shot
- task: hellaswag # 4-way multiple choice commonsense reasoning dataset
num_fewshot: 10
task_alias: hellaswag_10shot
- task: lambada # predict the endings of text passages
num_fewshot: 0
- task: openbookqa # 4-way multiple choice question answering task that requires multi-step reasoning
num_fewshot: 0
- task: piqa # answer questions based on a passage
num_fewshot: 10
- task: wsc273 # Winograd Schema Challenge
num_fewshot: 0
- task: winogrande # Winograd challenge, extended to more domains
num_fewshot: 0
# - task: agieval_lsat_ar # 3-shot tests in legal domain
# num_fewshot: 3
# - task: arc_easy # 10-shot, four-way MCQ questions involving grade 3-9 basic science
# num_fewshot: 10
# - task: arc_challenge # a (harder) version of arc_easy
# num_fewshot: 10
# - task: boolq # answer yes/no questions based on a passage
# num_fewshot: 10
# - task: copa # use causal reasoning to predict the correct outcome of a given scenario
# num_fewshot: 0
# - task: hellaswag # 4-way multiple choice commonsense reasoning dataset
# num_fewshot: 0
# task_alias: hellaswag_0shot
# - task: hellaswag # 4-way multiple choice commonsense reasoning dataset
# num_fewshot: 10
# task_alias: hellaswag_10shot
# - task: lambada # predict the endings of text passages
# num_fewshot: 0
# - task: openbookqa # 4-way multiple choice question answering task that requires multi-step reasoning
# num_fewshot: 0
# - task: piqa # answer questions based on a passage
# num_fewshot: 10
# - task: wsc273 # Winograd Schema Challenge
# num_fewshot: 0
# - task: winogrande # Winograd challenge, extended to more domains
# num_fewshot: 0
# requires generation
## - task: squadv2 # reading comprehension benchmark
# num_fewshot: 10
Expand Down
11 changes: 10 additions & 1 deletion src/levanter/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def load_checkpoint(
discover_latest=True,
axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None,
mesh: Optional[jax.sharding.Mesh] = None,
allow_partial: bool = False,
) -> M:
"""
Load a checkpoint from a given path. If discover_latest is True, then the latest checkpoint
Expand All @@ -367,6 +368,7 @@ def load_checkpoint(
discover_latest: whether to discover the latest checkpoint in the given path
axis_mapping: the axis mapping to use for loading the checkpoint
mesh: the mesh to use for loading the checkpoint
allow_partial: if True, allow partial loading of the checkpoint. If False, all parameters must be present in the checkpoint.
Returns:
the loaded checkpoint, with the same structure as the exemplar tree
Expand Down Expand Up @@ -397,7 +399,9 @@ def load_checkpoint(

ser, non_ser = equinox.partition(tree, is_jax_array_like)
try:
tree = tree_deserialize_leaves_tensorstore(checkpoint_path, ser, axis_mapping=axis_mapping, mesh=mesh)
tree = tree_deserialize_leaves_tensorstore(
checkpoint_path, ser, axis_mapping=axis_mapping, mesh=mesh, allow_missing=allow_partial
)
tree = equinox.combine(tree, non_ser)
return tree
except: # noqa
Expand Down Expand Up @@ -445,6 +449,7 @@ def load_checkpoint_or_initialize(
donate_args: FilterSpec = True,
donate_kwargs: Optional[FilterSpec] = None,
do_load: Optional[bool] = None,
allow_partial: bool = False,
) -> Callable[Sig, M]:
"""
Load a checkpoint from a given path. If discover_latest is True, then the latest checkpoint
Expand Down Expand Up @@ -476,6 +481,7 @@ def load_checkpoint_or_initialize(
donate_args: a FilterSpec that specifies which arguments to donate to init_fn if we need to initialize
donate_kwargs: a FilterSpec that specifies which kwargs to donate to init_fn if we need to initialize
do_load: if True, always load the checkpoint. If False, always initialize. If None, load if the checkpoint exists, otherwise initialize
allow_partial: if True, allow partial loading of the checkpoint. If False, all parameters must be present in the checkpoint.
Returns:
A function that takes the same arguments as init_fn, but loads the checkpoint if it exists and returns the
Expand All @@ -493,6 +499,8 @@ def load_checkpoint_or_initialize(
)
def init_and_merge(state, *args, **kwargs):
init_state = init_fn(*args, **kwargs)
# remove all ShapeDTypeStructs from the state
state = equinox.filter(state, lambda x: not isinstance(x, jax.ShapeDtypeStruct))
return equinox.combine(state, init_state)

def load_or_init(*args, **kwargs):
Expand All @@ -516,6 +524,7 @@ def load_or_init(*args, **kwargs):
discover_latest=discover_latest,
axis_mapping=axis_mapping,
mesh=mesh,
allow_partial=allow_partial,
)
except FileNotFoundError:
if do_load is True:
Expand Down
51 changes: 42 additions & 9 deletions src/levanter/tensorstore_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from haliax.partitioning import ResourceMapping
from haliax.util import is_named_array

from levanter.utils import jax_utils
from levanter.utils import fsspec_utils, jax_utils


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -119,6 +119,8 @@ def tree_deserialize_leaves_tensorstore(
axis_mapping: Optional[ResourceMapping] = None,
mesh: Optional[Mesh] = None,
manager: Optional[array_ser.GlobalAsyncCheckpointManager] = None,
*,
allow_missing: bool = False,
):
"""
Deserializes a PyTree of Arrays and NamedArrays from a Tensorstore checkpoint, returning a pytree with the same shape
Expand All @@ -132,6 +134,7 @@ def tree_deserialize_leaves_tensorstore(
axis_mapping: optional, the axis mapping for the NamedArrays (if they are not yet arrays)
mesh: optional, the mesh for the NamedArrays (if they are not yet arrays)
manager: optional, the checkpoint manager to use. If not provided, a new one will be created
allow_missing: if True, missing leaves will be allowed and kept as-is
Returns:
A pytree with the same shape as the exemplar pytree, but with the arrays deserialized from the checkpoint
Expand All @@ -151,26 +154,56 @@ def tree_deserialize_leaves_tensorstore(
shardings_leaves, shardings_structure = jtu.tree_flatten(shardings, is_leaf=_is_named_or_none)

assert len(shardings_leaves) == len(paths)

# ok, so, jax really doesn't want any Nones in the leaves here, so we need to temporarily partition the pytree
real_indices = [i for i, x in enumerate(shardings_leaves) if x is not None]
real_leaves = [x for x in shardings_leaves if x is not None]
real_paths = [paths[i] for i in real_indices]
paths_to_load = []
indices_to_load = []
shardings_to_load = []

missing_paths = []
missing_indices = []

for i in real_indices:
path = paths[i]

if not fsspec_utils.exists(path):
missing_paths.append(path)
missing_indices.append(i)
continue

assert len(real_leaves) == len(real_paths), f"{len(real_leaves)} != {len(real_paths)}"
paths_to_load.append(path)
indices_to_load.append(i)
shardings_to_load.append(shardings_leaves[i])

# ok now check for missing paths
if missing_paths:
if not allow_missing:
raise FileNotFoundError(f"Missing paths: {missing_paths}")
else:
to_log = f"Several keys were missing from the checkpoint directory {checkpoint_dir}:"
leaf_paths = jtu.tree_leaves(leaf_key_paths, is_leaf=_is_named_or_none)
for i in missing_indices:
to_log += f"\n - {leaf_paths[i]}"
logger.warning(to_log)

deser_leaves = manager.deserialize_with_paths(shardings=shardings_to_load, paths=paths_to_load)

deser_leaves = manager.deserialize_with_paths(shardings=real_leaves, paths=real_paths)
# now we need to recreate the original structure

out_leaves = [None] * len(shardings_leaves)
for i, x in zip(real_indices, deser_leaves):
out_leaves = jax.tree_leaves(pytree, is_leaf=_is_named_or_none)
assert len(out_leaves) == len(shardings_leaves)
# out_leaves = [None] * len(shardings_leaves)
for i, x in zip(indices_to_load, deser_leaves):
out_leaves[i] = x

deser_arrays = jtu.tree_unflatten(shardings_structure, out_leaves)

# deser_arrays only has arrays, but we need named arrays for at least some.
# deser_arrays only has arrays for the deserialized arrays, but we need named arrays for at least some.
# The original pytree has the structure we want, so we'll use that to rebuild the named arrays
def _rebuild_named_array(like, array):
if is_named_array(array):
return array

if is_named_array(like):
return hax.NamedArray(array, like.axes)
else:
Expand Down
5 changes: 5 additions & 0 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def initial_state(
mesh=self.device_mesh,
subpath="model",
do_load=True,
allow_partial=self.config.allow_partial_checkpoint,
)()
model_init = jax.tree_util.Partial(lambda m: m, loaded_model)

Expand All @@ -369,6 +370,7 @@ def init_state_and_model(model_init, training_key):
mesh=self.device_mesh,
is_checkpointed=saveable_train_state,
do_load=load_checkpoint,
allow_partial=self.config.allow_partial_checkpoint,
)(model_init, training_key)

return state
Expand Down Expand Up @@ -629,6 +631,9 @@ class TrainerConfig:
load_checkpoint_path: Optional[str] = None
"""can be a parent (to find latest) or a specific checkpoint. if None, will set to checkpointer.base_path."""
initialize_from: Optional[str] = None # Levanter trainer checkpoint to initialize from
allow_partial_checkpoint: bool = False
"""If True, we allow loading a checkpoint that doesn't have all the parameters in the model.
Missing parameters are initialized from the model_init function."""

jax_config: Mapping[str, JsonAtom] = field(
default_factory=lambda: copy.deepcopy(DEFAULT_JAX_CONFIG)
Expand Down
37 changes: 37 additions & 0 deletions tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,3 +329,40 @@ def init_fn(key):
jax.tree_util.tree_leaves(arrays_only(eqx.filter(loaded, is_checkpointed))),
jax.tree_util.tree_leaves(arrays_only(eqx.filter(model1, is_checkpointed))),
)


def test_load_from_checkpoint_allows_partial_checkpoints():
In = Axis("in", 2)
Out = Axis("out", 1)

class MyModule(eqx.Module):
a: hax.NamedArray
b: hax.NamedArray | None

def init_fn(key, use_b):
k_a, k_b = jax.random.split(key)
return MyModule(a=hax.random.normal(k_a, (In, Out)), b=hax.random.normal(k_b, (In, Out)) if use_b else None)

k0 = jax.random.PRNGKey(0)
k1 = jax.random.PRNGKey(1)

model0 = init_fn(k0, False)
model1 = init_fn(k1, True)

is_checkpointed = True

with jax.sharding.Mesh(jax.devices(), ("devices",)), tempfile.TemporaryDirectory() as tmpdir:

save_checkpoint(eqx.filter(model0, is_checkpointed), step=0, checkpoint_path=tmpdir)

loaded = load_checkpoint_or_initialize(
init_fn,
tmpdir,
is_checkpointed=is_checkpointed,
allow_partial=True,
)(k1, True)

assert not any(jax.tree_util.tree_leaves(eqx.filter(loaded, lambda x: isinstance(x, ShapeDtypeStruct))))
assert hax.all(hax.equal(loaded.a, model0.a))
assert loaded.b is not None
assert hax.all(hax.equal(loaded.b, model1.b))
21 changes: 20 additions & 1 deletion tests/test_tensorstore_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,24 @@ class MyModule(eqx.Module):
m3 = MyModule(a=hax.zeros(A), b=hax.ones(A))
with TemporaryDirectory() as tmpdir:
tree_serialize_leaves_tensorstore(tmpdir, m2)
with pytest.raises(ValueError):
with pytest.raises(FileNotFoundError):
tree_deserialize_leaves_tensorstore(tmpdir, m3)


def test_tensorstore_ok_with_missing():
mesh = jax.sharding.Mesh(jax.devices(), ("device",))
with mesh:
A = hax.Axis("A", 10)

class MyModule(eqx.Module):
a: Any
b: Any

m = MyModule(a=None, b=hax.zeros(A))
m2 = MyModule(a=hax.full(A, 4), b=hax.ones(A))

with TemporaryDirectory() as tmpdir:
tree_serialize_leaves_tensorstore(tmpdir, m)
m3 = tree_deserialize_leaves_tensorstore(tmpdir, m2, allow_missing=True)
assert hax.all(m3.a == hax.full(A, 4))
assert hax.all(m3.b == hax.zeros(A))

0 comments on commit 6ad37c9

Please sign in to comment.