Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 20, 2025
1 parent 51c208c commit e77d1e3
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 29 deletions.
30 changes: 20 additions & 10 deletions colossalai/checkpoint_io/distributed_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def dist_model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool =
destination[extra_state_key] = extra_state
return destination


def load_state_dict_into_dist_model(
model: nn.Module, state_dict: Dict, prefix: str = "", keep_vars: bool = False, strict: bool = False
):
Expand Down Expand Up @@ -86,11 +87,12 @@ def load_state_dict_into_dist_model(
extra_state.copy_(state_dict[extra_state_key])
return destination


def create_model_metadata(
model: nn.Module,
prefix: str = "",
tp_size = None,
tp_rank = None,
tp_size=None,
tp_rank=None,
):
param_origin_shape = model.param_origin_shape
model = model.unwrap()
Expand All @@ -110,11 +112,12 @@ def create_model_metadata(
partition_size = param.shape[tp_partition_dim]
model_metadata[prefix + name]["offsets"][tp_partition_dim] = partition_size * tp_rank
if tp_rank == tp_size - 1:
model_metadata[prefix + name]["lengths"][tp_partition_dim] = original_shape[
tp_partition_dim
] - (partition_size * (tp_size - 1))
model_metadata[prefix + name]["lengths"][tp_partition_dim] = original_shape[tp_partition_dim] - (
partition_size * (tp_size - 1)
)
return model_metadata


def save_metadata(model_metadata, metadata_file, checkpoint_file=None, total_size=None):
metadata_dicts = {
"checkpoint_version": "1.0",
Expand All @@ -133,6 +136,7 @@ def save_metadata(model_metadata, metadata_file, checkpoint_file=None, total_siz
with open(metadata_file, "w") as json_file:
json.dump(metadata_dicts, json_file, indent=4)


def load_metadata(checkpoint: str):
metadata_dict = {}
for filename in os.listdir(checkpoint):
Expand Down Expand Up @@ -197,6 +201,7 @@ def find_covering_shards(shards, target_offsets, target_lengths):
assert total_lengths == global_shape
return covering_shards


def extract_weight_from_shard_partial(shard, target_offsets, target_lengths):
"""
Extract the target range of weights from shard data, supporting partial overlap.
Expand Down Expand Up @@ -233,6 +238,7 @@ def extract_weight_from_shard_partial(shard, target_offsets, target_lengths):
target_weight = weight[tuple(slices)]
return target_weight, target_slices


def assemble_tensor_from_shards_partial(shards, target_offsets, target_lengths, dtype):
target_tensor = torch.zeros(target_lengths, dtype=dtype)

Expand Down Expand Up @@ -310,7 +316,13 @@ def dist_model_sharder(


def save_dist_unshard_model(
model: ModelWrapper, model_metadata: Dict, checkpoint: str, use_safetensors: bool, use_async: bool = False, dist_id = 0, pinned_state_dicts = None
model: ModelWrapper,
model_metadata: Dict,
checkpoint: str,
use_safetensors: bool,
use_async: bool = False,
dist_id=0,
pinned_state_dicts=None,
):
"""
Save model state dict to a single file with given checkpointing path.
Expand Down Expand Up @@ -426,7 +438,7 @@ def save_dist_sharded_model(
use_safetensors: bool = False,
use_async: bool = False,
dist_id: int = 0,
pinned_state_dicts = None,
pinned_state_dicts=None,
) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
Expand Down Expand Up @@ -463,9 +475,7 @@ def save_dist_sharded_model(
pinned_state_dicts = pinned_state_dicts[id(model)]
else:
pinned_state_dicts = None
state_dict_shard = dist_model_sharder(
model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts
)
state_dict_shard = dist_model_sharder(model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts)
weights_name, _ = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)

Expand Down
2 changes: 1 addition & 1 deletion colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,4 +309,4 @@ def load_sharded_model(
)

def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None:
raise NotImplementedError
raise NotImplementedError
60 changes: 44 additions & 16 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat

from .distributed_checkpoint_utils import (
create_model_metadata,
is_pytorch_model_meta_dist_file,
load_dist_model,
save_dist_sharded_model,
save_dist_unshard_model,
)
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
Expand All @@ -47,14 +54,6 @@
sharded_optimizer_loading_epilogue,
)

from .distributed_checkpoint_utils import (
save_dist_sharded_model,
save_dist_unshard_model,
load_dist_model,
is_pytorch_model_meta_dist_file,
create_model_metadata
)

try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
Expand Down Expand Up @@ -244,9 +243,19 @@ def save_sharded_model(
return
dist_id = self.tp_size * self.pp_rank + self.tp_rank
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
save_dist_sharded_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, prefix=prefix, size_per_shard=size_per_shard, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts)
save_dist_sharded_model(
model=model,
model_metadata=model_metadata,
checkpoint=checkpoint,
prefix=prefix,
size_per_shard=size_per_shard,
use_safetensors=use_safetensors,
use_async=use_async,
dist_id=dist_id,
pinned_state_dicts=self.pinned_state_dicts,
)
return

model = model.unwrap()

if os.path.isfile(checkpoint):
Expand Down Expand Up @@ -394,9 +403,15 @@ def load_sharded_model(

if is_pytorch_model_meta_dist_file(checkpoint_index_file):
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint_index_file, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads)
load_dist_model(
model=model,
model_metadata=model_metadata,
checkpoint=checkpoint_index_file,
low_cpu_mem_mode=low_cpu_mem_mode,
num_threads=num_threads,
)
return

model_before_wrapping = model # backup for model before wrapping
model = model.unwrap()

Expand Down Expand Up @@ -792,9 +807,17 @@ def save_unsharded_model(
if self.dp_rank != 0 and self.sp_rank != 0:
return
dist_id = self.tp_size * self.pp_rank + self.tp_rank
save_dist_unshard_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts)
save_dist_unshard_model(
model=model,
model_metadata=model_metadata,
checkpoint=checkpoint,
use_safetensors=use_safetensors,
use_async=use_async,
dist_id=dist_id,
pinned_state_dicts=self.pinned_state_dicts,
)
return

model = model.unwrap()
if self.dp_rank != 0:
return
Expand Down Expand Up @@ -867,7 +890,13 @@ def load_unsharded_model(
for filename in os.listdir(checkpoint):
if is_pytorch_model_meta_dist_file(filename):
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads)
load_dist_model(
model=model,
model_metadata=model_metadata,
checkpoint=checkpoint,
low_cpu_mem_mode=low_cpu_mem_mode,
num_threads=num_threads,
)
return

strict = False
Expand Down Expand Up @@ -1099,7 +1128,6 @@ def gather_from_sharded_optimizer_state(
dist.all_gather(gather_tensor, v, group=dp_group)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)


# Then gather TP shards.
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
if partition_dim is not None:
Expand Down
14 changes: 12 additions & 2 deletions tests/test_checkpoint_io/test_dist_checkpointio.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,12 @@ def _preprocess_data(data):
model_ckpt_path_0 = f"{tempdir}/model_0"

booster_0.save_model(
model_0, model_ckpt_path_0, shard=shard, gather_dtensor=True, size_per_shard=size_per_shard, use_async=use_async
model_0,
model_ckpt_path_0,
shard=shard,
gather_dtensor=True,
size_per_shard=size_per_shard,
use_async=use_async,
)
booster_0.checkpoint_io._sync_d2h()
booster_0.checkpoint_io._sync_io()
Expand All @@ -96,7 +101,12 @@ def _preprocess_data(data):

model_ckpt_path_1 = f"{tempdir}/model_1"
booster_1.save_model(
model_1, model_ckpt_path_1, shard=shard, gather_dtensor=True, size_per_shard=size_per_shard, use_async=use_async
model_1,
model_ckpt_path_1,
shard=shard,
gather_dtensor=True,
size_per_shard=size_per_shard,
use_async=use_async,
)
booster_1.checkpoint_io._sync_d2h()
booster_1.checkpoint_io._sync_io()
Expand Down

0 comments on commit e77d1e3

Please sign in to comment.