Skip to content

Commit

Permalink
Remove hooks on gradient accumulation on engine/optimizer destroy (#4858
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn authored Dec 30, 2023
1 parent 3e94f8c commit 4034205
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 3 deletions.
3 changes: 2 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
STEP_MICRO_TIMER, \
FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER, \
STEP_GLOBAL_TIMER
from deepspeed.utils.debug import debug_extract_module_and_param_names
from deepspeed.utils.debug import debug_extract_module_and_param_names, debug_clear_module_and_param_names
from deepspeed.monitor.monitor import MonitorMaster
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
from deepspeed.runtime.utils import clip_grad_norm_
Expand Down Expand Up @@ -362,6 +362,7 @@ def __init__(
def destroy(self):
if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
self.optimizer.destroy()
debug_clear_module_and_param_names()

def _get_model_parameters(self):
if self.autotuning_profile_model_info():
Expand Down
6 changes: 5 additions & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def __init__(

#creates backward hooks for gradient partitioning
###Calls all gather param
self._grad_acc_hooks = []
self.create_reduce_and_remove_grad_hooks()

#exit(0)
Expand All @@ -397,6 +398,9 @@ def __init__(

def destroy(self):
self.parameter_offload.destroy()
for hook in self._grad_acc_hooks:
hook.remove()
print_rank_0("Removed grad acc hooks", force=False)
del self.__ipg_bucket_flat_buffer

def initialize_ds_offload(
Expand Down Expand Up @@ -1118,7 +1122,7 @@ def wrapper(param):
def reduce_partition_and_remove_grads(*notneeded):
self.reduce_ready_partitions_and_remove_grads(param)

grad_acc.register_hook(reduce_partition_and_remove_grads)
self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads))
self.grad_accs.append(grad_acc)

#print(f"param grad fn {param.expand_as(param).grad_fn}")
Expand Down
8 changes: 7 additions & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ def __init__(self,
self.reset_partition_gradient_structures()

# creates backward hooks for gradient partitioning
self._grad_acc_hooks = []
if self.partition_gradients or self.overlap_comm:
self.create_reduce_and_remove_grad_hooks()

Expand Down Expand Up @@ -522,6 +523,11 @@ def __init__(self,
self._enable_universal_checkpoint()
self._param_slice_mappings = self._create_param_mapping()

def destroy(self):
for hook in self._grad_acc_hooks:
hook.remove()
self.print_rank_0("Removed grad acc hooks")

def _enable_universal_checkpoint(self):
for lp_param_group in self.bit16_groups:
enable_universal_checkpoint(param_list=lp_param_group)
Expand Down Expand Up @@ -864,7 +870,7 @@ def wrapper(param, i):
def reduce_partition_and_remove_grads(*notneeded):
self.reduce_ready_partitions_and_remove_grads(param, i)

grad_acc.register_hook(reduce_partition_and_remove_grads)
self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads))
self.grad_accs.append(grad_acc)

wrapper(param, i)
Expand Down
7 changes: 7 additions & 0 deletions deepspeed/utils/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
param_names = {}


def debug_clear_module_and_param_names():
global module_names
global param_names
module_names = {}
param_names = {}


def debug_extract_module_and_param_names(model):
# extract the fully qualified names as soon as the model is acquired
global module_names
Expand Down

0 comments on commit 4034205

Please sign in to comment.