diff --git a/guided_diffusion/fp16_util.py b/guided_diffusion/fp16_util.py index 35a3f461..df3882d0 100644 --- a/guided_diffusion/fp16_util.py +++ b/guided_diffusion/fp16_util.py @@ -199,7 +199,8 @@ def _optimize_fp16(self, opt: th.optim.Optimizer): logger.logkv_mean("grad_norm", grad_norm) logger.logkv_mean("param_norm", param_norm) - self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) + for p in self.master_params: + p.grad.mul_(1.0 / (2 ** self.lg_loss_scale)) opt.step() zero_master_grads(self.master_params) master_params_to_model_params(self.param_groups_and_shapes, self.master_params)