From 44752e65df122cfd70e05804c45d00a149c2f4a4 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 11 Oct 2024 23:35:34 -0700 Subject: [PATCH 1/2] Removed CPU randn() from schedulers Signed-off-by: Boris Fomitchev --- monai/networks/schedulers/ddim.py | 2 +- monai/networks/schedulers/ddpm.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py index 2a0121d063..50a680336d 100644 --- a/monai/networks/schedulers/ddim.py +++ b/monai/networks/schedulers/ddim.py @@ -220,7 +220,7 @@ def step( if eta > 0: # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 device: torch.device = torch.device(model_output.device if torch.is_tensor(model_output) else "cpu") - noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) + noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator, device=device) variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise pred_prev_sample = pred_prev_sample + variance diff --git a/monai/networks/schedulers/ddpm.py b/monai/networks/schedulers/ddpm.py index 93ad833031..d64e11d379 100644 --- a/monai/networks/schedulers/ddpm.py +++ b/monai/networks/schedulers/ddpm.py @@ -241,8 +241,12 @@ def step( variance = 0 if timestep > 0: noise = torch.randn( - model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator - ).to(model_output.device) + model_output.size(), + dtype=model_output.dtype, + layout=model_output.layout, + generator=generator, + device=model_output.device, + ) variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise pred_prev_sample = pred_prev_sample + variance From 1893375e8ec5cdd52c99d431e607feb7dbd80802 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 15 Oct 2024 11:05:36 +0800 Subject: [PATCH 2/2] workaround for #8149 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 6d0ccd378a..72654d3534 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -22,7 +22,7 @@ isort>=5.1 ruff pytype>=2020.6.1; platform_system != "Windows" types-setuptools -mypy>=1.5.0 +mypy>=1.5.0, <1.12.0 ninja torchvision psutil