From 9984a643e91c3b279bb03282d5e0eec50a8464eb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Jan 2025 07:18:14 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../distributed_checkpoint_utils.py | 14 ++++------ .../hybrid_parallel_checkpoint_io.py | 28 +++++++++---------- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/colossalai/checkpoint_io/distributed_checkpoint_utils.py b/colossalai/checkpoint_io/distributed_checkpoint_utils.py index a56386a1ffa3..286322486efb 100644 --- a/colossalai/checkpoint_io/distributed_checkpoint_utils.py +++ b/colossalai/checkpoint_io/distributed_checkpoint_utils.py @@ -1,5 +1,6 @@ import json import os +from contextlib import contextmanager from typing import Dict import torch @@ -9,12 +10,8 @@ from colossalai.interface import ModelWrapper from colossalai.shardformer.layer.parallel_module import ParallelModule -from contextlib import contextmanager -from .utils import ( - load_state_dict, - search_tp_partition_dim, -) +from .utils import load_state_dict, search_tp_partition_dim MODEL_META_PREFIX = "pytorch_model-meta-dist-" MODEL_WEIGHT_PREFIX = "pytorch_model-dist-" @@ -34,8 +31,7 @@ def RestoreDefaultStateDictBehavior(model): yield model finally: for module, original_method in original_methods.items(): - module._save_to_state_dict, module._load_from_state_dict = original_method - + module._save_to_state_dict, module._load_from_state_dict = original_method def create_model_metadata( @@ -260,12 +256,14 @@ def load_dist_model( return state_dict + def get_dist_files_name(weights_name, dist_id): weights_name = weights_name.replace(".bin", f"-dist-{dist_id:05d}-shard.bin") weights_name = weights_name.replace(".safetensors", f"-dist-{dist_id:05d}-shard.safetensors") return weights_name + def get_dist_meta_file_name(checkpoint, dist_id, use_safetensors): if use_safetensors: return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{SHARD_META_SUFFIX}") - return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{UNSHARD_META_SUFFIX}") \ No newline at end of file + return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{UNSHARD_META_SUFFIX}") diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 398740f62251..3b47a3a85d5a 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -1,11 +1,11 @@ import copy import logging import os +from contextlib import nullcontext from functools import reduce from pathlib import Path from shutil import rmtree from typing import Dict, Iterator, Optional, OrderedDict, Tuple -from contextlib import nullcontext import torch import torch.distributed as dist @@ -26,14 +26,14 @@ from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat from .distributed_checkpoint_utils import ( + MODEL_WEIGHT_PREFIX, + RestoreDefaultStateDictBehavior, create_model_metadata, + get_dist_files_name, + get_dist_meta_file_name, is_pytorch_model_meta_dist_file, load_dist_model, save_metadata, - get_dist_files_name, - get_dist_meta_file_name, - MODEL_WEIGHT_PREFIX, - RestoreDefaultStateDictBehavior ) from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile @@ -108,7 +108,7 @@ def _model_sharder( keep_vars: bool = False, size_per_shard: int = 1024, pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None, - gather_dtensor: bool = True, + gather_dtensor: bool = True, ) -> Iterator[Tuple[OrderedDict, int]]: # An internel method that breaks state_dict of model into shards within limited size. @@ -118,7 +118,7 @@ def _model_sharder( for name, param in model.named_parameters(): if param is None: continue - + # Gather tensor pieces when using tensor parallel. param_ = gather_distributed_param(param, keep_vars=False) if is_padded_tensor(param_): @@ -245,12 +245,12 @@ def save_sharded_model( model._force_wait_all_gather() if self.dp_rank != 0 and self.sp_rank != 0: return - + model_metadata = None if not gather_dtensor: # Manage filenames of sharded weights and index file for each pipeline stage. model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) - + model = model.unwrap() if os.path.isfile(checkpoint): @@ -280,7 +280,9 @@ def save_sharded_model( if not gather_dtensor: dist_id = self.tp_size * self.pp_rank + self.tp_rank weights_name = get_dist_files_name(weights_name=weights_name, dist_id=dist_id) - metadata_file = get_dist_meta_file_name(checkpoint=checkpoint, dist_id=dist_id, use_safetensors=use_safetensors) + metadata_file = get_dist_meta_file_name( + checkpoint=checkpoint, dist_id=dist_id, use_safetensors=use_safetensors + ) if use_async: total_size, writers = async_save_state_dict_shards( @@ -413,9 +415,7 @@ def load_sharded_model( ) model = model.unwrap() with RestoreDefaultStateDictBehavior(model): - load_state_dict_into_model( - model, state_dict, missing_keys=[], strict=False, load_sub_module=True - ) + load_state_dict_into_model(model, state_dict, missing_keys=[], strict=False, load_sub_module=True) return model_before_wrapping = model # backup for model before wrapping @@ -897,7 +897,7 @@ def load_unsharded_model( load_dtensor = True break - model_metadata = None # used for dist model + model_metadata = None # used for dist model if load_dtensor: model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)