diff --git a/moai/monads/human/pose/forward_kinematics.py b/moai/monads/human/pose/forward_kinematics.py index 9e8aec9..29352e6 100644 --- a/moai/monads/human/pose/forward_kinematics.py +++ b/moai/monads/human/pose/forward_kinematics.py @@ -27,20 +27,19 @@ def forward( self, # TODO: add parents tensor input? rotation: torch.Tensor, # [B, (T), J, 3, 3] position: torch.Tensor, # [B, (T), 3] - offset: typing.Optional[torch.Tensor] = None, # [B, (T), J, 3] + offsets: typing.Optional[torch.Tensor] = None, # [B, (T), J, 3] parents: typing.Optional[torch.Tensor] = None, # [B, J] ) -> typing.Dict[str, torch.Tensor]: # { [B, (T), J, 3], [B, (T), J, 3, 3] } - joints = torch.empty(rotation.shape[:-1], device=rotation.device) - joints[..., 0, :] = position.clone() # first joint according to global position - offset = ( - offset[:, np.newaxis, ..., np.newaxis] - if offset is not None + offsets = ( + offsets[:, np.newaxis, ..., np.newaxis] + if offsets is not None else self.offsets[:, np.newaxis, ..., np.newaxis] ) # NOTE: careful, col vs row major order - # offset = offset[np.newaxis, :, np.newaxis, :] #NOTE: careful, col vs row major order - global_rotation = rotation.clone() - # global_rotation = torch.empty(rotation.shape, device=rotation.device) - # global_rotation[..., 0, :3, :3] = rotation[..., 0, :3, :3].clone() + transforms = torch.empty(*rotation.shape[:-2], 4, 4, device=rotation.device) + transforms[..., :3, :3] = rotation.clone() + transforms[..., :3, 3] = offsets[..., 0].clone() + transforms[..., 0, :3, 3] = position.clone() + transforms[..., 3, 3] = 1.0 # NOTE: currently the op does not support per batch item parents parent_indices = ( parents[0].detach().cpu() @@ -48,22 +47,23 @@ def forward( else (self.parents[0].detach().cpu()) ) if ( - parent_indices.shape[-1] == offset.shape[-3] + parent_indices.shape[-1] == offsets.shape[-3] ): # NOTE: to support using the same parents everywhere parent_indices = parent_indices[1:] + composed = [transforms[..., 0, :, :]] for current_idx, parent_idx in enumerate( parent_indices, start=1 ): # NOTE: assumes parents exclude root - joints[..., current_idx, :] = torch.matmul( - global_rotation[..., parent_idx, :, :], offset[..., current_idx, :, :] - ).squeeze(-1) - global_rotation[..., current_idx, :, :] = torch.matmul( - global_rotation[..., parent_idx, :, :].clone(), - rotation[..., current_idx, :, :].clone(), + composed.append( + torch.matmul( + composed[parent_idx], + transforms[..., current_idx, :, :], + ) ) - joints[..., current_idx, :] += joints[..., parent_idx, :] + composed = torch.stack(composed, dim=-3) + joints = composed[..., :3, 3] return { "positions": joints, - "rotations": global_rotation, + "bone_transforms": transforms, }