Skip to content

Commit

Permalink
Force param sync when using distributed optimizer and overlap_param_g…
Browse files Browse the repository at this point in the history
…ather (NVIDIA#11486)

* Add disable/enable forward pre hook for DDP and overlap param gather

Signed-off-by: Hemil Desai <[email protected]>

* Fix

Signed-off-by: Hemil Desai <[email protected]>

* Force param sync before saving checkpoint

Signed-off-by: Hemil Desai <[email protected]>

* fix

Signed-off-by: Hemil Desai <[email protected]>

* Apply isort and black reformatting

Signed-off-by: hemildesai <[email protected]>

---------

Signed-off-by: Hemil Desai <[email protected]>
Signed-off-by: hemildesai <[email protected]>
Co-authored-by: hemildesai <[email protected]>
  • Loading branch information
hemildesai and hemildesai authored Dec 6, 2024
1 parent 54572f6 commit bde672e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
18 changes: 18 additions & 0 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,24 @@ def _module_sharded_state_dict(self, module, *args, **kwargs) -> Dict[str, Any]:

raise ValueError("Could not find sharded state dict")

def enable_forward_pre_hook(self):
for model in self:
model_chunk = model.module
assert isinstance(model_chunk, DDP)
model_chunk.enable_forward_pre_hook()

def disable_forward_pre_hook(self):
for model in self:
model_chunk = model.module
assert isinstance(model_chunk, DDP)
model_chunk.disable_forward_pre_hook()

def force_param_sync(self):
for model in self:
model_chunk = model.module
assert isinstance(model_chunk, DDP)
model_chunk.start_param_sync(force_sync=True)

@property
def pipeline(self) -> Union[ModelT, List[ModelT]]:
if len(self) == 1:
Expand Down
7 changes: 7 additions & 0 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,13 @@ def save_checkpoint(
self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None
) -> None:
"""Saves checkpoint"""
if (
isinstance(self.ddp_config, DistributedDataParallelConfig)
and self.ddp_config.use_distributed_optimizer
and self.ddp_config.overlap_param_gather
):
self.megatron_parallel.force_param_sync()

checkpoint["state_dict"] = OrderedDict([]) # remove device state_dict
# retrieve `sharded_state_dict` if it has not already been configured in `on_save_checkpoint`
if "sharded_state_dict" not in checkpoint:
Expand Down

0 comments on commit bde672e

Please sign in to comment.