Skip to content

Commit

Permalink
[ci skip] Use boolean mask for slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Sep 18, 2024
1 parent 548286b commit 7ac00b7
Showing 1 changed file with 12 additions and 26 deletions.
38 changes: 12 additions & 26 deletions src/jaxsim/rbda/mass_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]:

# Compute the articulated-body inertia and bias force of this link.
Ma = MA[i] - U[i] / d[i] @ U[i].T
Fa = F[i, :, ν[i]] + U[i] @ M_inv[i, ν[i]]

M_inv_ii = 1 / d[i]
M_inv = M_inv.at[i, i].set(M_inv_ii)
Expand All @@ -167,7 +168,7 @@ def propagate(
) -> tuple[jtp.Matrix, jtp.Matrix]:
MA, F = MA_F

Fa_λi = F[:, ν[i]] + U[i] @ M_inv[i, ν[i]]
Fa_λi = F[λ[i], :, ν[i]] + i_X_λi[i].T @ Fa
F = F.at[:, ν[i]].set(Fa_λi)

MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i]
Expand Down Expand Up @@ -198,13 +199,7 @@ def propagate(
# Pass 3
# ======

P = jnp.zeros(
shape=(
model.number_of_links(),
model.number_of_links(),
model.number_of_links(),
)
)
P = jnp.zeros_like(F)

Pass3Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix]
pass_3_carry = (U, M_inv, P)
Expand All @@ -213,17 +208,13 @@ def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]:

U, M_inv, P = carry

mask = jnp.arange(P.shape[1]) >= i # equivalent to [i, i:]
mask = jnp.arange(P.shape[1]) >= i

def propagate_M_inv(M_inv: jtp.Matrix) -> jtp.Matrix:
P_ii = jax.lax.dynamic_slice(
P, (i, i - P.shape[1], P.shape[2]), (P.shape[0], 1) * mask
P_λi = jnp.where(mask, i_X_λi[i].T @ P[λ[i], i], P[λ[i], i])
M_inv = M_inv.at[i].set(
jnp.where(mask, M_inv[i] - U[i].T @ P_λi / d[i], M_inv)
)
M_inv_ii = jax.lax.dynamic_slice(
M_inv.squeeze(), (i, i - M_inv.squeeze().shape[0]), i_X_λi[i].shape
)
M_inv_ii = M_inv_ii.at[:].set(M_inv_ii - U[i].T @ i_X_λi[i] @ P_ii / d[i])
jax.lax.dynamic_update_slice(M, M_inv_ii, (i, i))

return M_inv

Expand All @@ -234,19 +225,14 @@ def propagate_M_inv(M_inv: jtp.Matrix) -> jtp.Matrix:
operand=M_inv,
)

M_inv_ii = jax.lax.dynamic_slice(
M_inv, (i, i - d.shape[0], 1), (1, d[i].shape - i, 1)
)
M_inv_ii = M_inv[i] * mask

P_i = S[i].T @ M_inv_ii
P = P.at[i].set(P_i.squeeze())
P_ii = S[i].T @ M_inv_ii
P = P.at[i].set(P_ii.squeeze())

def propagate_P(P: jtp.Vector) -> jtp.Vector:
P_λii = jax.lax.dynamic_slice(P, (λ[i], i), (1, i))
P_iii = jax.lax.dynamic_slice(P, (i, i), (1, i))

P_iii = P_iii.at[:].set(P_iii + i_X_λi[i].T @ P_λii)
jax.lax.dynamic_update_slice(P, P_iii, (i, i))
P_λi = jnp.where(mask, i_X_λi[i].T @ P[λ[i], i], P[λ[i], i])
P = P.at[i].set(jnp.where(mask, P[i] + P_λi, P[i]))

return P

Expand Down

0 comments on commit 7ac00b7

Please sign in to comment.