From 2c69e9ad2887b7e78c88c2db3209713542dad7e2 Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 17 Jul 2024 10:01:44 +0000 Subject: [PATCH] Minor fixes --- .../distributed_differentiable_primitives.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index f1102908..bd41347a 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -86,7 +86,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): def backward(ctx, grad_output): group = ctx.group out = DifferentiableReduceScatterSum.apply(grad_output, group) - return out, None, None + return out, None class DifferentiableReduceScatterSum(torch.autograd.Function): @@ -122,7 +122,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): group = ctx.group - return DifferentiableAllGather.apply(grad_output, group, False), None + return DifferentiableAllGather.apply(grad_output, group), None # ----------------- @@ -138,7 +138,7 @@ def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None): return DifferentiableAllReduceSum.apply(tensor, group) -def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None) +def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None): return DifferentiableAllGather.apply(tensor, group)