Skip to content

Commit

Permalink
v1.4.3
Browse files Browse the repository at this point in the history
  • Loading branch information
Romain BRÉGIER committed Feb 26, 2024
1 parent 37b2c50 commit 8c22a56
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docsource/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ Bits of code were adapted from SciPy. Documentation is generated, distributed an

Changelog
==========
Version 1.4.3:
- Fix normalization bug in :func:`~roma.utils.quat_composition()` (thanks jamiesalter for reporting).
Version 1.4.2:
- Fix for :func:`~roma.utils.quat_action()` to support arbitrary devices and types.
- Added conversion functions between Rigid and RigidUnitQuat.
Expand Down
15 changes: 13 additions & 2 deletions roma/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,17 @@ def quat_inverse(quat):
"""
return quat_conjugation(quat) / torch.sum(quat**2, dim=-1, keepdim=True)

def quat_normalize(quat):
"""
Returns a normalized, unit norm, copy of a batch of quaternions.
Args:
quat (...x4 tensor, XYZW convention): batch of quaternions.
Returns:
batch of quaternions (...x4 tensor, XYZW convention).
"""
return quat / torch.linalg.norm(quat, dim=-1, keepdim=True)

def quat_product(p, q):
"""
Returns the product of two quaternions.
Expand Down Expand Up @@ -235,7 +246,7 @@ def quat_composition(sequence, normalize = False):
for q in sequence[1:]:
res = quat_product(res, q)
if normalize:
q = q / torch.norm(q, dim=-1, keepdim=True)
res = quat_normalize(res)
return res

def quat_action(q, v, is_normalized=False):
Expand Down Expand Up @@ -380,7 +391,7 @@ def unitquat_slerp_fast(q0, q1, steps, shortest_arc=True):
# Interpolation
q = alpha * q0.reshape((1,) * steps.dim() + q0.shape) + beta * q1.reshape((1,) * steps.dim() + q1.shape)
# Normalization of the output
q /= torch.norm(q, dim=-1, keepdim=True)
q = quat_normalize(q)
return q.reshape(steps.shape + batch_shape + (4,))

def rotvec_slerp(rotvec0, rotvec1, steps):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="roma",
version="1.4.2",
version="1.4.3",
author="Romain Brégier",
author_email="[email protected]",
description="A lightweight library to deal with 3D rotations in PyTorch.",
Expand Down
2 changes: 2 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def test_quat(self):
q_id = roma.rotvec_to_unitquat(torch.zeros(1,3))
self.assertTrue(is_close(q_id, roma.quat_product(q, iq)))
self.assertTrue(is_close(q_id, roma.quat_product(iq, q)))
nq = roma.quat_normalize(q)
self.assertTrue(is_close(torch.linalg.norm(nq, dim=-1), torch.ones(batch_size, dtype=dtype)))

def test_quat_action(self):
batch_size=100
Expand Down

0 comments on commit 8c22a56

Please sign in to comment.