diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index ad7edcec1..ac0ec41da 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -939,9 +939,9 @@ def free_floating_mass_matrix( raise ValueError(data.velocity_representation) -@jax.jit +@functools.partial(jax.jit, static_argnames=["prefer_rbd"]) def free_floating_coriolis_matrix( - model: JaxSimModel, data: js.data.JaxSimModelData + model: JaxSimModel, data: js.data.JaxSimModelData, prefer_rbd: bool = True ) -> jtp.Matrix: """ Compute the free-floating Coriolis matrix of the model. @@ -949,6 +949,9 @@ def free_floating_coriolis_matrix( Args: model: The model to consider. data: The data of the considered model. + prefer_rbd: + Whether to prefer the RBD algorithm over the computation that uses + the Jacobians. Returns: The free-floating Coriolis matrix of the model. @@ -958,52 +961,76 @@ def free_floating_coriolis_matrix( does not exploit any iterative algorithm. Therefore, the computation of the Coriolis matrix may be much slower than other quantities. """ + if prefer_rbd: + # Extract the link and joint serializations. + joint_names = model.joint_names() - # We perform all the calculation in body-fixed. - # The Coriolis matrix computed in this representation is converted later - # to the active representation stored in data. - with data.switch_velocity_representation(VelRepr.Body): + # Extract the state in inertial-fixed representation. + with data.switch_velocity_representation(VelRepr.Inertial): + W_p_B = data.base_position() + W_v_WB = data.base_velocity() + W_Q_B = data.base_orientation(dcm=False) + s = data.joint_positions(model=model, joint_names=joint_names) + ṡ = data.joint_velocities(model=model, joint_names=joint_names) - B_ν = data.generalized_velocity() + M_B, Ṁ_B, C_B = jaxsim.rbda.coriolis( # noqa: F841 + model=model, + base_position=W_p_B, + base_quaternion=W_Q_B, + joint_positions=s, + base_linear_velocity=W_v_WB[0:3], + base_angular_velocity=W_v_WB[3:6], + joint_velocities=ṡ, + standard_gravity=data.standard_gravity(), + ) - # Doubly-left free-floating Jacobian. - L_J_WL_B = generalized_free_floating_jacobian(model=model, data=data) + else: - # Doubly-left free-floating Jacobian derivative. - L_J̇_WL_B = jax.vmap( - lambda link_index: js.link.jacobian_derivative( - model=model, data=data, link_index=link_index - ) - )(js.link.names_to_idxs(model=model, link_names=model.link_names())) + # We perform all the calculation in body-fixed. + # The Coriolis matrix computed in this representation is converted later + # to the active representation stored in data. + with data.switch_velocity_representation(VelRepr.Body): - L_M_L = link_spatial_inertia_matrices(model=model) + B_ν = data.generalized_velocity() - # Body-fixed link velocities. - # Note: we could have called link.velocity() instead of computing it ourselves, - # but since we need the link Jacobians later, we can save a double calculation. - L_v_WL = jax.vmap(lambda J: J @ B_ν)(L_J_WL_B) + # Doubly-left free-floating Jacobian. + L_J_WL_B = generalized_free_floating_jacobian(model=model, data=data) - # Compute the contribution of each link to the Coriolis matrix. - def compute_link_contribution(M, v, J, J̇) -> jtp.Array: + # Doubly-left free-floating Jacobian derivative. + L_J̇_WL_B = jax.vmap( + lambda link_index: js.link.jacobian_derivative( + model=model, data=data, link_index=link_index + ) + )(js.link.names_to_idxs(model=model, link_names=model.link_names())) - return J.T @ ((Cross.vx_star(v) @ M + M @ Cross.vx(v)) @ J + M @ J̇) + L_M_L = link_spatial_inertia_matrices(model=model) - C_B_links = jax.vmap(compute_link_contribution)( - L_M_L, - L_v_WL, - L_J_WL_B, - L_J̇_WL_B, - ) + # Body-fixed link velocities. + # Note: we could have called link.velocity() instead of computing it ourselves, + # but since we need the link Jacobians later, we can save a double calculation. + L_v_WL = jax.vmap(lambda J: J @ B_ν)(L_J_WL_B) - # We need to adjust the Coriolis matrix for fixed-base models. - # In this case, the base link does not contribute to the matrix, and we need to zero - # the off-diagonal terms mapping joint quantities onto the base configuration. - if model.floating_base(): - C_B = C_B_links.sum(axis=0) - else: - C_B = C_B_links[1:].sum(axis=0) - C_B = C_B.at[0:6, 6:].set(0.0) - C_B = C_B.at[6:, 0:6].set(0.0) + # Compute the contribution of each link to the Coriolis matrix. + def compute_link_contribution(M, v, J, J̇) -> jtp.Array: + + return J.T @ ((Cross.vx_star(v) @ M + M @ Cross.vx(v)) @ J + M @ J̇) + + C_B_links = jax.vmap(compute_link_contribution)( + L_M_L, + L_v_WL, + L_J_WL_B, + L_J̇_WL_B, + ) + + # We need to adjust the Coriolis matrix for fixed-base models. + # In this case, the base link does not contribute to the matrix, and we need to zero + # the off-diagonal terms mapping joint quantities onto the base configuration. + if model.floating_base(): + C_B = C_B_links.sum(axis=0) + else: + C_B = C_B_links[1:].sum(axis=0) + C_B = C_B.at[0:6, 6:].set(0.0) + C_B = C_B.at[6:, 0:6].set(0.0) # Adjust the representation of the Coriolis matrix. # Refer to https://github.com/traversaro/traversaro-phd-thesis, Section 3.6. diff --git a/src/jaxsim/rbda/coriolis.py b/src/jaxsim/rbda/coriolis.py new file mode 100644 index 000000000..176f0e9df --- /dev/null +++ b/src/jaxsim/rbda/coriolis.py @@ -0,0 +1,187 @@ +import jax +import jax.numpy as jnp + +import jaxsim.api as js +import jaxsim.typing as jtp +from jaxsim.math import Adjoint, Cross, StandardGravity, Transform + +from . import utils + + +def coriolis( + model: js.model.JaxSimModel, + *, + base_position: jtp.VectorLike, + base_quaternion: jtp.VectorLike, + joint_positions: jtp.VectorLike, + base_linear_velocity: jtp.VectorLike, + base_angular_velocity: jtp.VectorLike, + joint_velocities: jtp.VectorLike, + joint_forces: jtp.VectorLike | None = None, + link_forces: jtp.MatrixLike | None = None, + standard_gravity: jtp.FloatLike = StandardGravity, +) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]: + """ + Coriolis matrix + """ + + W_p_B, W_Q_B, s, _, ṡ, _, _, _, _, _ = utils.process_inputs( + model=model, + base_position=base_position, + base_quaternion=base_quaternion, + joint_positions=joint_positions, + base_linear_velocity=base_linear_velocity, + base_angular_velocity=base_angular_velocity, + joint_velocities=joint_velocities, + standard_gravity=standard_gravity, + ) + + W_H_B = Transform.from_quaternion_and_translation( + quaternion=W_Q_B, + translation=W_p_B, + ) + + # Extract data from the physics model + pre_X_λi = model.tree_transforms + M = js.model.link_spatial_inertia_matrices(model=model) + i_X_pre, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + joint_positions=s, base_transform=W_H_B.as_matrix() + ) + λ = model.kin_dyn_parameters.parent_array + + # Initialize buffers + v = jnp.array([jnp.zeros([6, 1])] * model.number_of_links()) + Ṡ = jnp.array([jnp.zeros([6, 1])] * model.number_of_links()) + BC = jnp.array([jnp.zeros([6, 6])] * model.number_of_links()) + IC = jnp.zeros_like(M) + + i_X_λi = jnp.zeros_like(i_X_pre) + + # 6D transform of base velocity + B_X_W = Adjoint.from_quaternion_and_translation( + quaternion=W_Q_B, + translation=W_p_B, + inverse=True, + normalize_quaternion=True, + ) + i_X_λi = i_X_λi.at[0].set(B_X_W) + + # Transforms link -> base + i_X_0 = jnp.zeros_like(pre_X_λi) + i_X_0 = i_X_0.at[0].set(jnp.eye(6)) + + Pass1Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix] + + def loop_pass_1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]: + i_X_λi, v, Ṡ, BC, IC = carry + vJ = S[i] * ṡ[i] + v_i = i_X_λi[i] @ v[λ[i]] + vJ + v = v.at[i].set(v_i) + + Ṡ_i = Cross.vx(v[i]) @ S[i] + Ṡ = Ṡ.at[i].set(Ṡ_i) + + IC = IC.at[i].set(M[i]) + BC_i = ( + Cross.vx_star(v[i]) @ Cross.vx(IC[i] @ v[i]) - IC[i] @ Cross.vx(v[i]) + ) / 2 + BC = BC.at[i].set(BC_i) + + return (i_X_λi, v, Ṡ, BC, IC), None + + (i_X_λi, v, Ṡ, BC, IC), _ = ( + jax.lax.scan( + f=loop_pass_1, + init=(i_X_λi, v, Ṡ, BC, IC), + xs=jnp.arange(1, model.number_of_links() + 1), + ) + if model.number_of_links() > 1 + else [(i_X_λi, v, Ṡ, BC, IC), None] + ) + + C = jnp.zeros([model.number_of_links(), model.number_of_links()]) + M = jnp.zeros([model.number_of_links(), model.number_of_links()]) + Ṁ = jnp.zeros([model.number_of_links(), model.number_of_links()]) + + Pass2Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix] + + def loop_pass_2(carry: Pass2Carry, j: jtp.Int) -> tuple[Pass2Carry, None]: + jj = λ[j] - 1 + + C, M, Ṁ, IC, BC = carry + + F_1 = IC[j] @ Ṡ[j] + BC[j] @ S[j] + F_2 = IC[j] @ S[j] + F_3 = BC[j].T @ S[j] + + C = C.at[jj, jj].set((S[j].T @ F_1).squeeze()) + M = M.at[jj, jj].set((S[j].T @ F_2).squeeze()) + Ṁ = Ṁ.at[jj, jj].set((Ṡ[j].T @ F_2 + S[j].T @ F_3).squeeze()) + + F_1 = i_X_λi[j] @ F_1 + F_2 = i_X_λi[j] @ F_2 + F_3 = i_X_λi[j] @ F_3 + + InnerLoopCarry = tuple[ + jtp.Matrix, + jtp.Matrix, + jtp.Matrix, + jtp.Matrix, + jtp.Matrix, + jtp.Matrix, + jtp.Matrix, + ] + + def inner_loop_body(carry: InnerLoopCarry) -> tuple[InnerLoopCarry]: + C, M, Ṁ, F_1, F_2, F_3, i = carry + ii = λ[i] - 1 + + C = C.at[ii, jj].set((S[i].T @ F_1).squeeze()) + C = C.at[jj, ii].set((S[i].T @ F_1).squeeze()) + + M = M.at[ii, ii].set((S[i].T @ F_2).squeeze()) + Ṁ = Ṁ.at[ii].set((Ṡ[i].T @ F_2 + S[i].T @ F_3).squeeze()) + + F_1 = F_1 + i_X_λi[i] @ F_1 + F_2 = F_2 + i_X_λi[i] @ F_2 + F_3 = F_3 + i_X_λi[i] @ F_3 + + i = λ[i] + return C, M, Ṁ, F_1, F_2, F_3, i + + (C, M, Ṁ, F_1, F_2, F_3, _) = jax.lax.while_loop( + body_fun=inner_loop_body, + cond_fun=lambda idx: idx[-1] > 0, + init_val=(C, M, Ṁ, F_1, F_2, F_3, 0), + ) + + def propagate( + IC_BC: tuple[jtp.Matrix, jtp.Matrix] + ) -> tuple[jtp.Matrix, jtp.Matrix]: + IC, BC = IC_BC + + IC = IC.at[λ[j]].set(IC[λ[j]] + i_X_λi[j] @ IC[j] @ i_X_λi[j].T) + BC = BC.at[λ[j]].set(BC[λ[j]] + i_X_λi[j] @ BC[j] @ i_X_λi[j].T) + + return IC, BC + + IC, BC = jax.lax.cond( + pred=jnp.array([λ[j] != 0, model.is_floating_base]).any(), + true_fun=propagate, + false_fun=lambda IC_BC: IC_BC, + operand=(IC, BC), + ) + + return (C, M, Ṁ, IC, BC), None + + (C, M, Ṁ, IC, BC), _ = ( + jax.lax.scan( + f=loop_pass_2, + init=(C, M, Ṁ, IC, BC), + xs=jnp.flip(jnp.arange(1, model.number_of_links() + 1)), + ) + if model.number_of_links() > 1 + else [(C, M, Ṁ, IC, BC), None] + ) + + return M, Ṁ, C diff --git a/tests/test_api_model.py b/tests/test_api_model.py index 1c57f283f..e23da2937 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -407,7 +407,7 @@ def test_coriolis_matrix( # ===== I_ν = data.generalized_velocity() - C = js.model.free_floating_coriolis_matrix(model=model, data=data) + C = js.model.free_floating_coriolis_matrix(model=model, data=data, prefer_rbd=False) h = js.model.free_floating_bias_forces(model=model, data=data) g = js.model.free_floating_gravity_forces(model=model, data=data) @@ -477,6 +477,15 @@ def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: # Ensure that (Ṁ - 2C) is skew symmetric. assert Ṁ - C - C.T == pytest.approx(0) + M = js.model.free_floating_mass_matrix(model=model, data=data) + + M_rbd, _, C_rbd = js.model.free_floating_coriolis_matrix( + model=model, data=data, prefer_rbd=True + ) + + assert C == pytest.approx(C_rbd) + assert M == pytest.approx(M_rbd) + def test_model_fd_id_consistency( jaxsim_models_types: js.model.JaxSimModel,