From c0417f0e4f0f17879b1cdbfe846b111070a0f125 Mon Sep 17 00:00:00 2001
From: Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
Date: Thu, 12 Sep 2024 00:07:30 +0200
Subject: [PATCH 1/3] [ci skip] Add RBD to compute the Coriolis matrix

---
 src/jaxsim/rbda/coriolis.py | 187 ++++++++++++++++++++++++++++++++++++
 1 file changed, 187 insertions(+)
 create mode 100644 src/jaxsim/rbda/coriolis.py

diff --git a/src/jaxsim/rbda/coriolis.py b/src/jaxsim/rbda/coriolis.py
new file mode 100644
index 000000000..176f0e9df
--- /dev/null
+++ b/src/jaxsim/rbda/coriolis.py
@@ -0,0 +1,187 @@
+import jax
+import jax.numpy as jnp
+
+import jaxsim.api as js
+import jaxsim.typing as jtp
+from jaxsim.math import Adjoint, Cross, StandardGravity, Transform
+
+from . import utils
+
+
+def coriolis(
+    model: js.model.JaxSimModel,
+    *,
+    base_position: jtp.VectorLike,
+    base_quaternion: jtp.VectorLike,
+    joint_positions: jtp.VectorLike,
+    base_linear_velocity: jtp.VectorLike,
+    base_angular_velocity: jtp.VectorLike,
+    joint_velocities: jtp.VectorLike,
+    joint_forces: jtp.VectorLike | None = None,
+    link_forces: jtp.MatrixLike | None = None,
+    standard_gravity: jtp.FloatLike = StandardGravity,
+) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
+    """
+    Coriolis matrix
+    """
+
+    W_p_B, W_Q_B, s, _, ṡ, _, _, _, _, _ = utils.process_inputs(
+        model=model,
+        base_position=base_position,
+        base_quaternion=base_quaternion,
+        joint_positions=joint_positions,
+        base_linear_velocity=base_linear_velocity,
+        base_angular_velocity=base_angular_velocity,
+        joint_velocities=joint_velocities,
+        standard_gravity=standard_gravity,
+    )
+
+    W_H_B = Transform.from_quaternion_and_translation(
+        quaternion=W_Q_B,
+        translation=W_p_B,
+    )
+
+    # Extract data from the physics model
+    pre_X_λi = model.tree_transforms
+    M = js.model.link_spatial_inertia_matrices(model=model)
+    i_X_pre, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
+        joint_positions=s, base_transform=W_H_B.as_matrix()
+    )
+    λ = model.kin_dyn_parameters.parent_array
+
+    # Initialize buffers
+    v = jnp.array([jnp.zeros([6, 1])] * model.number_of_links())
+    Ṡ = jnp.array([jnp.zeros([6, 1])] * model.number_of_links())
+    BC = jnp.array([jnp.zeros([6, 6])] * model.number_of_links())
+    IC = jnp.zeros_like(M)
+
+    i_X_λi = jnp.zeros_like(i_X_pre)
+
+    # 6D transform of base velocity
+    B_X_W = Adjoint.from_quaternion_and_translation(
+        quaternion=W_Q_B,
+        translation=W_p_B,
+        inverse=True,
+        normalize_quaternion=True,
+    )
+    i_X_λi = i_X_λi.at[0].set(B_X_W)
+
+    # Transforms link -> base
+    i_X_0 = jnp.zeros_like(pre_X_λi)
+    i_X_0 = i_X_0.at[0].set(jnp.eye(6))
+
+    Pass1Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
+
+    def loop_pass_1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]:
+        i_X_λi, v, Ṡ, BC, IC = carry
+        vJ = S[i] * ṡ[i]
+        v_i = i_X_λi[i] @ v[λ[i]] + vJ
+        v = v.at[i].set(v_i)
+
+        Ṡ_i = Cross.vx(v[i]) @ S[i]
+        Ṡ = Ṡ.at[i].set(Ṡ_i)
+
+        IC = IC.at[i].set(M[i])
+        BC_i = (
+            Cross.vx_star(v[i]) @ Cross.vx(IC[i] @ v[i]) - IC[i] @ Cross.vx(v[i])
+        ) / 2
+        BC = BC.at[i].set(BC_i)
+
+        return (i_X_λi, v, Ṡ, BC, IC), None
+
+    (i_X_λi, v, Ṡ, BC, IC), _ = (
+        jax.lax.scan(
+            f=loop_pass_1,
+            init=(i_X_λi, v, Ṡ, BC, IC),
+            xs=jnp.arange(1, model.number_of_links() + 1),
+        )
+        if model.number_of_links() > 1
+        else [(i_X_λi, v, Ṡ, BC, IC), None]
+    )
+
+    C = jnp.zeros([model.number_of_links(), model.number_of_links()])
+    M = jnp.zeros([model.number_of_links(), model.number_of_links()])
+    Ṁ = jnp.zeros([model.number_of_links(), model.number_of_links()])
+
+    Pass2Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
+
+    def loop_pass_2(carry: Pass2Carry, j: jtp.Int) -> tuple[Pass2Carry, None]:
+        jj = λ[j] - 1
+
+        C, M, Ṁ, IC, BC = carry
+
+        F_1 = IC[j] @ Ṡ[j] + BC[j] @ S[j]
+        F_2 = IC[j] @ S[j]
+        F_3 = BC[j].T @ S[j]
+
+        C = C.at[jj, jj].set((S[j].T @ F_1).squeeze())
+        M = M.at[jj, jj].set((S[j].T @ F_2).squeeze())
+        Ṁ = Ṁ.at[jj, jj].set((Ṡ[j].T @ F_2 + S[j].T @ F_3).squeeze())
+
+        F_1 = i_X_λi[j] @ F_1
+        F_2 = i_X_λi[j] @ F_2
+        F_3 = i_X_λi[j] @ F_3
+
+        InnerLoopCarry = tuple[
+            jtp.Matrix,
+            jtp.Matrix,
+            jtp.Matrix,
+            jtp.Matrix,
+            jtp.Matrix,
+            jtp.Matrix,
+            jtp.Matrix,
+        ]
+
+        def inner_loop_body(carry: InnerLoopCarry) -> tuple[InnerLoopCarry]:
+            C, M, Ṁ, F_1, F_2, F_3, i = carry
+            ii = λ[i] - 1
+
+            C = C.at[ii, jj].set((S[i].T @ F_1).squeeze())
+            C = C.at[jj, ii].set((S[i].T @ F_1).squeeze())
+
+            M = M.at[ii, ii].set((S[i].T @ F_2).squeeze())
+            Ṁ = Ṁ.at[ii].set((Ṡ[i].T @ F_2 + S[i].T @ F_3).squeeze())
+
+            F_1 = F_1 + i_X_λi[i] @ F_1
+            F_2 = F_2 + i_X_λi[i] @ F_2
+            F_3 = F_3 + i_X_λi[i] @ F_3
+
+            i = λ[i]
+            return C, M, Ṁ, F_1, F_2, F_3, i
+
+        (C, M, Ṁ, F_1, F_2, F_3, _) = jax.lax.while_loop(
+            body_fun=inner_loop_body,
+            cond_fun=lambda idx: idx[-1] > 0,
+            init_val=(C, M, Ṁ, F_1, F_2, F_3, 0),
+        )
+
+        def propagate(
+            IC_BC: tuple[jtp.Matrix, jtp.Matrix]
+        ) -> tuple[jtp.Matrix, jtp.Matrix]:
+            IC, BC = IC_BC
+
+            IC = IC.at[λ[j]].set(IC[λ[j]] + i_X_λi[j] @ IC[j] @ i_X_λi[j].T)
+            BC = BC.at[λ[j]].set(BC[λ[j]] + i_X_λi[j] @ BC[j] @ i_X_λi[j].T)
+
+            return IC, BC
+
+        IC, BC = jax.lax.cond(
+            pred=jnp.array([λ[j] != 0, model.is_floating_base]).any(),
+            true_fun=propagate,
+            false_fun=lambda IC_BC: IC_BC,
+            operand=(IC, BC),
+        )
+
+        return (C, M, Ṁ, IC, BC), None
+
+    (C, M, Ṁ, IC, BC), _ = (
+        jax.lax.scan(
+            f=loop_pass_2,
+            init=(C, M, Ṁ, IC, BC),
+            xs=jnp.flip(jnp.arange(1, model.number_of_links() + 1)),
+        )
+        if model.number_of_links() > 1
+        else [(C, M, Ṁ, IC, BC), None]
+    )
+
+    return M, Ṁ, C

