diff --git a/roma/utils.py b/roma/utils.py index b4065d1..3fcdc47 100644 --- a/roma/utils.py +++ b/roma/utils.py @@ -256,7 +256,7 @@ def quat_action(q, v, is_normalized=False): """ batch_shape = v.shape[:-1] iquat = quat_conjugation(q) if is_normalized else quat_inverse(q) - pure = torch.cat((v, torch.zeros(batch_shape + (1,))), dim=-1) + pure = torch.cat((v, torch.zeros(batch_shape + (1,), dtype=q.dtype, device=q.device)), dim=-1) res = quat_product(q, quat_product(pure, iquat)) return res[...,:3]