From e77d1e33c947e68edc5bc5d5338ab90320d019b7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Jan 2025 03:30:57 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../distributed_checkpoint_utils.py | 30 ++++++---- .../checkpoint_io/general_checkpoint_io.py | 2 +- .../hybrid_parallel_checkpoint_io.py | 60 ++++++++++++++----- .../test_dist_checkpointio.py | 14 ++++- 4 files changed, 77 insertions(+), 29 deletions(-) diff --git a/colossalai/checkpoint_io/distributed_checkpoint_utils.py b/colossalai/checkpoint_io/distributed_checkpoint_utils.py index f39c8d41dcd8..17d86161e61c 100644 --- a/colossalai/checkpoint_io/distributed_checkpoint_utils.py +++ b/colossalai/checkpoint_io/distributed_checkpoint_utils.py @@ -58,6 +58,7 @@ def dist_model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = destination[extra_state_key] = extra_state return destination + def load_state_dict_into_dist_model( model: nn.Module, state_dict: Dict, prefix: str = "", keep_vars: bool = False, strict: bool = False ): @@ -86,11 +87,12 @@ def load_state_dict_into_dist_model( extra_state.copy_(state_dict[extra_state_key]) return destination + def create_model_metadata( model: nn.Module, prefix: str = "", - tp_size = None, - tp_rank = None, + tp_size=None, + tp_rank=None, ): param_origin_shape = model.param_origin_shape model = model.unwrap() @@ -110,11 +112,12 @@ def create_model_metadata( partition_size = param.shape[tp_partition_dim] model_metadata[prefix + name]["offsets"][tp_partition_dim] = partition_size * tp_rank if tp_rank == tp_size - 1: - model_metadata[prefix + name]["lengths"][tp_partition_dim] = original_shape[ - tp_partition_dim - ] - (partition_size * (tp_size - 1)) + model_metadata[prefix + name]["lengths"][tp_partition_dim] = original_shape[tp_partition_dim] - ( + partition_size * (tp_size - 1) + ) return model_metadata + def save_metadata(model_metadata, metadata_file, checkpoint_file=None, total_size=None): metadata_dicts = { "checkpoint_version": "1.0", @@ -133,6 +136,7 @@ def save_metadata(model_metadata, metadata_file, checkpoint_file=None, total_siz with open(metadata_file, "w") as json_file: json.dump(metadata_dicts, json_file, indent=4) + def load_metadata(checkpoint: str): metadata_dict = {} for filename in os.listdir(checkpoint): @@ -197,6 +201,7 @@ def find_covering_shards(shards, target_offsets, target_lengths): assert total_lengths == global_shape return covering_shards + def extract_weight_from_shard_partial(shard, target_offsets, target_lengths): """ Extract the target range of weights from shard data, supporting partial overlap. @@ -233,6 +238,7 @@ def extract_weight_from_shard_partial(shard, target_offsets, target_lengths): target_weight = weight[tuple(slices)] return target_weight, target_slices + def assemble_tensor_from_shards_partial(shards, target_offsets, target_lengths, dtype): target_tensor = torch.zeros(target_lengths, dtype=dtype) @@ -310,7 +316,13 @@ def dist_model_sharder( def save_dist_unshard_model( - model: ModelWrapper, model_metadata: Dict, checkpoint: str, use_safetensors: bool, use_async: bool = False, dist_id = 0, pinned_state_dicts = None + model: ModelWrapper, + model_metadata: Dict, + checkpoint: str, + use_safetensors: bool, + use_async: bool = False, + dist_id=0, + pinned_state_dicts=None, ): """ Save model state dict to a single file with given checkpointing path. @@ -426,7 +438,7 @@ def save_dist_sharded_model( use_safetensors: bool = False, use_async: bool = False, dist_id: int = 0, - pinned_state_dicts = None, + pinned_state_dicts=None, ) -> None: """ Save sharded model checkpoint under the given checkpointing path. @@ -463,9 +475,7 @@ def save_dist_sharded_model( pinned_state_dicts = pinned_state_dicts[id(model)] else: pinned_state_dicts = None - state_dict_shard = dist_model_sharder( - model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts - ) + state_dict_shard = dist_model_sharder(model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts) weights_name, _ = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index d5ed5b848de3..c38958ee31b9 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -309,4 +309,4 @@ def load_sharded_model( ) def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 44c119eef6d5..dd1dd4258d9e 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -24,6 +24,13 @@ from colossalai.utils import get_current_device, get_non_persistent_buffers_set from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat +from .distributed_checkpoint_utils import ( + create_model_metadata, + is_pytorch_model_meta_dist_file, + load_dist_model, + save_dist_sharded_model, + save_dist_unshard_model, +) from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile from .utils import ( @@ -47,14 +54,6 @@ sharded_optimizer_loading_epilogue, ) -from .distributed_checkpoint_utils import ( - save_dist_sharded_model, - save_dist_unshard_model, - load_dist_model, - is_pytorch_model_meta_dist_file, - create_model_metadata -) - try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX except ImportError: @@ -244,9 +243,19 @@ def save_sharded_model( return dist_id = self.tp_size * self.pp_rank + self.tp_rank model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) - save_dist_sharded_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, prefix=prefix, size_per_shard=size_per_shard, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts) + save_dist_sharded_model( + model=model, + model_metadata=model_metadata, + checkpoint=checkpoint, + prefix=prefix, + size_per_shard=size_per_shard, + use_safetensors=use_safetensors, + use_async=use_async, + dist_id=dist_id, + pinned_state_dicts=self.pinned_state_dicts, + ) return - + model = model.unwrap() if os.path.isfile(checkpoint): @@ -394,9 +403,15 @@ def load_sharded_model( if is_pytorch_model_meta_dist_file(checkpoint_index_file): model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) - load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint_index_file, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads) + load_dist_model( + model=model, + model_metadata=model_metadata, + checkpoint=checkpoint_index_file, + low_cpu_mem_mode=low_cpu_mem_mode, + num_threads=num_threads, + ) return - + model_before_wrapping = model # backup for model before wrapping model = model.unwrap() @@ -792,9 +807,17 @@ def save_unsharded_model( if self.dp_rank != 0 and self.sp_rank != 0: return dist_id = self.tp_size * self.pp_rank + self.tp_rank - save_dist_unshard_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts) + save_dist_unshard_model( + model=model, + model_metadata=model_metadata, + checkpoint=checkpoint, + use_safetensors=use_safetensors, + use_async=use_async, + dist_id=dist_id, + pinned_state_dicts=self.pinned_state_dicts, + ) return - + model = model.unwrap() if self.dp_rank != 0: return @@ -867,7 +890,13 @@ def load_unsharded_model( for filename in os.listdir(checkpoint): if is_pytorch_model_meta_dist_file(filename): model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) - load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads) + load_dist_model( + model=model, + model_metadata=model_metadata, + checkpoint=checkpoint, + low_cpu_mem_mode=low_cpu_mem_mode, + num_threads=num_threads, + ) return strict = False @@ -1099,7 +1128,6 @@ def gather_from_sharded_optimizer_state( dist.all_gather(gather_tensor, v, group=dp_group) v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) - # Then gather TP shards. partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size) if partition_dim is not None: diff --git a/tests/test_checkpoint_io/test_dist_checkpointio.py b/tests/test_checkpoint_io/test_dist_checkpointio.py index 08354c214a62..850a10c17ce6 100644 --- a/tests/test_checkpoint_io/test_dist_checkpointio.py +++ b/tests/test_checkpoint_io/test_dist_checkpointio.py @@ -79,7 +79,12 @@ def _preprocess_data(data): model_ckpt_path_0 = f"{tempdir}/model_0" booster_0.save_model( - model_0, model_ckpt_path_0, shard=shard, gather_dtensor=True, size_per_shard=size_per_shard, use_async=use_async + model_0, + model_ckpt_path_0, + shard=shard, + gather_dtensor=True, + size_per_shard=size_per_shard, + use_async=use_async, ) booster_0.checkpoint_io._sync_d2h() booster_0.checkpoint_io._sync_io() @@ -96,7 +101,12 @@ def _preprocess_data(data): model_ckpt_path_1 = f"{tempdir}/model_1" booster_1.save_model( - model_1, model_ckpt_path_1, shard=shard, gather_dtensor=True, size_per_shard=size_per_shard, use_async=use_async + model_1, + model_ckpt_path_1, + shard=shard, + gather_dtensor=True, + size_per_shard=size_per_shard, + use_async=use_async, ) booster_1.checkpoint_io._sync_d2h() booster_1.checkpoint_io._sync_io()