From baa4de8788c2d41ce6cbe1bbf63eec3782335c0c Mon Sep 17 00:00:00 2001
From: Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
Date: Thu, 12 Sep 2024 00:09:34 +0200
Subject: [PATCH 2/3] [ci skip] Add options to compute the Coriolis matrix from
 RBD

---
 src/jaxsim/api/model.py | 103 +++++++++++++++++++++++++---------------
 1 file changed, 65 insertions(+), 38 deletions(-)

diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py
index ad7edcec1..ac0ec41da 100644
--- a/src/jaxsim/api/model.py
+++ b/src/jaxsim/api/model.py
@@ -939,9 +939,9 @@ def free_floating_mass_matrix(
             raise ValueError(data.velocity_representation)
 
 
-@jax.jit
+@functools.partial(jax.jit, static_argnames=["prefer_rbd"])
 def free_floating_coriolis_matrix(
-    model: JaxSimModel, data: js.data.JaxSimModelData
+    model: JaxSimModel, data: js.data.JaxSimModelData, prefer_rbd: bool = True
 ) -> jtp.Matrix:
     """
     Compute the free-floating Coriolis matrix of the model.
@@ -949,6 +949,9 @@ def free_floating_coriolis_matrix(
     Args:
         model: The model to consider.
         data: The data of the considered model.
+        prefer_rbd:
+            Whether to prefer the RBD algorithm over the computation that uses
+            the Jacobians.
 
     Returns:
         The free-floating Coriolis matrix of the model.
@@ -958,52 +961,76 @@ def free_floating_coriolis_matrix(
         does not exploit any iterative algorithm. Therefore, the computation of
         the Coriolis matrix may be much slower than other quantities.
     """
+    if prefer_rbd:
+        # Extract the link and joint serializations.
+        joint_names = model.joint_names()
 
-    # We perform all the calculation in body-fixed.
-    # The Coriolis matrix computed in this representation is converted later
-    # to the active representation stored in data.
-    with data.switch_velocity_representation(VelRepr.Body):
+        # Extract the state in inertial-fixed representation.
+        with data.switch_velocity_representation(VelRepr.Inertial):
+            W_p_B = data.base_position()
+            W_v_WB = data.base_velocity()
+            W_Q_B = data.base_orientation(dcm=False)
+            s = data.joint_positions(model=model, joint_names=joint_names)
+            ṡ = data.joint_velocities(model=model, joint_names=joint_names)
 
-        B_ν = data.generalized_velocity()
+        M_B, Ṁ_B, C_B = jaxsim.rbda.coriolis(  # noqa: F841
+            model=model,
+            base_position=W_p_B,
+            base_quaternion=W_Q_B,
+            joint_positions=s,
+            base_linear_velocity=W_v_WB[0:3],
+            base_angular_velocity=W_v_WB[3:6],
+            joint_velocities=ṡ,
+            standard_gravity=data.standard_gravity(),
+        )
 
-        # Doubly-left free-floating Jacobian.
-        L_J_WL_B = generalized_free_floating_jacobian(model=model, data=data)
+    else:
 
-        # Doubly-left free-floating Jacobian derivative.
-        L_J̇_WL_B = jax.vmap(
-            lambda link_index: js.link.jacobian_derivative(
-                model=model, data=data, link_index=link_index
-            )
-        )(js.link.names_to_idxs(model=model, link_names=model.link_names()))
+        # We perform all the calculation in body-fixed.
+        # The Coriolis matrix computed in this representation is converted later
+        # to the active representation stored in data.
+        with data.switch_velocity_representation(VelRepr.Body):
 
-    L_M_L = link_spatial_inertia_matrices(model=model)
+            B_ν = data.generalized_velocity()
 
-    # Body-fixed link velocities.
-    # Note: we could have called link.velocity() instead of computing it ourselves,
-    # but since we need the link Jacobians later, we can save a double calculation.
-    L_v_WL = jax.vmap(lambda J: J @ B_ν)(L_J_WL_B)
+            # Doubly-left free-floating Jacobian.
+            L_J_WL_B = generalized_free_floating_jacobian(model=model, data=data)
 
-    # Compute the contribution of each link to the Coriolis matrix.
-    def compute_link_contribution(M, v, J, J̇) -> jtp.Array:
+            # Doubly-left free-floating Jacobian derivative.
+            L_J̇_WL_B = jax.vmap(
+                lambda link_index: js.link.jacobian_derivative(
+                    model=model, data=data, link_index=link_index
+                )
+            )(js.link.names_to_idxs(model=model, link_names=model.link_names()))
 
-        return J.T @ ((Cross.vx_star(v) @ M + M @ Cross.vx(v)) @ J + M @ J̇)
+        L_M_L = link_spatial_inertia_matrices(model=model)
 
-    C_B_links = jax.vmap(compute_link_contribution)(
-        L_M_L,
-        L_v_WL,
-        L_J_WL_B,
-        L_J̇_WL_B,
-    )
+        # Body-fixed link velocities.
+        # Note: we could have called link.velocity() instead of computing it ourselves,
+        # but since we need the link Jacobians later, we can save a double calculation.
+        L_v_WL = jax.vmap(lambda J: J @ B_ν)(L_J_WL_B)
 
-    # We need to adjust the Coriolis matrix for fixed-base models.
-    # In this case, the base link does not contribute to the matrix, and we need to zero
-    # the off-diagonal terms mapping joint quantities onto the base configuration.
-    if model.floating_base():
-        C_B = C_B_links.sum(axis=0)
-    else:
-        C_B = C_B_links[1:].sum(axis=0)
-        C_B = C_B.at[0:6, 6:].set(0.0)
-        C_B = C_B.at[6:, 0:6].set(0.0)
+        # Compute the contribution of each link to the Coriolis matrix.
+        def compute_link_contribution(M, v, J, J̇) -> jtp.Array:
+
+            return J.T @ ((Cross.vx_star(v) @ M + M @ Cross.vx(v)) @ J + M @ J̇)
+
+        C_B_links = jax.vmap(compute_link_contribution)(
+            L_M_L,
+            L_v_WL,
+            L_J_WL_B,
+            L_J̇_WL_B,
+        )
+
+        # We need to adjust the Coriolis matrix for fixed-base models.
+        # In this case, the base link does not contribute to the matrix, and we need to zero
+        # the off-diagonal terms mapping joint quantities onto the base configuration.
+        if model.floating_base():
+            C_B = C_B_links.sum(axis=0)
+        else:
+            C_B = C_B_links[1:].sum(axis=0)
+            C_B = C_B.at[0:6, 6:].set(0.0)
+            C_B = C_B.at[6:, 0:6].set(0.0)
 
     # Adjust the representation of the Coriolis matrix.
     # Refer to https://github.com/traversaro/traversaro-phd-thesis, Section 3.6.

From ff0523305a87d7ca053779a8048c789405520c77 Mon Sep 17 00:00:00 2001
From: Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
Date: Thu, 12 Sep 2024 00:10:17 +0200
Subject: [PATCH 3/3] [ci skip] Add test for the Coriolis matrix RBD

---
 tests/test_api_model.py | 11 ++++++++++-
 1 file changed, 10 insertions(+), 1 deletion(-)

diff --git a/tests/test_api_model.py b/tests/test_api_model.py
index 1c57f283f..e23da2937 100644
--- a/tests/test_api_model.py
+++ b/tests/test_api_model.py
@@ -407,7 +407,7 @@ def test_coriolis_matrix(
     # =====
 
     I_ν = data.generalized_velocity()
-    C = js.model.free_floating_coriolis_matrix(model=model, data=data)
+    C = js.model.free_floating_coriolis_matrix(model=model, data=data, prefer_rbd=False)
 
     h = js.model.free_floating_bias_forces(model=model, data=data)
     g = js.model.free_floating_gravity_forces(model=model, data=data)
@@ -477,6 +477,15 @@ def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array:
     # Ensure that (Ṁ - 2C) is skew symmetric.
     assert Ṁ - C - C.T == pytest.approx(0)
 
+    M = js.model.free_floating_mass_matrix(model=model, data=data)
+
+    M_rbd, _, C_rbd = js.model.free_floating_coriolis_matrix(
+        model=model, data=data, prefer_rbd=True
+    )
+
+    assert C == pytest.approx(C_rbd)
+    assert M == pytest.approx(M_rbd)
+
 
 def test_model_fd_id_consistency(
     jaxsim_models_types: js.model.JaxSimModel,