From 319d130f08269fa770714b0d28f346bdfd0a4cbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Romain=20BR=C3=89GIER?= Date: Tue, 23 Apr 2024 13:26:14 +0200 Subject: [PATCH] Euler functions support stacked tensors --- docsource/source/index.rst | 5 +++-- roma/euler.py | 32 ++++++++++++++++++++++---------- test/test_euler.py | 16 ++++++++++++++++ 3 files changed, 41 insertions(+), 12 deletions(-) diff --git a/docsource/source/index.rst b/docsource/source/index.rst index 801d06c..7aeeb67 100644 --- a/docsource/source/index.rst +++ b/docsource/source/index.rst @@ -57,8 +57,9 @@ Rotation matrix (rotmat) - Encoded as a ...xDxD tensor (D=3 for 3D rotations). - We use column-vector convention, i.e. :math:`R X` is the transformation of a 1xD vector :math:`X` by a rotation matrix :math:`R`. -Euler and Tait-Bryan angles are *NOT* currently supported. - This is because of the many different existing conventions, and because of the limited interest of such parameterization for numerical applications. +Euler and Tait-Bryan angles (euler) + - Encoded as a ...xD tensor (with D=3 for typical Euler angle conventions) or a list of D tensors corresponding to each angle. + - We provide mappings between Euler angles and other rotation representations (use an other representation for actual computations). Mappings between rotation representations diff --git a/roma/euler.py b/roma/euler.py index 8af25e0..44ca312 100644 --- a/roma/euler.py +++ b/roma/euler.py @@ -19,7 +19,7 @@ def _elementary_basis_index(axis): else: raise ValueError("Invalid axis.") -def euler_to_unitquat(convention: str, angles : list, degrees=False, normalize=True, dtype=None, device=None): +def euler_to_unitquat(convention: str, angles, degrees=False, normalize=True, dtype=None, device=None): """ Convert Euler angles to unit quaternion representation. @@ -28,12 +28,16 @@ def euler_to_unitquat(convention: str, angles : list, degrees=False, normalize=T The sequence of rotation is expressed either with respect to a global 'extrinsic' coordinate system (in which case axes are denoted in lowercase: 'x', 'y', or 'z'), or with respect to an 'intrinsic' coordinates system attached to the object under rotation (in which case axes are denoted in uppercase: 'X', 'Y', 'Z'). Intrinsic and extrinsic conventions cannot be mixed. - angles (list of floats or list of tensors): a list of angles associated to each axis, expressed in radians by default. + angles (list of floats, list of tensors, or tensor): a list of angles associated to each axis, expressed in radians by default. + If a single tensor is provided, Euler angles are assumed to be stacked along the latest dimension. degrees (bool): if True, input angles are assumed to be expressed in degrees. Returns: A batch of unit quaternions (...x4 tensor, XYZW convention). """ + if type(angles) == torch.Tensor: + angles = [t.squeeze(dim=-1) for t in torch.split(angles, split_size_or_sections=1, dim=-1)] + assert len(convention) == len(angles) extrinsics = convention.islower() @@ -67,7 +71,8 @@ def euler_to_rotvec(convention: str, angles : list, degrees=False, dtype=None, d Args: convention (string): 'xyz' for example. See :func:`~roma.euler.euler_to_unitquat()`. - angles (list of floats or torch tensors): a list of angles associated to each axis, expressed in radians by default. + angles (list of floats, list of tensors, or tensor): a list of angles associated to each axis, expressed in radians by default. + If a single tensor is provided, Euler angles are assumed to be stacked along the latest dimension. degrees (bool): if True, input angles are assumed to be expressed in degrees. Returns: @@ -81,7 +86,8 @@ def euler_to_rotmat(convention: str, angles : list, degrees=False, dtype=None, d Args: convention (string): 'xyz' for example. See :func:`~roma.euler.euler_to_unitquat()`. - angles (list of floats or torch tensors): a list of angles associated to each axis, expressed in radians by default. + angles (list of floats, list of tensors, or tensor): a list of angles associated to each axis, expressed in radians by default. + If a single tensor is provided, Euler angles are assumed to be stacked along the latest dimension. degrees (bool): if True, input angles are assumed to be expressed in degrees. Returns: @@ -89,7 +95,7 @@ def euler_to_rotmat(convention: str, angles : list, degrees=False, dtype=None, d """ return roma.unitquat_to_rotmat(euler_to_unitquat(convention=convention, angles=angles, degrees=degrees, dtype=dtype, device=device)) -def unitquat_to_euler(convention : str, quat, degrees=False, epsilon=1e-7): +def unitquat_to_euler(convention : str, quat, as_tensor=False, degrees=False, epsilon=1e-7): """ Convert unit quaternion to Euler angles representation. @@ -97,7 +103,8 @@ def unitquat_to_euler(convention : str, quat, degrees=False, epsilon=1e-7): convention (str): string of 3 characters belonging to {'x', 'y', 'z'} for extrinsic rotations, or {'X', 'Y', 'Z'} for intrinsic rotations. Consecutive axes should not be identical. quat (...x4 tensor, XYZW convention): input batch of unit quaternion. - degrees (bool): if True, returned angles are expressed in degrees. + as_tensor (boolean): if True, angles are returned as a stacked ...x3 tensor. + degrees (bool): if True, angles are returned in degrees. epsilon (float): a small value used to detect degenerate configurations. Returns: @@ -192,9 +199,12 @@ def unitquat_to_euler(convention : str, quat, degrees=False, epsilon=1e-7): foo = torch.rad2deg(foo) angles[idx] = roma.internal.unflatten_batch_dims(foo, batch_shape) + if as_tensor: + angles = torch.stack(angles, dim=-1) + return angles -def rotvec_to_euler(convention : str, rotvec, degrees=False, epsilon=1e-7): +def rotvec_to_euler(convention : str, rotvec, as_tensor=False, degrees=False, epsilon=1e-7): """ Convert rotation vector to Euler angles representation. @@ -202,7 +212,8 @@ def rotvec_to_euler(convention : str, rotvec, degrees=False, epsilon=1e-7): convention (str): string of 3 characters belonging to {'x', 'y', 'z'} for extrinsic rotations, or {'X', 'Y', 'Z'} for intrinsic rotations. Consecutive axes should not be identical. rotvec (...x3 tensor): input batch of rotation vectors. - degrees (bool): if True, returned angles are expressed in degrees. + as_tensor (boolean): if True, angles are returned as a stacked ...x3 tensor. + degrees (bool): if True, angles are returned in degrees. epsilon (float): a small value used to detect degenerate configurations. Returns: @@ -211,7 +222,7 @@ def rotvec_to_euler(convention : str, rotvec, degrees=False, epsilon=1e-7): """ return unitquat_to_euler(convention, roma.rotvec_to_unitquat(rotvec), degrees=degrees, epsilon=epsilon) -def rotmat_to_euler(convention : str, rotmat, degrees=False, epsilon=1e-7): +def rotmat_to_euler(convention : str, rotmat, as_tensor=False, degrees=False, epsilon=1e-7): """ Convert rotation matrix to Euler angles representation. @@ -219,7 +230,8 @@ def rotmat_to_euler(convention : str, rotmat, degrees=False, epsilon=1e-7): convention (str): string of 3 characters belonging to {'x', 'y', 'z'} for extrinsic rotations, or {'X', 'Y', 'Z'} for intrinsic rotations. Consecutive axes should not be identical. rotmat (...x3x3 tensor): input batch of rotation matrices. - degrees (bool): if True, returned angles are expressed in degrees. + as_tensor (boolean): if True, angles are returned as a stacked ...x3 tensor. + degrees (bool): if True, angles are returned in degrees. epsilon (float): a small value used to detect degenerate configurations. Returns: diff --git a/test/test_euler.py b/test/test_euler.py index d29214e..6e7ba1e 100644 --- a/test/test_euler.py +++ b/test/test_euler.py @@ -91,5 +91,21 @@ def test_euler_backward(self): angles = roma.rotvec_to_euler('xyz', rotvec) sum(angles).backward() + def test_euler_tensor(self): + """ + Test that Euler conversion methods support both list and tensor inputs. + """ + batch_shape = (10,34) + q = roma.random_unitquat(batch_shape, device=device) + convention = 'xyz' + angles = roma.unitquat_to_euler(convention, q) + angles_tensor = roma.unitquat_to_euler(convention, q, as_tensor=True) + assert type(angles) == list + assert type(angles_tensor) == torch.Tensor + q1 = roma.euler_to_unitquat(convention, angles) + q2 = roma.euler_to_unitquat(convention, angles_tensor) + self.assertTrue(torch.all(roma.rotmat_geodesic_distance(q1, q2) < 1e-6)) + + if __name__ == "__main__": unittest.main() \ No newline at end of file