diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 154d5cb5e5f3..1b7ae18889fd 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -107,9 +107,9 @@ def _model_sharder( if param is None: continue # Gather tensor pieces when using tensor parallel. - if is_padded_tensor(param): - param = to_unpadded_tensor(param) param_ = gather_distributed_param(param, keep_vars=False) + if is_padded_tensor(param_): + param_ = to_unpadded_tensor(param_) if pinned_state_dicts is not None: if (prefix + name) not in pinned_state_dicts: pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu")