Skip to content

Commit

Permalink
fix for fp16 loss scaling (#52)
Browse files Browse the repository at this point in the history
fixes #44
  • Loading branch information
unixpickle authored Jul 15, 2022
1 parent 4afb0ac commit 22e0df8
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion guided_diffusion/fp16_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ def _optimize_fp16(self, opt: th.optim.Optimizer):
logger.logkv_mean("grad_norm", grad_norm)
logger.logkv_mean("param_norm", param_norm)

self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
for p in self.master_params:
p.grad.mul_(1.0 / (2 ** self.lg_loss_scale))
opt.step()
zero_master_grads(self.master_params)
master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
Expand Down

0 comments on commit 22e0df8

Please sign in to comment.