You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It seems like there is a bug when setting use_fp16 = False, the log is
Traceback (most recent call last):
File "scripts/image_train_stable.py", line 150, in <module>
main()
File "scripts/image_train_stable.py", line 78, in main
TrainLoop(
File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/train_util.py", line 194, in run_loop
self.run_step(batch, cond)
File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/train_util.py", line 208, in run_step
self.forward_backward(batch, cond)
File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/train_util.py", line 236, in forward_backward
losses = compute_losses()
File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/respace.py", line 96, in training_losses
return super().training_losses(self._wrap_model(model), *args, **kwargs)
File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/gaussian_diffusion.py", line 1137, in training_losses
model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/respace.py", line 133, in __call__
return self.model(x, new_ts, **kwargs)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 886, in forward
output = self.module(*inputs[0], **kwargs[0])
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/unet.py", line 882, in forward
h = module(h, emb, context)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/unet.py", line 217, in forward
x = layer(x, context)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/unet.py", line 188, in forward
x = block(x, context=context)
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/unet.py", line 140, in forward
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/nn.py", line 162, in checkpoint
return CheckpointFunction.apply(func, len(inputs), *args)
File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/nn.py", line 174, in forward
output_tensors = ctx.run_function(*ctx.input_tensors)
File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/unet.py", line 144, in _forward
x = self.attn2(self.norm2(x), context=context) + x
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/unet.py", line 112, in forward
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/functional.py", line 327, in einsum
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
RuntimeError: expected scalar type Float but found Half
The text was updated successfully, but these errors were encountered:
Thanks! One interesting thing I found is fp 16 seems to use similar or even more GPU memory. Both fp16 and fp32 can train with bs=8 on a GPU with 48 GB ram, and when bs = 2, I even found fp32 takes less memory.
It seems like there is a bug when setting use_fp16 = False, the log is
The text was updated successfully, but these errors were encountered: