From 97e60cbbcb429f51e0b0ff71f23dd20240ac9b84 Mon Sep 17 00:00:00 2001 From: Lemon Qin <57213526+Lemon-412@users.noreply.github.com> Date: Tue, 21 Jan 2025 10:23:15 +0800 Subject: [PATCH] [checkpointio] gather tensor before unpad it if the tensor is both padded and distributed (#6168) --- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 154d5cb5e5f3..1b7ae18889fd 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -107,9 +107,9 @@ def _model_sharder( if param is None: continue # Gather tensor pieces when using tensor parallel. - if is_padded_tensor(param): - param = to_unpadded_tensor(param) param_ = gather_distributed_param(param, keep_vars=False) + if is_padded_tensor(param_): + param_ = to_unpadded_tensor(param_) if pinned_state_dicts is not None: if (prefix + name) not in pinned_state_dicts: pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu")