diff --git a/docsource/source/index.rst b/docsource/source/index.rst index 5546e96..0d4fbe1 100644 --- a/docsource/source/index.rst +++ b/docsource/source/index.rst @@ -231,6 +231,9 @@ Bits of code were adapted from SciPy. Documentation is generated, distributed an Changelog ========== +Version 1.4.2: + - Fix for :func:`~roma.utils.quat_action()` to support arbitrary devices and types. + - Added conversion functions between Rigid and RigidUnitQuat. Version 1.4.1: - Added XYZW / WXYZ quaternion conversion routines: :func:`~roma.mappings.quat_xyzw_to_wxyz()` and :func:`~roma.mappings.quat_wxyz_to_xyzw()`. - Added :func:`~roma.utils.rotvec_geodesic_distance()`. diff --git a/roma/transforms.py b/roma/transforms.py index 1a9e4f9..e2a0029 100644 --- a/roma/transforms.py +++ b/roma/transforms.py @@ -63,6 +63,7 @@ def linear_compose(self, other): Returns: a tensor representing the composed transformation. """ + assert len(self.linear.shape) == len(other.linear.shape), "Expecting the same number of batch dimensions for the two transformations." return torch.einsum("...ik, ...kj -> ...ij", self.linear, other.linear) def linear_inverse(self): @@ -84,6 +85,7 @@ def linear_apply(self, v): See note in :func:`~roma.transforms.Linear.apply()` regarding broadcasting. """ + assert len(self.linear.shape) == len(v.shape) + 1, "Expecting the same number of batch dimensions for the transformation and the vector." return torch.einsum("...ik, ...k -> ...i", self.linear, v) def linear_normalize(self): @@ -343,7 +345,16 @@ class Rigid(Isometry, Rotation): :var translation: (...xD tensor): batch of matrices specifying the translation part. """ def __init__(self, linear, translation): - Isometry.__init__(self, linear, translation) + Isometry.__init__(self, linear, translation) + + def to_rigidunitquat(self): + """ + Returns the corresponding RigidUnitQuat transformation. + + Note: + Original and resulting transformations share the same translation tensor. Be careful in case of in-place modifications. + """ + return RigidUnitQuat(roma.rotmat_to_unitquat(self.linear), self.translation) class RigidUnitQuat(_BaseAffine, RotationUnitQuat): """ @@ -404,4 +415,13 @@ def from_homogeneous(matrix): D = H1 - 1 linear = roma.rotmat_to_unitquat(matrix[...,:D, :D]) translation = matrix[...,:D, D] - return RigidUnitQuat(linear, translation) \ No newline at end of file + return RigidUnitQuat(linear, translation) + + def to_rigid(self): + """ + Returns the corresponding Rigid transformation. + + Note: + Original and resulting transformations share the same translation tensor. Be careful in case of in-place modifications. + """ + return Rigid(roma.unitquat_to_rotmat(self.linear), self.translation) \ No newline at end of file diff --git a/setup.py b/setup.py index 312f7d0..d612caa 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="roma", - version="1.4.1", + version="1.4.2", author="Romain Brégier", author_email="romain.bregier@naverlabs.com", description="A lightweight library to deal with 3D rotations in PyTorch.", diff --git a/test/test_transforms.py b/test/test_transforms.py index be6f6df..2f86456 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -266,4 +266,19 @@ def test_affine_different_dims(self): self.assertRaises(AssertionError, lambda : roma.Rotation(linear)) self.assertRaises(AssertionError, lambda : roma.Isometry(linear, translation)) self.assertRaises(AssertionError, lambda : roma.Rigid(linear, translation)) + + def test_rigid_conversions(self): + batch_shape = (2,3,6) + dtype = torch.float64 + rigid = roma.Rigid(roma.random_rotmat(batch_shape, dtype=dtype), torch.zeros(batch_shape + (3,), dtype=dtype)) + rigidunitquat = rigid.to_rigidunitquat() + rigid2 = rigidunitquat.to_rigid() + + self.assertTrue(torch.all(torch.isclose(rigid.linear, rigid2.linear))) + self.assertTrue(torch.all(torch.isclose(rigid.translation, rigid2.translation))) + + x = torch.randn(batch_shape + (3,), dtype=dtype) + x1 = rigid.apply(x) + x2 = rigidunitquat.apply(x) + self.assertTrue(torch.all(torch.isclose(x1, x2))) \ No newline at end of file