From 30892cdb232539fb222a91cfd2575eb0ac3a7b8f Mon Sep 17 00:00:00 2001 From: ConnorTingley <50556748+ConnorTingley@users.noreply.github.com> Date: Thu, 16 Jan 2025 02:53:33 -0800 Subject: [PATCH] JAX Jacobian of from_axis_angle producing NaNs (#339) --- src/jaxsim/math/rotation.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/src/jaxsim/math/rotation.py b/src/jaxsim/math/rotation.py index 58d730ee7..bcbe98a5f 100644 --- a/src/jaxsim/math/rotation.py +++ b/src/jaxsim/math/rotation.py @@ -68,26 +68,17 @@ def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix: vector = vector.squeeze() - def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix: + theta = safe_norm(vector) - v = axis - theta = safe_norm(v) + s = jnp.sin(theta) + c = jnp.cos(theta) - s = jnp.sin(theta) - c = jnp.cos(theta) + c1 = 2 * jnp.sin(theta / 2.0) ** 2 - c1 = 2 * jnp.sin(theta / 2.0) ** 2 + safe_theta = jnp.where(theta == 0, 1.0, theta) + u = vector / safe_theta + u = jnp.vstack(u.squeeze()) - u = v / theta - u = jnp.vstack(u.squeeze()) + R = c * jnp.eye(3) - s * Skew.wedge(u) + c1 * u @ u.T - R = c * jnp.eye(3) - s * Skew.wedge(u) + c1 * u @ u.T - - return R.transpose() - - return jnp.where( - jnp.allclose(vector, 0.0), - # Return an identity rotation matrix when the input vector is zero. - jnp.eye(3), - theta_is_not_zero(axis=vector), - ) + return R.transpose()