From ef10b1e4006c00cf8ee3bc7480994fb274d18841 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Romain=20BR=C3=89GIER?= Date: Sun, 1 Oct 2023 10:52:30 +0200 Subject: [PATCH] v1.4.1 --- docsource/source/index.rst | 3 +++ roma/mappings.py | 26 +++++++++++++++++++++++++- roma/utils.py | 12 ++++++++++++ setup.py | 2 +- test/test_mappings.py | 8 ++++++++ test/test_utils.py | 8 ++++++-- 6 files changed, 55 insertions(+), 4 deletions(-) diff --git a/docsource/source/index.rst b/docsource/source/index.rst index 30409b5..5546e96 100644 --- a/docsource/source/index.rst +++ b/docsource/source/index.rst @@ -231,6 +231,9 @@ Bits of code were adapted from SciPy. Documentation is generated, distributed an Changelog ========== +Version 1.4.1: + - Added XYZW / WXYZ quaternion conversion routines: :func:`~roma.mappings.quat_xyzw_to_wxyz()` and :func:`~roma.mappings.quat_wxyz_to_xyzw()`. + - Added :func:`~roma.utils.rotvec_geodesic_distance()`. Version 1.4.0: - Added the :ref:`spatial-transformations` module. Version 1.3.4: diff --git a/roma/mappings.py b/roma/mappings.py index d2e8d63..efdb0a8 100644 --- a/roma/mappings.py +++ b/roma/mappings.py @@ -428,4 +428,28 @@ def rotmat_to_rotvec(R): batch of rotation vectors (...x3 tensor). """ q = rotmat_to_unitquat(R) - return unitquat_to_rotvec(q) \ No newline at end of file + return unitquat_to_rotvec(q) + +def quat_xyzw_to_wxyz(xyzw): + """ + Convert quaternion from XYZW to WXYZ convention. + + Args: + xyzw (...x4 tensor, XYZW convention): batch of quaternions. + Returns: + batch of quaternions (...x4 tensor, WXYZ convention). + """ + assert xyzw.shape[-1] == 4 + return torch.cat((xyzw[...,-1,None], xyzw[...,:-1]), dim=-1) + +def quat_wxyz_to_xyzw(wxyz): + """ + Convert quaternion from WXYZ to XYZW convention. + + Args: + wxyz (...x4 tensor, WXYZ convention): batch of quaternions. + Returns: + batch of quaternions (...x4 tensor, XYZW convention). + """ + assert wxyz.shape[-1] == 4 + return torch.cat((wxyz[...,1:], wxyz[...,0,None]), dim=-1) diff --git a/roma/utils.py b/roma/utils.py index 9e8c645..b4065d1 100644 --- a/roma/utils.py +++ b/roma/utils.py @@ -154,6 +154,18 @@ def unitquat_geodesic_distance(q1, q2): """ return 4.0 * torch.asin(0.5 * torch.min(roma.internal.norm(q2 - q1, dim=-1), roma.internal.norm(q2 + q1, dim=-1))) +def rotvec_geodesic_distance(vec1, vec2): + """ + Returns the angular distance between rotations represented by rotation vectors. + (use a conversion to unit quaternions internally). + + Args: + vec1, vec2 (...x3 tensors): batch of rotation vectors. + Returns: + batch of angles in radians (... tensor). + """ + return unitquat_geodesic_distance(roma.mappings.rotvec_to_unitquat(vec1), roma.mappings.rotvec_to_unitquat(vec2)) + def quat_conjugation(quat): """ Returns the conjugation of input batch of quaternions. diff --git a/setup.py b/setup.py index 9375833..312f7d0 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="roma", - version="1.4.0", + version="1.4.1", author="Romain Brégier", author_email="romain.bregier@naverlabs.com", description="A lightweight library to deal with 3D rotations in PyTorch.", diff --git a/test/test_mappings.py b/test/test_mappings.py index f99f058..2478aef 100644 --- a/test/test_mappings.py +++ b/test/test_mappings.py @@ -185,6 +185,14 @@ def test_rotvec_unitquat_nan_issues(self): self.assertTrue(torch.all(torch.isfinite(loss))) self.assertTrue(torch.all(torch.isfinite(rotvec.grad))) + def test_quat_conventions(self): + for batch_shape in [(), (10,), (23,5)]: + quat_xyzw = torch.randn(batch_shape + (4,)) + quat_wxyz = roma.mappings.quat_xyzw_to_wxyz(quat_xyzw) + self.assertTrue(quat_xyzw.shape == quat_wxyz.shape) + quat_xyzw_bis = roma.mappings.quat_wxyz_to_xyzw(quat_wxyz) + self.assertTrue(torch.all(quat_xyzw == quat_xyzw_bis)) + if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/test/test_utils.py b/test/test_utils.py index 83ad84e..d4c5533 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -20,7 +20,7 @@ def test_flatten(self): xbis = roma.internal.unflatten_batch_dims(xflat, batch_shape) self.assertTrue(torch.all(xbis == x)) - def test_geodesic_distance(self): + def test_rotmat_geodesic_distance(self): batch_size = 100 for dtype in (torch.float32, torch.float64): axis = torch.nn.functional.normalize(torch.randn((batch_size,3), dtype=dtype), dim=-1) @@ -48,7 +48,7 @@ def test_geodesic_distance(self): geo_dist_naive = roma.rotmat_geodesic_distance_naive(M @ R, M @ I[None,:,:]) self.assertTrue(is_close(torch.abs(alpha), geo_dist_naive)) - def test_unitquat_geodesic_distance(self): + def test_other_geodesic_distance(self): batch_size = 100 for dtype in (torch.float32, torch.float64): q1 = roma.random_unitquat(batch_size, dtype=dtype) @@ -59,6 +59,10 @@ def test_unitquat_geodesic_distance(self): R2 = roma.unitquat_to_rotmat(q2) alpha_R = roma.rotmat_geodesic_distance(R1, R2) self.assertTrue(is_close(alpha_q, alpha_R)) + rotvec1 = roma.unitquat_to_rotvec(q1) + rotvec2 = roma.unitquat_to_rotvec(q2) + alpha_rotvec = roma.rotvec_geodesic_distance(rotvec1, rotvec2) + self.assertTrue(is_close(alpha_rotvec, alpha_q)) def test_random_unitquat(self): q = roma.random_unitquat((3,5))