Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 21, 2025
1 parent 794c6b1 commit 9984a64
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 22 deletions.
14 changes: 6 additions & 8 deletions colossalai/checkpoint_io/distributed_checkpoint_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
from contextlib import contextmanager
from typing import Dict

import torch
Expand All @@ -9,12 +10,8 @@

from colossalai.interface import ModelWrapper
from colossalai.shardformer.layer.parallel_module import ParallelModule
from contextlib import contextmanager

from .utils import (
load_state_dict,
search_tp_partition_dim,
)
from .utils import load_state_dict, search_tp_partition_dim

MODEL_META_PREFIX = "pytorch_model-meta-dist-"
MODEL_WEIGHT_PREFIX = "pytorch_model-dist-"
Expand All @@ -34,8 +31,7 @@ def RestoreDefaultStateDictBehavior(model):
yield model
finally:
for module, original_method in original_methods.items():
module._save_to_state_dict, module._load_from_state_dict = original_method

module._save_to_state_dict, module._load_from_state_dict = original_method


def create_model_metadata(
Expand Down Expand Up @@ -260,12 +256,14 @@ def load_dist_model(

return state_dict


def get_dist_files_name(weights_name, dist_id):
weights_name = weights_name.replace(".bin", f"-dist-{dist_id:05d}-shard.bin")
weights_name = weights_name.replace(".safetensors", f"-dist-{dist_id:05d}-shard.safetensors")
return weights_name


def get_dist_meta_file_name(checkpoint, dist_id, use_safetensors):
if use_safetensors:
return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{SHARD_META_SUFFIX}")
return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{UNSHARD_META_SUFFIX}")
return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{UNSHARD_META_SUFFIX}")
28 changes: 14 additions & 14 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import copy
import logging
import os
from contextlib import nullcontext
from functools import reduce
from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
from contextlib import nullcontext

import torch
import torch.distributed as dist
Expand All @@ -26,14 +26,14 @@
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat

from .distributed_checkpoint_utils import (
MODEL_WEIGHT_PREFIX,
RestoreDefaultStateDictBehavior,
create_model_metadata,
get_dist_files_name,
get_dist_meta_file_name,
is_pytorch_model_meta_dist_file,
load_dist_model,
save_metadata,
get_dist_files_name,
get_dist_meta_file_name,
MODEL_WEIGHT_PREFIX,
RestoreDefaultStateDictBehavior
)
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
Expand Down Expand Up @@ -108,7 +108,7 @@ def _model_sharder(
keep_vars: bool = False,
size_per_shard: int = 1024,
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
gather_dtensor: bool = True,
gather_dtensor: bool = True,
) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size.

Expand All @@ -118,7 +118,7 @@ def _model_sharder(
for name, param in model.named_parameters():
if param is None:
continue

# Gather tensor pieces when using tensor parallel.
param_ = gather_distributed_param(param, keep_vars=False)
if is_padded_tensor(param_):
Expand Down Expand Up @@ -245,12 +245,12 @@ def save_sharded_model(
model._force_wait_all_gather()
if self.dp_rank != 0 and self.sp_rank != 0:
return

model_metadata = None
if not gather_dtensor:
# Manage filenames of sharded weights and index file for each pipeline stage.
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)

model = model.unwrap()

if os.path.isfile(checkpoint):
Expand Down Expand Up @@ -280,7 +280,9 @@ def save_sharded_model(
if not gather_dtensor:
dist_id = self.tp_size * self.pp_rank + self.tp_rank
weights_name = get_dist_files_name(weights_name=weights_name, dist_id=dist_id)
metadata_file = get_dist_meta_file_name(checkpoint=checkpoint, dist_id=dist_id, use_safetensors=use_safetensors)
metadata_file = get_dist_meta_file_name(
checkpoint=checkpoint, dist_id=dist_id, use_safetensors=use_safetensors
)

if use_async:
total_size, writers = async_save_state_dict_shards(
Expand Down Expand Up @@ -413,9 +415,7 @@ def load_sharded_model(
)
model = model.unwrap()
with RestoreDefaultStateDictBehavior(model):
load_state_dict_into_model(
model, state_dict, missing_keys=[], strict=False, load_sub_module=True
)
load_state_dict_into_model(model, state_dict, missing_keys=[], strict=False, load_sub_module=True)
return

model_before_wrapping = model # backup for model before wrapping
Expand Down Expand Up @@ -897,7 +897,7 @@ def load_unsharded_model(
load_dtensor = True
break

model_metadata = None # used for dist model
model_metadata = None # used for dist model
if load_dtensor:
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)

Expand Down

0 comments on commit 9984a64

Please sign in to comment.