Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add algorithm to compute the standalone Coriolis matrix C(q, ν) #75

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 65 additions & 38 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,16 +939,19 @@ def free_floating_mass_matrix(
raise ValueError(data.velocity_representation)


@jax.jit
@functools.partial(jax.jit, static_argnames=["prefer_rbd"])
def free_floating_coriolis_matrix(
model: JaxSimModel, data: js.data.JaxSimModelData
model: JaxSimModel, data: js.data.JaxSimModelData, prefer_rbd: bool = True
) -> jtp.Matrix:
"""
Compute the free-floating Coriolis matrix of the model.

Args:
model: The model to consider.
data: The data of the considered model.
prefer_rbd:
Whether to prefer the RBD algorithm over the computation that uses
the Jacobians.

Returns:
The free-floating Coriolis matrix of the model.
Expand All @@ -958,52 +961,76 @@ def free_floating_coriolis_matrix(
does not exploit any iterative algorithm. Therefore, the computation of
the Coriolis matrix may be much slower than other quantities.
"""
if prefer_rbd:
# Extract the link and joint serializations.
joint_names = model.joint_names()

# We perform all the calculation in body-fixed.
# The Coriolis matrix computed in this representation is converted later
# to the active representation stored in data.
with data.switch_velocity_representation(VelRepr.Body):
# Extract the state in inertial-fixed representation.
with data.switch_velocity_representation(VelRepr.Inertial):
W_p_B = data.base_position()
W_v_WB = data.base_velocity()
W_Q_B = data.base_orientation(dcm=False)
s = data.joint_positions(model=model, joint_names=joint_names)
ṡ = data.joint_velocities(model=model, joint_names=joint_names)

B_ν = data.generalized_velocity()
M_B, Ṁ_B, C_B = jaxsim.rbda.coriolis( # noqa: F841
model=model,
base_position=W_p_B,
base_quaternion=W_Q_B,
joint_positions=s,
base_linear_velocity=W_v_WB[0:3],
base_angular_velocity=W_v_WB[3:6],
joint_velocities=ṡ,
standard_gravity=data.standard_gravity(),
)

# Doubly-left free-floating Jacobian.
L_J_WL_B = generalized_free_floating_jacobian(model=model, data=data)
else:

# Doubly-left free-floating Jacobian derivative.
L_J̇_WL_B = jax.vmap(
lambda link_index: js.link.jacobian_derivative(
model=model, data=data, link_index=link_index
)
)(js.link.names_to_idxs(model=model, link_names=model.link_names()))
# We perform all the calculation in body-fixed.
# The Coriolis matrix computed in this representation is converted later
# to the active representation stored in data.
with data.switch_velocity_representation(VelRepr.Body):

L_M_L = link_spatial_inertia_matrices(model=model)
B_ν = data.generalized_velocity()

# Body-fixed link velocities.
# Note: we could have called link.velocity() instead of computing it ourselves,
# but since we need the link Jacobians later, we can save a double calculation.
L_v_WL = jax.vmap(lambda J: J @ B_ν)(L_J_WL_B)
# Doubly-left free-floating Jacobian.
L_J_WL_B = generalized_free_floating_jacobian(model=model, data=data)

# Compute the contribution of each link to the Coriolis matrix.
def compute_link_contribution(M, v, J, J̇) -> jtp.Array:
# Doubly-left free-floating Jacobian derivative.
L_J̇_WL_B = jax.vmap(
lambda link_index: js.link.jacobian_derivative(
model=model, data=data, link_index=link_index
)
)(js.link.names_to_idxs(model=model, link_names=model.link_names()))

return J.T @ ((Cross.vx_star(v) @ M + M @ Cross.vx(v)) @ J + M @ J̇)
L_M_L = link_spatial_inertia_matrices(model=model)

C_B_links = jax.vmap(compute_link_contribution)(
L_M_L,
L_v_WL,
L_J_WL_B,
L_J̇_WL_B,
)
# Body-fixed link velocities.
# Note: we could have called link.velocity() instead of computing it ourselves,
# but since we need the link Jacobians later, we can save a double calculation.
L_v_WL = jax.vmap(lambda J: J @ B_ν)(L_J_WL_B)

# We need to adjust the Coriolis matrix for fixed-base models.
# In this case, the base link does not contribute to the matrix, and we need to zero
# the off-diagonal terms mapping joint quantities onto the base configuration.
if model.floating_base():
C_B = C_B_links.sum(axis=0)
else:
C_B = C_B_links[1:].sum(axis=0)
C_B = C_B.at[0:6, 6:].set(0.0)
C_B = C_B.at[6:, 0:6].set(0.0)
# Compute the contribution of each link to the Coriolis matrix.
def compute_link_contribution(M, v, J, J̇) -> jtp.Array:

return J.T @ ((Cross.vx_star(v) @ M + M @ Cross.vx(v)) @ J + M @ J̇)

C_B_links = jax.vmap(compute_link_contribution)(
L_M_L,
L_v_WL,
L_J_WL_B,
L_J̇_WL_B,
)

# We need to adjust the Coriolis matrix for fixed-base models.
# In this case, the base link does not contribute to the matrix, and we need to zero
# the off-diagonal terms mapping joint quantities onto the base configuration.
if model.floating_base():
C_B = C_B_links.sum(axis=0)
else:
C_B = C_B_links[1:].sum(axis=0)
C_B = C_B.at[0:6, 6:].set(0.0)
C_B = C_B.at[6:, 0:6].set(0.0)

# Adjust the representation of the Coriolis matrix.
# Refer to https://github.com/traversaro/traversaro-phd-thesis, Section 3.6.
Expand Down
187 changes: 187 additions & 0 deletions src/jaxsim/rbda/coriolis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import jax
import jax.numpy as jnp

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import Adjoint, Cross, StandardGravity, Transform

from . import utils


def coriolis(
model: js.model.JaxSimModel,
*,
base_position: jtp.VectorLike,
base_quaternion: jtp.VectorLike,
joint_positions: jtp.VectorLike,
base_linear_velocity: jtp.VectorLike,
base_angular_velocity: jtp.VectorLike,
joint_velocities: jtp.VectorLike,
joint_forces: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
standard_gravity: jtp.FloatLike = StandardGravity,
) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
"""
Coriolis matrix
"""

W_p_B, W_Q_B, s, _, ṡ, _, _, _, _, _ = utils.process_inputs(
model=model,
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
base_linear_velocity=base_linear_velocity,
base_angular_velocity=base_angular_velocity,
joint_velocities=joint_velocities,
standard_gravity=standard_gravity,
)

W_H_B = Transform.from_quaternion_and_translation(
quaternion=W_Q_B,
translation=W_p_B,
)

# Extract data from the physics model
pre_X_λi = model.tree_transforms
M = js.model.link_spatial_inertia_matrices(model=model)
i_X_pre, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces(
joint_positions=s, base_transform=W_H_B.as_matrix()
)
λ = model.kin_dyn_parameters.parent_array

# Initialize buffers
v = jnp.array([jnp.zeros([6, 1])] * model.number_of_links())
Ṡ = jnp.array([jnp.zeros([6, 1])] * model.number_of_links())
BC = jnp.array([jnp.zeros([6, 6])] * model.number_of_links())
IC = jnp.zeros_like(M)

i_X_λi = jnp.zeros_like(i_X_pre)

# 6D transform of base velocity
B_X_W = Adjoint.from_quaternion_and_translation(
quaternion=W_Q_B,
translation=W_p_B,
inverse=True,
normalize_quaternion=True,
)
i_X_λi = i_X_λi.at[0].set(B_X_W)

# Transforms link -> base
i_X_0 = jnp.zeros_like(pre_X_λi)
i_X_0 = i_X_0.at[0].set(jnp.eye(6))

Pass1Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]

