From 40a0a9f21afe81bde31ee3acad695a96ce0f86aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Romain=20BR=C3=89GIER?= Date: Mon, 20 Nov 2023 17:28:54 +0900 Subject: [PATCH] Fix: copy dtype and device in quat_action allocation --- roma/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roma/utils.py b/roma/utils.py index b4065d1..3fcdc47 100644 --- a/roma/utils.py +++ b/roma/utils.py @@ -256,7 +256,7 @@ def quat_action(q, v, is_normalized=False): """ batch_shape = v.shape[:-1] iquat = quat_conjugation(q) if is_normalized else quat_inverse(q) - pure = torch.cat((v, torch.zeros(batch_shape + (1,))), dim=-1) + pure = torch.cat((v, torch.zeros(batch_shape + (1,), dtype=q.dtype, device=q.device)), dim=-1) res = quat_product(q, quat_product(pure, iquat)) return res[...,:3]