From 4d94731017e63a73008c94a99a1bee87585bf9b0 Mon Sep 17 00:00:00 2001 From: Angelica Chen <72049239+angie-chen55@users.noreply.github.com> Date: Tue, 9 Jul 2024 13:42:56 -0400 Subject: [PATCH] move tensors to correct device (#1) --- holo/test_functions/closed_form/_rough_mt_fuji.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/holo/test_functions/closed_form/_rough_mt_fuji.py b/holo/test_functions/closed_form/_rough_mt_fuji.py index 41256b9..9864c7b 100644 --- a/holo/test_functions/closed_form/_rough_mt_fuji.py +++ b/holo/test_functions/closed_form/_rough_mt_fuji.py @@ -60,8 +60,10 @@ def _optimal_value(self): def optimal_solution(self): soln = self.centroids.clone() mask = self._random_term - self._additive_term > 0 + mask = mask.to(device=soln.device) soln = torch.where(mask, torch.ones_like(soln), soln) mask = self._random_term + self._additive_term < 0 + mask = mask.to(device=soln.device) soln = torch.where(mask, torch.zeros_like(soln), soln) return soln @@ -69,6 +71,7 @@ def optimal_solution(self): def to(self, device, dtype): self.centroids = self.centroids.to(device, dtype) self._generator = torch.Generator(device=device).manual_seed(self._random_seed) + self._random_term = self._random_term.to(device) return self def __repr__(self):