def loop_pass_1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]:
i_X_λi, v, Ṡ, BC, IC = carry
vJ = S[i] * ṡ[i]
v_i = i_X_λi[i] @ v[λ[i]] + vJ
v = v.at[i].set(v_i)

Ṡ_i = Cross.vx(v[i]) @ S[i]
Ṡ = Ṡ.at[i].set(Ṡ_i)

IC = IC.at[i].set(M[i])
BC_i = (
Cross.vx_star(v[i]) @ Cross.vx(IC[i] @ v[i]) - IC[i] @ Cross.vx(v[i])
) / 2
BC = BC.at[i].set(BC_i)

return (i_X_λi, v, Ṡ, BC, IC), None

(i_X_λi, v, Ṡ, BC, IC), _ = (
jax.lax.scan(
f=loop_pass_1,
init=(i_X_λi, v, Ṡ, BC, IC),
xs=jnp.arange(1, model.number_of_links() + 1),
)
if model.number_of_links() > 1
else [(i_X_λi, v, Ṡ, BC, IC), None]
)

C = jnp.zeros([model.number_of_links(), model.number_of_links()])
M = jnp.zeros([model.number_of_links(), model.number_of_links()])
Ṁ = jnp.zeros([model.number_of_links(), model.number_of_links()])

Pass2Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]

