Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precisely track nvme optimizer offload #6963

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,10 +799,8 @@ def zero_load_from_fp32_weights(self):
def zero_elastic_checkpoint(self):
return self._config.zero_config.elastic_checkpoint

def zero_has_nvme_offload(self):
if not hasattr(self.optimizer, "swap_optimizer"):
return False
return self.optimizer.swap_optimizer or self.optimizer.params_in_nvme_and_cpu
def zero_nvme_offload_optimizer(self):
return getattr(self.optimizer, "swap_optimizer", False)

def zero_max_live_parameters(self):
return self._config.zero_config.max_live_parameters
Expand Down Expand Up @@ -2865,7 +2863,7 @@ def load_checkpoint(self,
if not success:
self.optimizer._restore_from_bit16_weights()

if self.zero_has_nvme_offload():
if self.zero_nvme_offload_optimizer():
from shutil import copytree, disk_usage
offload_dir = self.optimizer.optimizer_swapper.swap_folder
offload_ckpt_dir = os.path.join(load_dir, tag, "offloaded_tensors")
Expand Down Expand Up @@ -3205,7 +3203,7 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True,
self._create_zero_checkpoint_files(save_dir, tag)
self._save_zero_checkpoint(save_dir, tag)

if self.zero_has_nvme_offload():
if self.zero_nvme_offload_optimizer():
from shutil import copytree, disk_usage
offload_dir = self.optimizer.optimizer_swapper.swap_folder
offload_ckpt_dir = os.path.join(save_dir, tag, "offloaded_tensors")
Expand Down
5 changes: 5 additions & 0 deletions deepspeed/runtime/swap_tensor/optimizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ def __init__(self, swap_config, aio_config, base_folder, optimizer, largest_nume
'timer_names',
]

def purge_state(self):
for swap_info in self.swap_params_info.values():
swap_info.tensors = [swap_info.tensors[0]]
swap_info.has_state_tensors = False

def swappable_tensor(self, param=None, numel=None):
assert param is not None or numel is not None, "Either param or numel must be provided"
if param is not None:
Expand Down
12 changes: 4 additions & 8 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2652,11 +2652,9 @@ def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True):
self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT])
self._clear_fp32_optimizer_param_groups()

if self.swap_optimizer or self.params_in_nvme_and_cpu:
if self.swap_optimizer:
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
for swap_info in self.optimizer_swapper.swap_params_info.values():
swap_info.tensors = [swap_info.tensors[0]]
swap_info.has_state_tensors = False
self.optimizer_swapper.purge_state()

if self.swap_optimizer:
# Touch all parameters to synchronize all buffers
Expand Down Expand Up @@ -2773,11 +2771,9 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa
else:
optim_sd[OPTIMIZER_STATE_DICT]['state'][0][key] = key_tensor

if self.swap_optimizer or self.params_in_nvme_and_cpu:
if self.swap_optimizer:
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
for swap_info in self.optimizer_swapper.swap_params_info.values():
swap_info.tensors = [swap_info.tensors[0]]
swap_info.has_state_tensors = False
self.optimizer_swapper.purge_state()

if self.swap_optimizer:
# Touch all parameters to synchronize all buffers
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/runtime/zero/test_nvme_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ class TestNVMeCheckpointing(DistributedTest):
world_size = 1

@pytest.mark.parametrize('param_offload_device, optim_offload_device',
[(OffloadDeviceEnum.cpu, OffloadDeviceEnum.cpu),
[(OffloadDeviceEnum.none, OffloadDeviceEnum.nvme),
(OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme),
(OffloadDeviceEnum.nvme, OffloadDeviceEnum.none),
(OffloadDeviceEnum.nvme, OffloadDeviceEnum.cpu),
(OffloadDeviceEnum.nvme, OffloadDeviceEnum.nvme)])
def test_nvme_checkpointing(self, tmpdir, param_offload_device, optim_offload_device):
zero_dir, ckpt_dir = os.path.join(tmpdir, "zero"), os.path.join(tmpdir, "checkpoint")
Expand Down