Skip to content

Commit

Permalink
Euler functions support stacked tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
Romain BRÉGIER committed Apr 23, 2024
1 parent d5eabcc commit 319d130
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 12 deletions.
5 changes: 3 additions & 2 deletions docsource/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ Rotation matrix (rotmat)
- Encoded as a ...xDxD tensor (D=3 for 3D rotations).
- We use column-vector convention, i.e. :math:`R X` is the transformation of a 1xD vector :math:`X` by a rotation matrix :math:`R`.

Euler and Tait-Bryan angles are *NOT* currently supported.
This is because of the many different existing conventions, and because of the limited interest of such parameterization for numerical applications.
Euler and Tait-Bryan angles (euler)
- Encoded as a ...xD tensor (with D=3 for typical Euler angle conventions) or a list of D tensors corresponding to each angle.
- We provide mappings between Euler angles and other rotation representations (use an other representation for actual computations).


Mappings between rotation representations
Expand Down
32 changes: 22 additions & 10 deletions roma/euler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _elementary_basis_index(axis):
else:
raise ValueError("Invalid axis.")

def euler_to_unitquat(convention: str, angles : list, degrees=False, normalize=True, dtype=None, device=None):
def euler_to_unitquat(convention: str, angles, degrees=False, normalize=True, dtype=None, device=None):
"""
Convert Euler angles to unit quaternion representation.
Expand All @@ -28,12 +28,16 @@ def euler_to_unitquat(convention: str, angles : list, degrees=False, normalize=T
The sequence of rotation is expressed either with respect to a global 'extrinsic' coordinate system (in which case axes are denoted in lowercase: 'x', 'y', or 'z'),
or with respect to an 'intrinsic' coordinates system attached to the object under rotation (in which case axes are denoted in uppercase: 'X', 'Y', 'Z').
Intrinsic and extrinsic conventions cannot be mixed.
angles (list of floats or list of tensors): a list of angles associated to each axis, expressed in radians by default.
angles (list of floats, list of tensors, or tensor): a list of angles associated to each axis, expressed in radians by default.
If a single tensor is provided, Euler angles are assumed to be stacked along the latest dimension.
degrees (bool): if True, input angles are assumed to be expressed in degrees.
Returns:
A batch of unit quaternions (...x4 tensor, XYZW convention).
"""
if type(angles) == torch.Tensor:
angles = [t.squeeze(dim=-1) for t in torch.split(angles, split_size_or_sections=1, dim=-1)]

assert len(convention) == len(angles)

