Skip to content

Commit

Permalink
JAX Jacobian of from_axis_angle producing NaNs (#339)
Browse files Browse the repository at this point in the history
  • Loading branch information
ConnorTingley authored Jan 16, 2025
1 parent b9eadd3 commit 30892cd
Showing 1 changed file with 9 additions and 18 deletions.
27 changes: 9 additions & 18 deletions src/jaxsim/math/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,26 +68,17 @@ def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix:

vector = vector.squeeze()

def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix:
theta = safe_norm(vector)

v = axis
theta = safe_norm(v)
s = jnp.sin(theta)
c = jnp.cos(theta)

s = jnp.sin(theta)
c = jnp.cos(theta)
c1 = 2 * jnp.sin(theta / 2.0) ** 2

c1 = 2 * jnp.sin(theta / 2.0) ** 2
safe_theta = jnp.where(theta == 0, 1.0, theta)
u = vector / safe_theta
u = jnp.vstack(u.squeeze())

u = v / theta
u = jnp.vstack(u.squeeze())
R = c * jnp.eye(3) - s * Skew.wedge(u) + c1 * u @ u.T

R = c * jnp.eye(3) - s * Skew.wedge(u) + c1 * u @ u.T

return R.transpose()

return jnp.where(
jnp.allclose(vector, 0.0),
# Return an identity rotation matrix when the input vector is zero.
jnp.eye(3),
theta_is_not_zero(axis=vector),
)
return R.transpose()

0 comments on commit 30892cd

Please sign in to comment.