-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add forward kinematics & euler angle conversions monads
- Loading branch information
Showing
7 changed files
with
236 additions
and
132 deletions.
There are no files selected for viewing
4 changes: 4 additions & 0 deletions
4
moai/conf/model/monads/geometry/rotation/euler_to_rotmat.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# @package model.monads.euler_to_rotmat | ||
|
||
_target_: moai.monads.geometry.rotations.euler.EulerToRotationMatrix | ||
order: XYZ |
4 changes: 4 additions & 0 deletions
4
moai/conf/model/monads/geometry/rotation/rotmat_to_euler.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# @package model.monads.rotmat_to_euler | ||
|
||
_target_: moai.monads.geometry.rotations.euler.RotationMatrixToEuler | ||
order: XYZ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# @package model.monads.forward_kinematics | ||
|
||
_target_: moai.monads.human.pose.forward_kinematics.ForwardKinematics | ||
parents: ??? # int[] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
# NOTE: adapted from PyTorch3D | ||
|
||
import torch | ||
|
||
__all__ = ["euler_angles_to_matrix", "EulerToRotationMatrix"] | ||
|
||
|
||
def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Return the rotation matrices for one of the rotations about an axis | ||
of which Euler angles describe, for each value of the angle given. | ||
Args: | ||
axis: Axis label "X" or "Y or "Z". | ||
angle: any shape tensor of Euler angles in radians | ||
Returns: | ||
Rotation matrices as tensor of shape (..., 3, 3). | ||
""" | ||
|
||
cos = torch.cos(angle) | ||
sin = torch.sin(angle) | ||
one = torch.ones_like(angle) | ||
zero = torch.zeros_like(angle) | ||
|
||
if axis == "X": | ||
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) | ||
elif axis == "Y": | ||
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) | ||
elif axis == "Z": | ||
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) | ||
else: | ||
raise ValueError("letter must be either X, Y or Z.") | ||
|
||
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) | ||
|
||
|
||
def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: | ||
""" | ||
Convert rotations given as Euler angles in radians to rotation matrices. | ||
Args: | ||
euler_angles: Euler angles in radians as tensor of shape (..., 3). | ||
convention: Convention string of three uppercase letters from | ||
{"X", "Y", and "Z"}. | ||
Returns: | ||
Rotation matrices as tensor of shape (..., 3, 3). | ||
""" | ||
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: | ||
raise ValueError("Invalid input euler angles.") | ||
if len(convention) != 3: | ||
raise ValueError("Convention must have 3 letters.") | ||
if convention[1] in (convention[0], convention[2]): | ||
raise ValueError(f"Invalid convention {convention}.") | ||
for letter in convention: | ||
if letter not in ("X", "Y", "Z"): | ||
raise ValueError(f"Invalid letter {letter} in convention string.") | ||
matrices = [ | ||
_axis_angle_rotation(c, e) | ||
for c, e in zip(convention, torch.unbind(euler_angles, -1)) | ||
] | ||
# return functools.reduce(torch.matmul, matrices) | ||
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) | ||
|
||
|
||
def _index_from_letter(letter: str) -> int: | ||
if letter == "X": | ||
return 0 | ||
if letter == "Y": | ||
return 1 | ||
if letter == "Z": | ||
return 2 | ||
raise ValueError("letter must be either X, Y or Z.") | ||
|
||
|
||
def _angle_from_tan( | ||
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool | ||
) -> torch.Tensor: | ||
""" | ||
Extract the first or third Euler angle from the two members of | ||
the matrix which are positive constant times its sine and cosine. | ||
Args: | ||
axis: Axis label "X" or "Y or "Z" for the angle we are finding. | ||
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the | ||
convention. | ||
data: Rotation matrices as tensor of shape (..., 3, 3). | ||
horizontal: Whether we are looking for the angle for the third axis, | ||
which means the relevant entries are in the same row of the | ||
rotation matrix. If not, they are in the same column. | ||
tait_bryan: Whether the first and third axes in the convention differ. | ||
Returns: | ||
Euler Angles in radians for each matrix in data as a tensor | ||
of shape (...). | ||
""" | ||
|
||
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] | ||
if horizontal: | ||
i2, i1 = i1, i2 | ||
even = (axis + other_axis) in ["XY", "YZ", "ZX"] | ||
if horizontal == even: | ||
return torch.atan2(data[..., i1], data[..., i2]) | ||
if tait_bryan: | ||
return torch.atan2(-data[..., i2], data[..., i1]) | ||
return torch.atan2(data[..., i2], -data[..., i1]) | ||
|
||
|
||
def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: | ||
""" | ||
Convert rotations given as rotation matrices to Euler angles in radians. | ||
Args: | ||
matrix: Rotation matrices as tensor of shape (..., 3, 3). | ||
convention: Convention string of three uppercase letters. | ||
Returns: | ||
Euler angles in radians as tensor of shape (..., 3). | ||
""" | ||
if len(convention) != 3: | ||
raise ValueError("Convention must have 3 letters.") | ||
if convention[1] in (convention[0], convention[2]): | ||
raise ValueError(f"Invalid convention {convention}.") | ||
for letter in convention: | ||
if letter not in ("X", "Y", "Z"): | ||
raise ValueError(f"Invalid letter {letter} in convention string.") | ||
if matrix.size(-1) != 3 or matrix.size(-2) != 3: | ||
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") | ||
i0 = _index_from_letter(convention[0]) | ||
i2 = _index_from_letter(convention[2]) | ||
tait_bryan = i0 != i2 | ||
if tait_bryan: | ||
central_angle = torch.asin( | ||
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) | ||
) | ||
else: | ||
central_angle = torch.acos(matrix[..., i0, i0]) | ||
|
||
o = ( | ||
_angle_from_tan( | ||
convention[0], convention[1], matrix[..., i2], False, tait_bryan | ||
), | ||
central_angle, | ||
_angle_from_tan( | ||
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan | ||
), | ||
) | ||
return torch.stack(o, -1) | ||
|
||
|
||
class EulerToRotationMatrix(torch.nn.Module): | ||
def __init__(self, order: str = "XYZ") -> None: | ||
super().__init__() | ||
self.order = order | ||
|
||
def forward(self, euler: torch.Tensor) -> torch.Tensor: | ||
return euler_angles_to_matrix(euler, self.order) | ||
|
||
|
||
class RotationMatrixToEuler(torch.nn.Module): | ||
def __init__(self, order: str = "XYZ") -> None: | ||
super().__init__() | ||
self.order = order | ||
|
||
def forward(self, rotation: torch.Tensor) -> torch.Tensor: | ||
return matrix_to_euler_angles(rotation, self.order) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +0,0 @@ | ||
# from moai.monads.human.body.init_translation import InitTranslation | ||
# from moai.monads.human.body.joint_regressor import JointRegressor | ||
# from moai.monads.human.body.transfer import BodyTransfer | ||
|
||
# __all__ = [ | ||
# 'InitTranslation', | ||
# 'JointRegressor', | ||
# 'BodyTransfer', | ||
# ] | ||
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import typing | ||
|
||
import numpy as np | ||
import torch | ||
|
||
__all__ = ["ForwardKinematics"] | ||
|
||
|
||
class ForwardKinematics(torch.nn.Module): | ||
def __init__( | ||
self, parents | ||
): # TODO: add a col/row major param to adjust offset slicing | ||
super().__init__() | ||
self.parents = parents # TODO: register buffer from list? | ||
|
||
def forward( | ||
self, # TODO: add parents tensor input? | ||
rotation: torch.Tensor, # [B, (T), J, 3, 3] | ||
position: torch.Tensor, # [B, (T), 3] | ||
offset: torch.Tensor, # [B, (T), J, 3] | ||
parents: typing.Optional[torch.Tensor] = None, # [B, J] | ||
) -> typing.Dict[str, torch.Tensor]: # { [B, (T), J, 3], [B, (T), J, 3, 3] } | ||
joints = torch.empty(rotation.shape[:-1], device=rotation.device) | ||
joints[..., 0, :] = position.clone() # first joint according to global position | ||
offset = offset[ | ||
:, np.newaxis, ..., np.newaxis | ||
] # NOTE: careful, col vs row major order | ||
# offset = offset[np.newaxis, :, np.newaxis, :] #NOTE: careful, col vs row major order | ||
global_rotation = rotation.clone() | ||
# global_rotation = torch.empty(rotation.shape, device=rotation.device) | ||
# global_rotation[..., 0, :3, :3] = rotation[..., 0, :3, :3].clone() | ||
# NOTE: currently the op does not support per batch item parents | ||
parent_indices = ( | ||
parents[0].detach().cpu() | ||
if parents is not None | ||
else self.parents[0].detach().cpu() | ||
) | ||
if ( | ||
parent_indices.shape[-1] == offset.shape[-3] | ||
): # NOTE: to support using the same parents everywhere | ||
parent_indices = parent_indices[1:] | ||
for current_idx, parent_idx in enumerate( | ||
parent_indices, start=1 | ||
): # NOTE: assumes parents exclude root | ||
joints[..., current_idx, :] = torch.matmul( | ||
global_rotation[..., parent_idx, :, :], offset[..., current_idx, :, :] | ||
).squeeze(-1) | ||
global_rotation[..., current_idx, :, :] = torch.matmul( | ||
global_rotation[..., parent_idx, :, :].clone(), | ||
rotation[..., current_idx, :, :].clone(), | ||
) | ||
joints[..., current_idx, :] += joints[..., parent_idx, :] | ||
|
||
return { | ||
"positions": joints, | ||
"rotations": global_rotation, | ||
} |