Skip to content

Commit

Permalink
feat: add forward kinematics & euler angle conversions monads
Browse files Browse the repository at this point in the history
  • Loading branch information
nmvrs committed Jun 6, 2024
1 parent a77da8e commit 525391c
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 132 deletions.
4 changes: 4 additions & 0 deletions moai/conf/model/monads/geometry/rotation/euler_to_rotmat.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# @package model.monads.euler_to_rotmat

_target_: moai.monads.geometry.rotations.euler.EulerToRotationMatrix
order: XYZ
4 changes: 4 additions & 0 deletions moai/conf/model/monads/geometry/rotation/rotmat_to_euler.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# @package model.monads.rotmat_to_euler

_target_: moai.monads.geometry.rotations.euler.RotationMatrixToEuler
order: XYZ
4 changes: 4 additions & 0 deletions moai/conf/model/monads/human/pose/forward_kinematics.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# @package model.monads.forward_kinematics

_target_: moai.monads.human.pose.forward_kinematics.ForwardKinematics
parents: ??? # int[]
167 changes: 167 additions & 0 deletions moai/monads/geometry/rotations/euler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# NOTE: adapted from PyTorch3D

import torch

__all__ = ["euler_angles_to_matrix", "EulerToRotationMatrix"]


def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
"""
Return the rotation matrices for one of the rotations about an axis
of which Euler angles describe, for each value of the angle given.
Args:
axis: Axis label "X" or "Y or "Z".
angle: any shape tensor of Euler angles in radians
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""

cos = torch.cos(angle)
sin = torch.sin(angle)
one = torch.ones_like(angle)
zero = torch.zeros_like(angle)

if axis == "X":
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
elif axis == "Y":
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
elif axis == "Z":
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
else:
raise ValueError("letter must be either X, Y or Z.")

return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))


def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
"""
Convert rotations given as Euler angles in radians to rotation matrices.
Args:
euler_angles: Euler angles in radians as tensor of shape (..., 3).
convention: Convention string of three uppercase letters from
{"X", "Y", and "Z"}.
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
raise ValueError("Invalid input euler angles.")
if len(convention) != 3:
raise ValueError("Convention must have 3 letters.")
if convention[1] in (convention[0], convention[2]):
raise ValueError(f"Invalid convention {convention}.")
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
matrices = [
_axis_angle_rotation(c, e)
for c, e in zip(convention, torch.unbind(euler_angles, -1))
]
# return functools.reduce(torch.matmul, matrices)
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])


def _index_from_letter(letter: str) -> int:
if letter == "X":
return 0
if letter == "Y":
return 1
if letter == "Z":
return 2
raise ValueError("letter must be either X, Y or Z.")


def _angle_from_tan(
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
) -> torch.Tensor:
"""
Extract the first or third Euler angle from the two members of
the matrix which are positive constant times its sine and cosine.
Args:
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
convention.
data: Rotation matrices as tensor of shape (..., 3, 3).
horizontal: Whether we are looking for the angle for the third axis,
which means the relevant entries are in the same row of the
rotation matrix. If not, they are in the same column.
tait_bryan: Whether the first and third axes in the convention differ.
Returns:
Euler Angles in radians for each matrix in data as a tensor
of shape (...).
"""

i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
if horizontal:
i2, i1 = i1, i2
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
if horizontal == even:
return torch.atan2(data[..., i1], data[..., i2])
if tait_bryan:
return torch.atan2(-data[..., i2], data[..., i1])
return torch.atan2(data[..., i2], -data[..., i1])


def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to Euler angles in radians.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
convention: Convention string of three uppercase letters.
Returns:
Euler angles in radians as tensor of shape (..., 3).
"""
if len(convention) != 3:
raise ValueError("Convention must have 3 letters.")
if convention[1] in (convention[0], convention[2]):
raise ValueError(f"Invalid convention {convention}.")
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
i0 = _index_from_letter(convention[0])
i2 = _index_from_letter(convention[2])
tait_bryan = i0 != i2
if tait_bryan:
central_angle = torch.asin(
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
)
else:
central_angle = torch.acos(matrix[..., i0, i0])

o = (
_angle_from_tan(
convention[0], convention[1], matrix[..., i2], False, tait_bryan
),
central_angle,
_angle_from_tan(
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
),
)
return torch.stack(o, -1)


class EulerToRotationMatrix(torch.nn.Module):
def __init__(self, order: str = "XYZ") -> None:
super().__init__()
self.order = order

def forward(self, euler: torch.Tensor) -> torch.Tensor:
return euler_angles_to_matrix(euler, self.order)


class RotationMatrixToEuler(torch.nn.Module):
def __init__(self, order: str = "XYZ") -> None:
super().__init__()
self.order = order

def forward(self, rotation: torch.Tensor) -> torch.Tensor:
return matrix_to_euler_angles(rotation, self.order)
9 changes: 0 additions & 9 deletions moai/monads/human/body/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +0,0 @@
# from moai.monads.human.body.init_translation import InitTranslation
# from moai.monads.human.body.joint_regressor import JointRegressor
# from moai.monads.human.body.transfer import BodyTransfer

# __all__ = [
# 'InitTranslation',
# 'JointRegressor',
# 'BodyTransfer',
# ]
123 changes: 0 additions & 123 deletions moai/monads/human/body/kinematics.py

This file was deleted.

57 changes: 57 additions & 0 deletions moai/monads/human/pose/forward_kinematics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import typing

import numpy as np
import torch

__all__ = ["ForwardKinematics"]


class ForwardKinematics(torch.nn.Module):
def __init__(
self, parents
): # TODO: add a col/row major param to adjust offset slicing
super().__init__()
self.parents = parents # TODO: register buffer from list?

def forward(
self, # TODO: add parents tensor input?
rotation: torch.Tensor, # [B, (T), J, 3, 3]
position: torch.Tensor, # [B, (T), 3]
offset: torch.Tensor, # [B, (T), J, 3]
parents: typing.Optional[torch.Tensor] = None, # [B, J]
) -> typing.Dict[str, torch.Tensor]: # { [B, (T), J, 3], [B, (T), J, 3, 3] }
joints = torch.empty(rotation.shape[:-1], device=rotation.device)
joints[..., 0, :] = position.clone() # first joint according to global position
offset = offset[
:, np.newaxis, ..., np.newaxis
] # NOTE: careful, col vs row major order
# offset = offset[np.newaxis, :, np.newaxis, :] #NOTE: careful, col vs row major order
global_rotation = rotation.clone()
# global_rotation = torch.empty(rotation.shape, device=rotation.device)
# global_rotation[..., 0, :3, :3] = rotation[..., 0, :3, :3].clone()
# NOTE: currently the op does not support per batch item parents
parent_indices = (
parents[0].detach().cpu()
if parents is not None
else self.parents[0].detach().cpu()
)
if (
parent_indices.shape[-1] == offset.shape[-3]
): # NOTE: to support using the same parents everywhere
parent_indices = parent_indices[1:]
for current_idx, parent_idx in enumerate(
parent_indices, start=1
): # NOTE: assumes parents exclude root
joints[..., current_idx, :] = torch.matmul(
global_rotation[..., parent_idx, :, :], offset[..., current_idx, :, :]
).squeeze(-1)
global_rotation[..., current_idx, :, :] = torch.matmul(
global_rotation[..., parent_idx, :, :].clone(),
rotation[..., current_idx, :, :].clone(),
)
joints[..., current_idx, :] += joints[..., parent_idx, :]

return {
"positions": joints,
"rotations": global_rotation,
}

0 comments on commit 525391c

Please sign in to comment.