Skip to content

Commit

Permalink
v1.5.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Romain BRÉGIER committed Oct 30, 2024
1 parent fb8dadf commit 87c6b20
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 9 deletions.
2 changes: 2 additions & 0 deletions docsource/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ Bits of code were adapted from SciPy. Documentation is generated, distributed an

Changelog
==========
Version 1.5.1:
- Syntactic sugar for :ref:`spatial-transformations`: support for default linear or translation parts, identity transformations and batch dimension squeezing.
Version 1.5.0:
- Added Euler angles mappings.
Version 1.4.5:
Expand Down
43 changes: 36 additions & 7 deletions roma/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,17 @@ def __getitem__(self, args):
"""
Slicing operator, for convenience.
"""
return type(self)(self.linear[args], self.translation[args])
return type(self)(self.linear[args], self.translation[args])

def squeeze(self, dim):
"""
Return a view of the transformation in which a batch dimension equal to 1 has been squeezed.
:var dim: positive integer: The dimension to squeeze.
"""
assert dim >= 0, "Only positive dimensions are supported to avoid ambiguities."
assert self.linear.shape[dim] == self.translation.shape[dim] == 1, ""
return type(self)(self.linear.squeeze(dim), self.translation.squeeze(dim))

def __len__(self):
return len(self.linear)
Expand Down Expand Up @@ -316,9 +326,12 @@ class Affine(_BaseAffine, Linear):
An affine transformation represented by a linear and a translation part.
:var linear: (...xCxD tensor): batch of matrices specifying the linear part.
:var translation: (...xD tensor): batch of matrices specifying the translation part.
:var translation: (...xD tensor or None): batch of matrices specifying the translation part.
"""
def __init__(self, linear, translation):
if translation is None:
# Set a default null translation.
translation = torch.zeros(linear.shape[:-2] + (linear.shape[-1],), dtype=linear.dtype, device=linear.device)
assert translation.shape[-1] == linear.shape[-2], "Incompatible linear and translation dimensions."
assert len(linear.shape[:-2]) == len(translation.shape[:-1]), "Batch dimensions should be broadcastable."
_BaseAffine.__init__(self, linear, translation)
Expand All @@ -328,20 +341,36 @@ class Isometry(Affine, Orthonormal):
"""
An isometric transformation represented by an orthonormal and a translation part.
:var linear: (...xDxD tensor): batch of matrices specifying the linear part.
:var translation: (...xD tensor): batch of matrices specifying the translation part.
:var linear: (...xDxD tensor or None): batch of matrices specifying the linear part.
:var translation: (...xD tensor or None): batch of matrices specifying the translation part.
"""
def __init__(self, linear, translation):
assert linear.shape[-1] == linear.shape[-2], "Expecting same dimensions for input and output."
if linear is None:
# Set a default identity linear part.
batch_dims = translation.shape[:-1]
D = translation.shape[-1]
linear = torch.eye(D, dtype=translation.dtype, device=translation.device)[[None] * len(batch_dims)]
else:
assert linear.shape[-1] == linear.shape[-2], "Expecting same dimensions for input and output."
Affine.__init__(self, linear, translation)

@classmethod
def Identity(cls, dim, batch_shape=tuple(), dtype=torch.float32, device=None):
"""
Return a default identity transformation.
:var dim: (strictly positive integer): dimension of the space in which the transformation operates (e.g. `dim=3` for 3D transformations).
:var batch_shape: (tuple): batch dimensions considered.
"""
translation = torch.zeros(batch_shape + (dim,), dtype=dtype, device=device)
return cls(linear=None, translation=translation)

class Rigid(Isometry, Rotation):
"""
A rigid transformation represented by an rotation and a translation part.
:var linear: (...xDxD tensor): batch of matrices specifying the linear part.
:var translation: (...xD tensor): batch of matrices specifying the translation part.
:var linear: (...xDxD tensor or None): batch of matrices specifying the linear part.
:var translation: (...xD tensor or None): batch of matrices specifying the translation part.
"""
def __init__(self, linear, translation):
Isometry.__init__(self, linear, translation)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="roma",
version="1.5.0",
version="1.5.1",
author="Romain Brégier",
author_email="[email protected]",
description="A lightweight library to deal with 3D rotations in PyTorch.",
Expand Down
56 changes: 55 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,4 +284,58 @@ def test_rigid_conversions(self):
x1 = rigid.apply(x)
x2 = rigidunitquat.apply(x)
self.assertTrue(torch.all(torch.isclose(x1, x2)))


def test_translation_only(self):
batch_shape = (2,3,6)
D = 3
dtype = torch.float64

# Translation-only transformation
translation = torch.randn(batch_shape + (D,), dtype=dtype)
identity = torch.eye(D, dtype=dtype)[[None] * len(batch_shape)].repeat(batch_shape + (1,1))
T = roma.Rigid(identity, translation)
T1 = roma.Rigid(None, translation)
delta = T1 @ T.inverse()
self.assertTrue(delta.linear.shape == T.linear.shape)
self.assertTrue(delta.translation.shape == T.translation.shape)
epsilon = 1e-7
self.assertTrue(torch.all(torch.abs(delta.translation) < epsilon))
self.assertTrue(torch.all(torch.isclose(T.linear, T1.linear)))

def test_identity(self):
batch_shape = (3,5)
D = 4
dtype = torch.float64
identity_transform = roma.Rigid.Identity(D, batch_shape=batch_shape)
self.assertTrue(torch.all(identity_transform.translation == torch.zeros((3,5,4), dtype=dtype)))
self.assertTrue(torch.all(identity_transform.linear == torch.eye(4)[None,None].repeat(3,5,1,1)))


def test_linear_only(self):
batch_shape = (2,3,6)
D = 3
dtype = torch.float64
# rotation-only transformation
null_translation = torch.zeros(batch_shape + (D,), dtype=dtype)
R = roma.random_rotmat(batch_shape, dtype=dtype)
T = roma.Rigid(R, null_translation)
T1 = roma.Rigid(R, None)
delta = T1 @ T.inverse()
self.assertTrue(delta.linear.shape == T.linear.shape)
self.assertTrue(delta.translation.shape == T.translation.shape)
epsilon = 1e-7
self.assertTrue(torch.all(torch.abs(delta.translation) < epsilon))
self.assertTrue(torch.all(torch.isclose(T.linear, T1.linear)))

def test_squeezing(self):
batch_shape = (2,3,6)
D = 3
dtype = torch.float64
# rotation-only transformation
t = torch.randn(batch_shape + (D,), dtype=dtype)
R = roma.random_rotmat(batch_shape, dtype=dtype)
T = roma.Rigid(R, t)
unsqueezed = T[None]
squeezed = unsqueezed.squeeze(dim=0)
self.assertTrue(torch.all(T.linear == squeezed.linear))
self.assertTrue(torch.all(T.translation == squeezed.translation))

0 comments on commit 87c6b20

Please sign in to comment.