Skip to content

Commit

Permalink
v1.4.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Romain BRÉGIER committed Oct 1, 2023
1 parent 0da38a3 commit ef10b1e
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 4 deletions.
3 changes: 3 additions & 0 deletions docsource/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ Bits of code were adapted from SciPy. Documentation is generated, distributed an

Changelog
==========
Version 1.4.1:
- Added XYZW / WXYZ quaternion conversion routines: :func:`~roma.mappings.quat_xyzw_to_wxyz()` and :func:`~roma.mappings.quat_wxyz_to_xyzw()`.
- Added :func:`~roma.utils.rotvec_geodesic_distance()`.
Version 1.4.0:
- Added the :ref:`spatial-transformations` module.
Version 1.3.4:
Expand Down
26 changes: 25 additions & 1 deletion roma/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,4 +428,28 @@ def rotmat_to_rotvec(R):
batch of rotation vectors (...x3 tensor).
"""
q = rotmat_to_unitquat(R)
return unitquat_to_rotvec(q)
return unitquat_to_rotvec(q)

def quat_xyzw_to_wxyz(xyzw):
"""
Convert quaternion from XYZW to WXYZ convention.
Args:
xyzw (...x4 tensor, XYZW convention): batch of quaternions.
Returns:
batch of quaternions (...x4 tensor, WXYZ convention).
"""
assert xyzw.shape[-1] == 4
return torch.cat((xyzw[...,-1,None], xyzw[...,:-1]), dim=-1)

def quat_wxyz_to_xyzw(wxyz):
"""
Convert quaternion from WXYZ to XYZW convention.
Args:
wxyz (...x4 tensor, WXYZ convention): batch of quaternions.
Returns:
batch of quaternions (...x4 tensor, XYZW convention).
"""
assert wxyz.shape[-1] == 4
return torch.cat((wxyz[...,1:], wxyz[...,0,None]), dim=-1)
12 changes: 12 additions & 0 deletions roma/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ def unitquat_geodesic_distance(q1, q2):
"""
return 4.0 * torch.asin(0.5 * torch.min(roma.internal.norm(q2 - q1, dim=-1), roma.internal.norm(q2 + q1, dim=-1)))

def rotvec_geodesic_distance(vec1, vec2):
"""
Returns the angular distance between rotations represented by rotation vectors.
(use a conversion to unit quaternions internally).
Args:
vec1, vec2 (...x3 tensors): batch of rotation vectors.
Returns:
batch of angles in radians (... tensor).
"""
return unitquat_geodesic_distance(roma.mappings.rotvec_to_unitquat(vec1), roma.mappings.rotvec_to_unitquat(vec2))

def quat_conjugation(quat):
"""
Returns the conjugation of input batch of quaternions.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="roma",
version="1.4.0",
version="1.4.1",
author="Romain Brégier",
author_email="[email protected]",
description="A lightweight library to deal with 3D rotations in PyTorch.",
Expand Down
8 changes: 8 additions & 0 deletions test/test_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,14 @@ def test_rotvec_unitquat_nan_issues(self):
self.assertTrue(torch.all(torch.isfinite(loss)))
self.assertTrue(torch.all(torch.isfinite(rotvec.grad)))

def test_quat_conventions(self):
for batch_shape in [(), (10,), (23,5)]:
quat_xyzw = torch.randn(batch_shape + (4,))
quat_wxyz = roma.mappings.quat_xyzw_to_wxyz(quat_xyzw)
self.assertTrue(quat_xyzw.shape == quat_wxyz.shape)
quat_xyzw_bis = roma.mappings.quat_wxyz_to_xyzw(quat_wxyz)
self.assertTrue(torch.all(quat_xyzw == quat_xyzw_bis))


if __name__ == "__main__":
unittest.main()
8 changes: 6 additions & 2 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_flatten(self):
xbis = roma.internal.unflatten_batch_dims(xflat, batch_shape)
self.assertTrue(torch.all(xbis == x))

def test_geodesic_distance(self):
def test_rotmat_geodesic_distance(self):
batch_size = 100
for dtype in (torch.float32, torch.float64):
axis = torch.nn.functional.normalize(torch.randn((batch_size,3), dtype=dtype), dim=-1)
Expand Down Expand Up @@ -48,7 +48,7 @@ def test_geodesic_distance(self):
geo_dist_naive = roma.rotmat_geodesic_distance_naive(M @ R, M @ I[None,:,:])
self.assertTrue(is_close(torch.abs(alpha), geo_dist_naive))

def test_unitquat_geodesic_distance(self):
def test_other_geodesic_distance(self):
batch_size = 100
for dtype in (torch.float32, torch.float64):
q1 = roma.random_unitquat(batch_size, dtype=dtype)
Expand All @@ -59,6 +59,10 @@ def test_unitquat_geodesic_distance(self):
R2 = roma.unitquat_to_rotmat(q2)
alpha_R = roma.rotmat_geodesic_distance(R1, R2)
self.assertTrue(is_close(alpha_q, alpha_R))
rotvec1 = roma.unitquat_to_rotvec(q1)
rotvec2 = roma.unitquat_to_rotvec(q2)
alpha_rotvec = roma.rotvec_geodesic_distance(rotvec1, rotvec2)
self.assertTrue(is_close(alpha_rotvec, alpha_q))

def test_random_unitquat(self):
q = roma.random_unitquat((3,5))
Expand Down

0 comments on commit ef10b1e

Please sign in to comment.