Skip to content

Commit

Permalink
[WIP]
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Sep 18, 2024
1 parent 7ac00b7 commit e042c22
Showing 1 changed file with 70 additions and 46 deletions.
116 changes: 70 additions & 46 deletions src/jaxsim/rbda/mass_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,13 @@ def mass_inverse(

# Allocate buffers.
MA = jnp.zeros(shape=(model.number_of_links(), 6, 6))
M_inv = jnp.zeros(
shape=(
model.number_of_links() + 6 * model.floating_base(),
model.number_of_links() + 6 * model.floating_base(),
1,
)
)

# Allocate the buffer of transforms link -> base.
i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
i_X_0 = i_X_0.at[0].set(jnp.eye(6))

# Initialize base quantities.
if model.floating_base():

# Initialize the articulated-body inertia (Mᴬ) of base link.
MA_0 = M[0]
MA = MA.at[0].set(MA_0)
Expand Down Expand Up @@ -135,7 +127,25 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]:

U = jnp.zeros_like(S)
d = jnp.zeros(shape=(model.number_of_links(), 1))
F = jnp.zeros(shape=(6, model.number_of_links()))
M_inv = jnp.zeros(
shape=(
model.number_of_joints() + 6 * model.floating_base(),
model.number_of_joints() + 6 * model.floating_base(),
1,
)
)

if model.number_of_joints() == 0:
M_inv_0 = jnp.linalg.solve(MA[0], jnp.eye(6))
M_inv = M_inv.at[:].set(M_inv_0)

F = jnp.zeros(
shape=(
model.number_of_links(),
6,
model.number_of_links(),
)
)

Pass2Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
pass_2_carry: Pass2Carry = (U, d, M_inv, MA, F)
Expand All @@ -152,14 +162,10 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]:
d_i = S[i].T @ U[i]
d = d.at[i].set(d_i.squeeze())

# Compute the articulated-body inertia and bias force of this link.
Ma = MA[i] - U[i] / d[i] @ U[i].T
Fa = F[i, :, ν[i]] + U[i] @ M_inv[i, ν[i]]

M_inv_ii = 1 / d[i]
M_inv = M_inv.at[i, i].set(M_inv_ii)
M_inv_i = 1 / d[i]
M_inv = M_inv.at[i, i].set(M_inv_i)

M_inv_iν = M_inv[i, ν[i]] - S[i].T @ F[:, ν[i]].squeeze() / d[i].T
M_inv_iν = M_inv[i, ν[i]] - S[i].T @ F[i, :, ν[i]].squeeze() / d[i].T
M_inv = M_inv.at[i, ν[i]].set(M_inv_iν)

# Propagate them to the parent, handling the base link.
Expand All @@ -168,10 +174,14 @@ def propagate(
) -> tuple[jtp.Matrix, jtp.Matrix]:
MA, F = MA_F

Fa_λi = F[λ[i], :, ν[i]] + i_X_λi[i].T @ Fa
F = F.at[:, ν[i]].set(Fa_λi)
# Compute the articulated-body inertia and bias force of this link.
Ma_i = MA[i] - U[i] / d[i] @ U[i].T
Fa_i = F[i, :, ν[i]] + U[i] @ M_inv[i, ν[i]]

MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i]
Fa_λi = F[λ[i], :, ν[i]] + i_X_λi[i].T @ Fa_i
F = F.at[λ[i], :, ν[i]].set(Fa_λi)

MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma_i @ i_X_λi[i]
MA = MA.at[λ[i]].set(MA_λi)

return MA, F
Expand All @@ -185,21 +195,29 @@ def propagate(

return (U, d, M_inv, MA, F), None

(U, d, M_inv, MA, F), _ = (
jax.lax.scan(
f=loop_body_pass2,
init=pass_2_carry,
xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
with jax.disable_jit(True):
(U, d, M_inv, MA, F), _ = (
jax.lax.scan(
f=loop_body_pass2,
init=pass_2_carry,
xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
)
if model.number_of_links() > 1
else [(U, d, M_inv, MA, F), None]
)
if model.number_of_links() > 1
else [(U, d, M_inv, MA, F), None]
)

# ======
# Pass 3
# ======

P = jnp.zeros_like(F)
P = jnp.zeros(
shape=(
model.number_of_joints(),
model.number_of_joints(),
6,
model.number_of_joints() + 6 * model.floating_base(),
)
)

Pass3Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix]
pass_3_carry = (U, M_inv, P)
Expand All @@ -208,12 +226,17 @@ def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]:

U, M_inv, P = carry

mask = jnp.arange(P.shape[1]) >= i
mask = jnp.arange(P.shape[-1]) >= i
mask_M = jnp.atleast_2d(mask).T

def propagate_M_inv(M_inv: jtp.Matrix) -> jtp.Matrix:
P_λi = jnp.where(mask, i_X_λi[i].T @ P[λ[i], i], P[λ[i], i])

M_inv = M_inv.at[i].set(
jnp.where(mask, M_inv[i] - U[i].T @ P_λi / d[i], M_inv)
jnp.where(
mask_M,
M_inv[i] - (U[i].T @ i_X_λi[i].T @ P[λ[i], i]).T / d[i],
M_inv[i],
)
)

return M_inv
Expand All @@ -225,14 +248,14 @@ def propagate_M_inv(M_inv: jtp.Matrix) -> jtp.Matrix:
operand=M_inv,
)

M_inv_ii = M_inv[i] * mask

P_ii = S[i].T @ M_inv_ii
P = P.at[i].set(P_ii.squeeze())
P_ii = jnp.where(mask, S[i] @ M_inv[i].T, P[i, i])
P = P.at[i].set(P_ii)

def propagate_P(P: jtp.Vector) -> jtp.Vector:
P_λi = jnp.where(mask, i_X_λi[i].T @ P[λ[i], i], P[λ[i], i])
P = P.at[i].set(jnp.where(mask, P[i] + P_λi, P[i]))

P = P.at[i, i].set(
jnp.where(mask, P[i, i] + i_X_λi[i].T @ P[λ[i], i], P[i, i])
)

return P

Expand All @@ -245,18 +268,19 @@ def propagate_P(P: jtp.Vector) -> jtp.Vector:

return (U, M_inv, P), None

(U, M_inv, P), _ = (
jax.lax.scan(
f=loop_body_pass3,
init=pass_3_carry,
xs=jnp.arange(1, model.number_of_links()),
with jax.disable_jit(True):
(U, M_inv, P), _ = (
jax.lax.scan(
f=loop_body_pass3,
init=pass_3_carry,
xs=jnp.arange(1, model.number_of_links()),
)
if model.number_of_links() > 1
else [(U, M_inv, P), None]
)
if model.number_of_links() > 1
else [(U, M_inv, P), None]
)

# ==============
# Adjust outputs
# ==============

return M_inv
return M_inv.squeeze()

0 comments on commit e042c22

Please sign in to comment.