Skip to content

Commit

Permalink
Update torch.norm to torch.linalg.norm and `torch.linalg.vector_n…
Browse files Browse the repository at this point in the history
…orm` (#6931)

- [x] Update PR since `torch.norm` and `torch.linalg.norm` have
[different function
signatures](https://pytorch.org/docs/stable/generated/torch.linalg.norm.html#torch.linalg.norm).
- [x] Check if there are any numeric differences between the functions.
- [x] Determine why there appear to be performance differences from
others [here](pytorch/pytorch#136360).
- [x] Update to `torch.linalg.vectornorm`
Follow up PR handles these in the comm folder: #6960
  • Loading branch information
loadams authored Jan 21, 2025
1 parent bc76b04 commit 8d1bc0a
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion deepspeed/runtime/comm/compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_erro

compensated_server_m.add_(server_error)

server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())

server_error.set_(compensated_server_m -
server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/comm/hccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_erro

compensated_server_m.add_(server_error)

server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())

server_error.set_(compensated_server_m -
server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/fp16/onebit/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def step(self, closure=None, grads=None):
# This is used to reduce compression error during compression stage.
momentum_scales = []
for group in self.param_groups:
momentum_scales.append([(torch.linalg.norm(self.state[p]['exp_avg']) /
momentum_scales.append([(torch.linalg.vector_norm(self.state[p]['exp_avg']) /
np.sqrt(torch.numel(self.state[p]['exp_avg']))).item()
for p in group['params']])
united_scale = sum([sum(x) for x in momentum_scales]) / sum([len(x) for x in momentum_scales])
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2101,7 +2101,7 @@ def step(self, closure=None):
return

norm_groups = self._get_norm_groups()
scaled_global_grad_norm = torch.linalg.norm(torch.stack(norm_groups))
scaled_global_grad_norm = torch.linalg.vector_norm(torch.stack(norm_groups))

# Stash unscaled gradient norm
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale
Expand Down
5 changes: 3 additions & 2 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1691,7 +1691,8 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):
continue
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
all_norms.append(
torch.norm(g.data.double().detach(), norm_type).to(get_accelerator().current_device_name()))
torch.linalg.vector_norm(g.data.double().detach(),
ord=norm_type).to(get_accelerator().current_device_name()))
if len(all_norms) > 0:
total_norm = torch.stack(all_norms).square().sum().float()
else:
Expand Down Expand Up @@ -1795,7 +1796,7 @@ def scaled_global_norm(self, norm_type=2):
self._average_expert_grad_norms(norm_groups)

# calculating L2 norm
return torch.norm(torch.stack(norm_groups), p=norm_type)
return torch.linalg.vector_norm(torch.stack(norm_groups), ord=norm_type)

def get_bit16_param_group(self, group_no):
bit16_partitions = self.parallel_partitioned_bit16_groups[group_no]
Expand Down

0 comments on commit 8d1bc0a

Please sign in to comment.