def loop_pass_2(carry: Pass2Carry, j: jtp.Int) -> tuple[Pass2Carry, None]:
jj = λ[j] - 1

C, M, Ṁ, IC, BC = carry

F_1 = IC[j] @ Ṡ[j] + BC[j] @ S[j]
F_2 = IC[j] @ S[j]
F_3 = BC[j].T @ S[j]

C = C.at[jj, jj].set((S[j].T @ F_1).squeeze())
M = M.at[jj, jj].set((S[j].T @ F_2).squeeze())
Ṁ = Ṁ.at[jj, jj].set((Ṡ[j].T @ F_2 + S[j].T @ F_3).squeeze())

F_1 = i_X_λi[j] @ F_1
F_2 = i_X_λi[j] @ F_2
F_3 = i_X_λi[j] @ F_3

InnerLoopCarry = tuple[
jtp.Matrix,
jtp.Matrix,
jtp.Matrix,
jtp.Matrix,
jtp.Matrix,
jtp.Matrix,
jtp.Matrix,
]

def inner_loop_body(carry: InnerLoopCarry) -> tuple[InnerLoopCarry]:
C, M, Ṁ, F_1, F_2, F_3, i = carry
ii = λ[i] - 1

C = C.at[ii, jj].set((S[i].T @ F_1).squeeze())
C = C.at[jj, ii].set((S[i].T @ F_1).squeeze())

M = M.at[ii, ii].set((S[i].T @ F_2).squeeze())
Ṁ = Ṁ.at[ii].set((Ṡ[i].T @ F_2 + S[i].T @ F_3).squeeze())

F_1 = F_1 + i_X_λi[i] @ F_1
F_2 = F_2 + i_X_λi[i] @ F_2
F_3 = F_3 + i_X_λi[i] @ F_3

i = λ[i]
return C, M, Ṁ, F_1, F_2, F_3, i

(C, M, Ṁ, F_1, F_2, F_3, _) = jax.lax.while_loop(
body_fun=inner_loop_body,
cond_fun=lambda idx: idx[-1] > 0,
init_val=(C, M, Ṁ, F_1, F_2, F_3, 0),
)

def propagate(
IC_BC: tuple[jtp.Matrix, jtp.Matrix]
) -> tuple[jtp.Matrix, jtp.Matrix]:
IC, BC = IC_BC

IC = IC.at[λ[j]].set(IC[λ[j]] + i_X_λi[j] @ IC[j] @ i_X_λi[j].T)
BC = BC.at[λ[j]].set(BC[λ[j]] + i_X_λi[j] @ BC[j] @ i_X_λi[j].T)

return IC, BC

IC, BC = jax.lax.cond(
pred=jnp.array([λ[j] != 0, model.is_floating_base]).any(),
true_fun=propagate,
false_fun=lambda IC_BC: IC_BC,
operand=(IC, BC),
)

return (C, M, Ṁ, IC, BC), None

(C, M, Ṁ, IC, BC), _ = (
jax.lax.scan(
f=loop_pass_2,
init=(C, M, Ṁ, IC, BC),
xs=jnp.flip(jnp.arange(1, model.number_of_links() + 1)),
)
if model.number_of_links() > 1
else [(C, M, Ṁ, IC, BC), None]
)

return M, Ṁ, C
11 changes: 10 additions & 1 deletion tests/test_api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def test_coriolis_matrix(
# =====

I_ν = data.generalized_velocity()
C = js.model.free_floating_coriolis_matrix(model=model, data=data)
C = js.model.free_floating_coriolis_matrix(model=model, data=data, prefer_rbd=False)

h = js.model.free_floating_bias_forces(model=model, data=data)
g = js.model.free_floating_gravity_forces(model=model, data=data)
Expand Down Expand Up @@ -477,6 +477,15 @@ def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array:
# Ensure that (Ṁ - 2C) is skew symmetric.
assert Ṁ - C - C.T == pytest.approx(0)

M = js.model.free_floating_mass_matrix(model=model, data=data)

M_rbd, _, C_rbd = js.model.free_floating_coriolis_matrix(
model=model, data=data, prefer_rbd=True
)

assert C == pytest.approx(C_rbd)
assert M == pytest.approx(M_rbd)


def test_model_fd_id_consistency(
jaxsim_models_types: js.model.JaxSimModel,
Expand Down