Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AleHD committed Jul 17, 2024
1 parent 59bfb6b commit 2c69e9a
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]):
def backward(ctx, grad_output):
group = ctx.group
out = DifferentiableReduceScatterSum.apply(grad_output, group)
return out, None, None
return out, None


class DifferentiableReduceScatterSum(torch.autograd.Function):
Expand Down Expand Up @@ -122,7 +122,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]):
@staticmethod
def backward(ctx, grad_output):
group = ctx.group
return DifferentiableAllGather.apply(grad_output, group, False), None
return DifferentiableAllGather.apply(grad_output, group), None


# -----------------
Expand All @@ -138,7 +138,7 @@ def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None):
return DifferentiableAllReduceSum.apply(tensor, group)


def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None)
def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None):
return DifferentiableAllGather.apply(tensor, group)


Expand Down

0 comments on commit 2c69e9a

Please sign in to comment.