extrinsics = convention.islower()
Expand Down Expand Up @@ -67,7 +71,8 @@ def euler_to_rotvec(convention: str, angles : list, degrees=False, dtype=None, d
Args:
convention (string): 'xyz' for example. See :func:`~roma.euler.euler_to_unitquat()`.
angles (list of floats or torch tensors): a list of angles associated to each axis, expressed in radians by default.
angles (list of floats, list of tensors, or tensor): a list of angles associated to each axis, expressed in radians by default.
If a single tensor is provided, Euler angles are assumed to be stacked along the latest dimension.
degrees (bool): if True, input angles are assumed to be expressed in degrees.
Returns:
Expand All @@ -81,23 +86,25 @@ def euler_to_rotmat(convention: str, angles : list, degrees=False, dtype=None, d
Args:
convention (string): 'xyz' for example. See :func:`~roma.euler.euler_to_unitquat()`.
angles (list of floats or torch tensors): a list of angles associated to each axis, expressed in radians by default.
angles (list of floats, list of tensors, or tensor): a list of angles associated to each axis, expressed in radians by default.
If a single tensor is provided, Euler angles are assumed to be stacked along the latest dimension.
degrees (bool): if True, input angles are assumed to be expressed in degrees.
Returns:
a batch of rotation matrices (...x3x3 tensor).
"""
return roma.unitquat_to_rotmat(euler_to_unitquat(convention=convention, angles=angles, degrees=degrees, dtype=dtype, device=device))

def unitquat_to_euler(convention : str, quat, degrees=False, epsilon=1e-7):
def unitquat_to_euler(convention : str, quat, as_tensor=False, degrees=False, epsilon=1e-7):
"""
Convert unit quaternion to Euler angles representation.
Args:
convention (str): string of 3 characters belonging to {'x', 'y', 'z'} for extrinsic rotations, or {'X', 'Y', 'Z'} for intrinsic rotations.
Consecutive axes should not be identical.
quat (...x4 tensor, XYZW convention): input batch of unit quaternion.
degrees (bool): if True, returned angles are expressed in degrees.
as_tensor (boolean): if True, angles are returned as a stacked ...x3 tensor.
degrees (bool): if True, angles are returned in degrees.
epsilon (float): a small value used to detect degenerate configurations.
Returns:
Expand Down Expand Up @@ -192,17 +199,21 @@ def unitquat_to_euler(convention : str, quat, degrees=False, epsilon=1e-7):
foo = torch.rad2deg(foo)
angles[idx] = roma.internal.unflatten_batch_dims(foo, batch_shape)

if as_tensor:
angles = torch.stack(angles, dim=-1)

return angles

def rotvec_to_euler(convention : str, rotvec, degrees=False, epsilon=1e-7):
def rotvec_to_euler(convention : str, rotvec, as_tensor=False, degrees=False, epsilon=1e-7):
"""
Convert rotation vector to Euler angles representation.
Args:
convention (str): string of 3 characters belonging to {'x', 'y', 'z'} for extrinsic rotations, or {'X', 'Y', 'Z'} for intrinsic rotations.
Consecutive axes should not be identical.
rotvec (...x3 tensor): input batch of rotation vectors.
degrees (bool): if True, returned angles are expressed in degrees.
as_tensor (boolean): if True, angles are returned as a stacked ...x3 tensor.
degrees (bool): if True, angles are returned in degrees.
epsilon (float): a small value used to detect degenerate configurations.
Returns:
Expand All @@ -211,15 +222,16 @@ def rotvec_to_euler(convention : str, rotvec, degrees=False, epsilon=1e-7):
"""
return unitquat_to_euler(convention, roma.rotvec_to_unitquat(rotvec), degrees=degrees, epsilon=epsilon)

def rotmat_to_euler(convention : str, rotmat, degrees=False, epsilon=1e-7):
def rotmat_to_euler(convention : str, rotmat, as_tensor=False, degrees=False, epsilon=1e-7):
"""
Convert rotation matrix to Euler angles representation.
Args:
convention (str): string of 3 characters belonging to {'x', 'y', 'z'} for extrinsic rotations, or {'X', 'Y', 'Z'} for intrinsic rotations.
Consecutive axes should not be identical.
rotmat (...x3x3 tensor): input batch of rotation matrices.
degrees (bool): if True, returned angles are expressed in degrees.
as_tensor (boolean): if True, angles are returned as a stacked ...x3 tensor.
degrees (bool): if True, angles are returned in degrees.
epsilon (float): a small value used to detect degenerate configurations.
Returns:
Expand Down
16 changes: 16 additions & 0 deletions test/test_euler.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,21 @@ def test_euler_backward(self):
angles = roma.rotvec_to_euler('xyz', rotvec)
sum(angles).backward()

def test_euler_tensor(self):
"""
Test that Euler conversion methods support both list and tensor inputs.
"""
batch_shape = (10,34)
q = roma.random_unitquat(batch_shape, device=device)
convention = 'xyz'
angles = roma.unitquat_to_euler(convention, q)
angles_tensor = roma.unitquat_to_euler(convention, q, as_tensor=True)
assert type(angles) == list
assert type(angles_tensor) == torch.Tensor
q1 = roma.euler_to_unitquat(convention, angles)
q2 = roma.euler_to_unitquat(convention, angles_tensor)
self.assertTrue(torch.all(roma.rotmat_geodesic_distance(q1, q2) < 1e-6))


if __name__ == "__main__":
unittest.main()

0 comments on commit 319d130

Please sign in to comment.