diff --git a/docsource/source/index.rst b/docsource/source/index.rst index 4bb3b56..7956dad 100644 --- a/docsource/source/index.rst +++ b/docsource/source/index.rst @@ -235,6 +235,8 @@ Bits of code were adapted from SciPy. Documentation is generated, distributed an Changelog ========== +Version 1.5.1: + - Syntactic sugar for :ref:`spatial-transformations`: support for default linear or translation parts, identity transformations and batch dimension squeezing. Version 1.5.0: - Added Euler angles mappings. Version 1.4.5: diff --git a/roma/transforms.py b/roma/transforms.py index 6ad5254..efc99dd 100644 --- a/roma/transforms.py +++ b/roma/transforms.py @@ -249,7 +249,17 @@ def __getitem__(self, args): """ Slicing operator, for convenience. """ - return type(self)(self.linear[args], self.translation[args]) + return type(self)(self.linear[args], self.translation[args]) + + def squeeze(self, dim): + """ + Return a view of the transformation in which a batch dimension equal to 1 has been squeezed. + + :var dim: positive integer: The dimension to squeeze. + """ + assert dim >= 0, "Only positive dimensions are supported to avoid ambiguities." + assert self.linear.shape[dim] == self.translation.shape[dim] == 1, "" + return type(self)(self.linear.squeeze(dim), self.translation.squeeze(dim)) def __len__(self): return len(self.linear) @@ -316,9 +326,12 @@ class Affine(_BaseAffine, Linear): An affine transformation represented by a linear and a translation part. :var linear: (...xCxD tensor): batch of matrices specifying the linear part. - :var translation: (...xD tensor): batch of matrices specifying the translation part. + :var translation: (...xD tensor or None): batch of matrices specifying the translation part. """ def __init__(self, linear, translation): + if translation is None: + # Set a default null translation. + translation = torch.zeros(linear.shape[:-2] + (linear.shape[-1],), dtype=linear.dtype, device=linear.device) assert translation.shape[-1] == linear.shape[-2], "Incompatible linear and translation dimensions." assert len(linear.shape[:-2]) == len(translation.shape[:-1]), "Batch dimensions should be broadcastable." _BaseAffine.__init__(self, linear, translation) @@ -328,20 +341,36 @@ class Isometry(Affine, Orthonormal): """ An isometric transformation represented by an orthonormal and a translation part. - :var linear: (...xDxD tensor): batch of matrices specifying the linear part. - :var translation: (...xD tensor): batch of matrices specifying the translation part. + :var linear: (...xDxD tensor or None): batch of matrices specifying the linear part. + :var translation: (...xD tensor or None): batch of matrices specifying the translation part. """ def __init__(self, linear, translation): - assert linear.shape[-1] == linear.shape[-2], "Expecting same dimensions for input and output." + if linear is None: + # Set a default identity linear part. + batch_dims = translation.shape[:-1] + D = translation.shape[-1] + linear = torch.eye(D, dtype=translation.dtype, device=translation.device)[[None] * len(batch_dims)] + else: + assert linear.shape[-1] == linear.shape[-2], "Expecting same dimensions for input and output." Affine.__init__(self, linear, translation) + @classmethod + def Identity(cls, dim, batch_shape=tuple(), dtype=torch.float32, device=None): + """ + Return a default identity transformation. + + :var dim: (strictly positive integer): dimension of the space in which the transformation operates (e.g. `dim=3` for 3D transformations). + :var batch_shape: (tuple): batch dimensions considered. + """ + translation = torch.zeros(batch_shape + (dim,), dtype=dtype, device=device) + return cls(linear=None, translation=translation) class Rigid(Isometry, Rotation): """ A rigid transformation represented by an rotation and a translation part. - :var linear: (...xDxD tensor): batch of matrices specifying the linear part. - :var translation: (...xD tensor): batch of matrices specifying the translation part. + :var linear: (...xDxD tensor or None): batch of matrices specifying the linear part. + :var translation: (...xD tensor or None): batch of matrices specifying the translation part. """ def __init__(self, linear, translation): Isometry.__init__(self, linear, translation) diff --git a/setup.py b/setup.py index 4ce2312..10ef965 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="roma", - version="1.5.0", + version="1.5.1", author="Romain Brégier", author_email="romain.bregier@naverlabs.com", description="A lightweight library to deal with 3D rotations in PyTorch.", diff --git a/test/test_transforms.py b/test/test_transforms.py index 47707f1..aa0cc53 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -284,4 +284,58 @@ def test_rigid_conversions(self): x1 = rigid.apply(x) x2 = rigidunitquat.apply(x) self.assertTrue(torch.all(torch.isclose(x1, x2))) - \ No newline at end of file + + def test_translation_only(self): + batch_shape = (2,3,6) + D = 3 + dtype = torch.float64 + + # Translation-only transformation + translation = torch.randn(batch_shape + (D,), dtype=dtype) + identity = torch.eye(D, dtype=dtype)[[None] * len(batch_shape)].repeat(batch_shape + (1,1)) + T = roma.Rigid(identity, translation) + T1 = roma.Rigid(None, translation) + delta = T1 @ T.inverse() + self.assertTrue(delta.linear.shape == T.linear.shape) + self.assertTrue(delta.translation.shape == T.translation.shape) + epsilon = 1e-7 + self.assertTrue(torch.all(torch.abs(delta.translation) < epsilon)) + self.assertTrue(torch.all(torch.isclose(T.linear, T1.linear))) + + def test_identity(self): + batch_shape = (3,5) + D = 4 + dtype = torch.float64 + identity_transform = roma.Rigid.Identity(D, batch_shape=batch_shape) + self.assertTrue(torch.all(identity_transform.translation == torch.zeros((3,5,4), dtype=dtype))) + self.assertTrue(torch.all(identity_transform.linear == torch.eye(4)[None,None].repeat(3,5,1,1))) + + + def test_linear_only(self): + batch_shape = (2,3,6) + D = 3 + dtype = torch.float64 + # rotation-only transformation + null_translation = torch.zeros(batch_shape + (D,), dtype=dtype) + R = roma.random_rotmat(batch_shape, dtype=dtype) + T = roma.Rigid(R, null_translation) + T1 = roma.Rigid(R, None) + delta = T1 @ T.inverse() + self.assertTrue(delta.linear.shape == T.linear.shape) + self.assertTrue(delta.translation.shape == T.translation.shape) + epsilon = 1e-7 + self.assertTrue(torch.all(torch.abs(delta.translation) < epsilon)) + self.assertTrue(torch.all(torch.isclose(T.linear, T1.linear))) + + def test_squeezing(self): + batch_shape = (2,3,6) + D = 3 + dtype = torch.float64 + # rotation-only transformation + t = torch.randn(batch_shape + (D,), dtype=dtype) + R = roma.random_rotmat(batch_shape, dtype=dtype) + T = roma.Rigid(R, t) + unsqueezed = T[None] + squeezed = unsqueezed.squeeze(dim=0) + self.assertTrue(torch.all(T.linear == squeezed.linear)) + self.assertTrue(torch.all(T.translation == squeezed.translation)) \ No newline at end of file