From 456c9ac67975da698e44dfd4f90c4f7b867d08bd Mon Sep 17 00:00:00 2001 From: Max Kovalenko Date: Fri, 3 Jan 2025 17:48:24 +0200 Subject: [PATCH] Stage3: Use new torch grad accumulation hooks API (#6773) * This commit addresses a Deepspeed issue [#6718](https://github.com/microsoft/DeepSpeed/issues/6718) * The existing code has been using the grad_acc node hook to reduce params grads. The constructs such as `param.data = replicated_tensor.data` used in `allgather_params(..)` are compiled into `param.set()` causing the hook assigned to the grad_acc node not being called. * Starting from PyTorch 2.1 there is a new and robust hook API on a param itself: `param.register_post_accumulate_grad_hook(..)` * This commit will make use of the proper API depending on the PyTorch version * It will also disable compile for PyTorch versions < 2.1 --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> --- deepspeed/runtime/compiler.py | 3 ++- deepspeed/runtime/zero/stage3.py | 7 ++----- deepspeed/utils/torch.py | 9 +++++++++ 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py index fa9220f4fcd0..be778b83f8bb 100644 --- a/deepspeed/runtime/compiler.py +++ b/deepspeed/runtime/compiler.py @@ -4,6 +4,7 @@ # DeepSpeed Team import torch +from deepspeed.utils.torch import required_torch_version try: from torch.compiler import is_compiling as torch_is_compiling @@ -16,7 +17,7 @@ def is_compile_supported(): - return hasattr(torch, "compiler") and hasattr(torch.nn.Module, "compile") + return required_torch_version(min_version=2.1) def disable(func): diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 459cffce52c8..28f91cb9b3ab 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -16,6 +16,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.utils import logger +from deepspeed.utils.torch import register_grad_hook from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item @@ -1159,7 +1160,6 @@ def overlapping_partition_gradients_reduce_epilogue(self): def create_reduce_and_remove_grad_hooks(self): print_rank_0(f'[Begin] Create gradient reduction hooks') - self.grad_accs = [] self.leaf_parameters = defaultdict(list) for i, param_group in enumerate(self.fp16_groups): for param in param_group: @@ -1172,15 +1172,12 @@ def create_reduce_and_remove_grad_hooks(self): #print(f"After all gather {param.device}, {param.shape}") def wrapper(param): - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] @instrument_w_nvtx def reduce_partition_and_remove_grads(*notneeded): self.reduce_ready_partitions_and_remove_grads(param) - self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads)) - self.grad_accs.append(grad_acc) + self._grad_acc_hooks.append(register_grad_hook(param, reduce_partition_and_remove_grads)) #print(f"param grad fn {param.expand_as(param).grad_fn}") if z3_leaf_parameter(param): diff --git a/deepspeed/utils/torch.py b/deepspeed/utils/torch.py index eb22d3561035..1d32775fe64a 100644 --- a/deepspeed/utils/torch.py +++ b/deepspeed/utils/torch.py @@ -20,3 +20,12 @@ def required_torch_version(min_version=None, max_version=None): return False return True + + +def register_grad_hook(param, hook): + if required_torch_version(min_version=2.1): + return param.register_post_accumulate_grad_hook(hook) + else: + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + return grad_acc.register_hook(hook)