Skip to content

Commit

Permalink
[checkpointio] gather tensor before unpad it if the tensor is both pa…
Browse files Browse the repository at this point in the history
…dded and distributed (#6168)
  • Loading branch information
Lemon-412 authored Jan 21, 2025
1 parent 5b094a8 commit 97e60cb
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def _model_sharder(
if param is None:
continue
# Gather tensor pieces when using tensor parallel.
if is_padded_tensor(param):
param = to_unpadded_tensor(param)
param_ = gather_distributed_param(param, keep_vars=False)
if is_padded_tensor(param_):
param_ = to_unpadded_tensor(param_)
if pinned_state_dicts is not None:
if (prefix + name) not in pinned_state_dicts:
pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu")
Expand Down

0 comments on commit 97e60cb

Please sign in to comment.