Skip to content

Commit

Permalink
v1.4.2
Browse files Browse the repository at this point in the history
  • Loading branch information
Romain BRÉGIER committed Dec 12, 2023
1 parent 40a0a9f commit 37b2c50
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 3 deletions.
3 changes: 3 additions & 0 deletions docsource/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ Bits of code were adapted from SciPy. Documentation is generated, distributed an

Changelog
==========
Version 1.4.2:
- Fix for :func:`~roma.utils.quat_action()` to support arbitrary devices and types.
- Added conversion functions between Rigid and RigidUnitQuat.
Version 1.4.1:
- Added XYZW / WXYZ quaternion conversion routines: :func:`~roma.mappings.quat_xyzw_to_wxyz()` and :func:`~roma.mappings.quat_wxyz_to_xyzw()`.
- Added :func:`~roma.utils.rotvec_geodesic_distance()`.
Expand Down
24 changes: 22 additions & 2 deletions roma/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def linear_compose(self, other):
Returns:
a tensor representing the composed transformation.
"""
assert len(self.linear.shape) == len(other.linear.shape), "Expecting the same number of batch dimensions for the two transformations."
return torch.einsum("...ik, ...kj -> ...ij", self.linear, other.linear)

def linear_inverse(self):
Expand All @@ -84,6 +85,7 @@ def linear_apply(self, v):
See note in :func:`~roma.transforms.Linear.apply()` regarding broadcasting.
"""
assert len(self.linear.shape) == len(v.shape) + 1, "Expecting the same number of batch dimensions for the transformation and the vector."
return torch.einsum("...ik, ...k -> ...i", self.linear, v)

def linear_normalize(self):
Expand Down Expand Up @@ -343,7 +345,16 @@ class Rigid(Isometry, Rotation):
:var translation: (...xD tensor): batch of matrices specifying the translation part.
"""
def __init__(self, linear, translation):
Isometry.__init__(self, linear, translation)
Isometry.__init__(self, linear, translation)

def to_rigidunitquat(self):
"""
Returns the corresponding RigidUnitQuat transformation.
Note:
Original and resulting transformations share the same translation tensor. Be careful in case of in-place modifications.
"""
return RigidUnitQuat(roma.rotmat_to_unitquat(self.linear), self.translation)

class RigidUnitQuat(_BaseAffine, RotationUnitQuat):
"""
Expand Down Expand Up @@ -404,4 +415,13 @@ def from_homogeneous(matrix):
D = H1 - 1
linear = roma.rotmat_to_unitquat(matrix[...,:D, :D])
translation = matrix[...,:D, D]
return RigidUnitQuat(linear, translation)
return RigidUnitQuat(linear, translation)

def to_rigid(self):
"""
Returns the corresponding Rigid transformation.
Note:
Original and resulting transformations share the same translation tensor. Be careful in case of in-place modifications.
"""
return Rigid(roma.unitquat_to_rotmat(self.linear), self.translation)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="roma",
version="1.4.1",
version="1.4.2",
author="Romain Brégier",
author_email="[email protected]",
description="A lightweight library to deal with 3D rotations in PyTorch.",
Expand Down
15 changes: 15 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,4 +266,19 @@ def test_affine_different_dims(self):
self.assertRaises(AssertionError, lambda : roma.Rotation(linear))
self.assertRaises(AssertionError, lambda : roma.Isometry(linear, translation))
self.assertRaises(AssertionError, lambda : roma.Rigid(linear, translation))

def test_rigid_conversions(self):
batch_shape = (2,3,6)
dtype = torch.float64
rigid = roma.Rigid(roma.random_rotmat(batch_shape, dtype=dtype), torch.zeros(batch_shape + (3,), dtype=dtype))
rigidunitquat = rigid.to_rigidunitquat()
rigid2 = rigidunitquat.to_rigid()

self.assertTrue(torch.all(torch.isclose(rigid.linear, rigid2.linear)))
self.assertTrue(torch.all(torch.isclose(rigid.translation, rigid2.translation)))

x = torch.randn(batch_shape + (3,), dtype=dtype)
x1 = rigid.apply(x)
x2 = rigidunitquat.apply(x)
self.assertTrue(torch.all(torch.isclose(x1, x2)))

0 comments on commit 37b2c50

Please sign in to comment.