From 8c22a56413604dfcad04641a441398904c10edef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Romain=20BR=C3=89GIER?= Date: Mon, 26 Feb 2024 09:57:49 +0100 Subject: [PATCH] v1.4.3 --- docsource/source/index.rst | 2 ++ roma/utils.py | 15 +++++++++++++-- setup.py | 2 +- test/test_utils.py | 2 ++ 4 files changed, 18 insertions(+), 3 deletions(-) diff --git a/docsource/source/index.rst b/docsource/source/index.rst index 0d4fbe1..c08cedd 100644 --- a/docsource/source/index.rst +++ b/docsource/source/index.rst @@ -231,6 +231,8 @@ Bits of code were adapted from SciPy. Documentation is generated, distributed an Changelog ========== +Version 1.4.3: + - Fix normalization bug in :func:`~roma.utils.quat_composition()` (thanks jamiesalter for reporting). Version 1.4.2: - Fix for :func:`~roma.utils.quat_action()` to support arbitrary devices and types. - Added conversion functions between Rigid and RigidUnitQuat. diff --git a/roma/utils.py b/roma/utils.py index 3fcdc47..ee3b5b5 100644 --- a/roma/utils.py +++ b/roma/utils.py @@ -195,6 +195,17 @@ def quat_inverse(quat): """ return quat_conjugation(quat) / torch.sum(quat**2, dim=-1, keepdim=True) +def quat_normalize(quat): + """ + Returns a normalized, unit norm, copy of a batch of quaternions. + + Args: + quat (...x4 tensor, XYZW convention): batch of quaternions. + Returns: + batch of quaternions (...x4 tensor, XYZW convention). + """ + return quat / torch.linalg.norm(quat, dim=-1, keepdim=True) + def quat_product(p, q): """ Returns the product of two quaternions. @@ -235,7 +246,7 @@ def quat_composition(sequence, normalize = False): for q in sequence[1:]: res = quat_product(res, q) if normalize: - q = q / torch.norm(q, dim=-1, keepdim=True) + res = quat_normalize(res) return res def quat_action(q, v, is_normalized=False): @@ -380,7 +391,7 @@ def unitquat_slerp_fast(q0, q1, steps, shortest_arc=True): # Interpolation q = alpha * q0.reshape((1,) * steps.dim() + q0.shape) + beta * q1.reshape((1,) * steps.dim() + q1.shape) # Normalization of the output - q /= torch.norm(q, dim=-1, keepdim=True) + q = quat_normalize(q) return q.reshape(steps.shape + batch_shape + (4,)) def rotvec_slerp(rotvec0, rotvec1, steps): diff --git a/setup.py b/setup.py index d612caa..0fb36cb 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="roma", - version="1.4.2", + version="1.4.3", author="Romain Brégier", author_email="romain.bregier@naverlabs.com", description="A lightweight library to deal with 3D rotations in PyTorch.", diff --git a/test/test_utils.py b/test/test_utils.py index d4c5533..cb70bf7 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -112,6 +112,8 @@ def test_quat(self): q_id = roma.rotvec_to_unitquat(torch.zeros(1,3)) self.assertTrue(is_close(q_id, roma.quat_product(q, iq))) self.assertTrue(is_close(q_id, roma.quat_product(iq, q))) + nq = roma.quat_normalize(q) + self.assertTrue(is_close(torch.linalg.norm(nq, dim=-1), torch.ones(batch_size, dtype=dtype))) def test_quat_action(self): batch_size=100