diff --git a/test/test_euler.py b/test/test_euler.py index 6e7ba1e..c6068fe 100644 --- a/test/test_euler.py +++ b/test/test_euler.py @@ -20,6 +20,7 @@ def test_euler(self): def test_euler_unitquat_consistency(self): device = torch.device(0) if torch.cuda.is_available() else torch.device('cpu') + dtype = torch.float64 for degrees in (True, False): for batch_shape in [tuple(), torch.Size((30,)), @@ -28,7 +29,7 @@ def test_euler_unitquat_consistency(self): for convention in ["".join(permutation) for permutation in itertools.permutations('xyz')] + ["xyx", "xzx", "yxy", "yzy", "zxz", "zyz"]: if intrinsics: convention = convention.upper() - q = roma.random_unitquat(batch_shape, device=device) + q = roma.random_unitquat(batch_shape, device=device, dtype=dtype) angles = roma.unitquat_to_euler(convention, q, degrees=degrees) self.assertTrue(len(angles) == 3) self.assertTrue(all([angle.shape == batch_shape for angle in angles])) @@ -63,6 +64,7 @@ def test_euler_rotvec_consistency(self): def test_euler_rotmat_consistency(self): device = torch.device(0) if torch.cuda.is_available() else torch.device('cpu') + dtype = torch.float64 for degrees in (True, False): for batch_shape in [tuple(), torch.Size((30,)), @@ -71,7 +73,7 @@ def test_euler_rotmat_consistency(self): for convention in ["".join(permutation) for permutation in itertools.permutations('xyz')] + ["xyx", "xzx", "yxy", "yzy", "zxz", "zyz"]: if intrinsics: convention = convention.upper() - q = roma.random_rotmat(batch_shape, device=device) + q = roma.random_rotmat(batch_shape, device=device, dtype=dtype) angles = roma.rotmat_to_euler(convention, q, degrees=degrees) self.assertTrue(len(angles) == 3) self.assertTrue(all([angle.shape == batch_shape for angle in angles])) @@ -96,7 +98,8 @@ def test_euler_tensor(self): Test that Euler conversion methods support both list and tensor inputs. """ batch_shape = (10,34) - q = roma.random_unitquat(batch_shape, device=device) + dtype = torch.float64 + q = roma.random_unitquat(batch_shape, device=device, dtype=dtype) convention = 'xyz' angles = roma.unitquat_to_euler(convention, q) angles_tensor = roma.unitquat_to_euler(convention, q, as_tensor=True)