Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider only enabled collidable points in contact forces computation for Rigid, RelaxedRigid and Soft contact models #274

Merged
merged 14 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 65 additions & 28 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def collidable_point_dynamics(
**kwargs,
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
r"""
Compute the 6D force applied to each collidable point.
Compute the 6D force applied to each enabled collidable point.

Args:
model: The model to consider.
Expand All @@ -151,7 +151,7 @@ def collidable_point_dynamics(
kwargs: Additional keyword arguments to pass to the active contact model.

Returns:
The 6D force applied to each collidable point and additional data based
The 6D force applied to each eneabled collidable point and additional data based
on the contact model configured:
- Soft: the material deformation rate.
- Rigid: no additional data.
Expand Down Expand Up @@ -199,15 +199,19 @@ def collidable_point_dynamics(
)

# Compute the transforms of the implicit frames `C[L] = (W_p_C, [L])`
# associated to each collidable point.
# associated to the enabled collidable point.
# In inertial-fixed representation, the computation of these transforms
# is not necessary and the conversion below becomes a no-op.

# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)
flferretti marked this conversation as resolved.
Show resolved Hide resolved

W_H_C = (
js.contact.transforms(model=model, data=data)
if data.velocity_representation is not VelRepr.Inertial
else jnp.zeros(
shape=(len(model.kin_dyn_parameters.contact_parameters.body), 4, 4)
)
else jnp.zeros(shape=(len(indices_of_enabled_collidable_points), 4, 4))
)

# Convert the 6D forces to the active representation.
Expand Down Expand Up @@ -246,6 +250,15 @@ def in_contact(
if link_names is not None and set(link_names).difference(model.link_names()):
raise ValueError("One or more link names are not part of the model")

# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)

parent_link_idx_of_enabled_collidable_points = jnp.array(
model.kin_dyn_parameters.contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]

W_p_Ci = collidable_point_positions(model=model, data=data)

terrain_height = jax.vmap(lambda x, y: model.terrain.height(x=x, y=y))(
Expand All @@ -262,7 +275,7 @@ def in_contact(

links_in_contact = jax.vmap(
lambda link_index: jnp.where(
jnp.array(model.kin_dyn_parameters.contact_parameters.body) == link_index,
parent_link_idx_of_enabled_collidable_points == link_index,
below_terrain,
jnp.zeros_like(below_terrain, dtype=bool),
).any()
Expand Down Expand Up @@ -426,14 +439,14 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
@jax.jit
def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
r"""
Return the pose of the collidable points.
Return the pose of the enabled collidable points.

Args:
model: The model to consider.
data: The data of the considered model.

Returns:
The stacked SE(3) matrices of all collidable points.
The stacked SE(3) matrices of all enabled collidable points.

Note:
Each collidable point is implicitly associated with a frame
Expand All @@ -442,16 +455,27 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt
rigidly attached to.
"""

# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)

parent_link_idx_of_enabled_collidable_points = jnp.array(
model.kin_dyn_parameters.contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]

# Get the transforms of the parent link of all collidable points.
W_H_L = js.model.forward_kinematics(model=model, data=data)[
jnp.array(model.kin_dyn_parameters.contact_parameters.body, dtype=int)
parent_link_idx_of_enabled_collidable_points
]

L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
indices_of_enabled_collidable_points
]

# Build the link-to-point transform from the displacement between the link frame L
# and the implicit contact frame C.
L_H_C = jax.vmap(lambda L_p_C: jnp.eye(4).at[0:3, 3].set(L_p_C))(
model.kin_dyn_parameters.contact_parameters.point
)
L_H_C = jax.vmap(lambda L_p_C: jnp.eye(4).at[0:3, 3].set(L_p_C))(L_p_Ci)

# Compose the work-to-link and link-to-point transforms.
return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C)
Expand All @@ -465,7 +489,7 @@ def jacobian(
output_vel_repr: VelRepr | None = None,
) -> jtp.Array:
r"""
Return the free-floating Jacobian of the collidable points.
Return the free-floating Jacobian of the enabled collidable points.

Args:
model: The model to consider.
Expand All @@ -475,7 +499,7 @@ def jacobian(

Returns:
The stacked :math:`6 \times (6+n)` free-floating jacobians of the frames associated to the
collidable points.
enabled collidable points.

Note:
Each collidable point is implicitly associated with a frame
Expand All @@ -488,6 +512,15 @@ def jacobian(
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)

# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)

parent_link_idx_of_enabled_collidable_points = jnp.array(
model.kin_dyn_parameters.contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]

# Compute the Jacobians of all links.
W_J_WL = js.model.generalized_free_floating_jacobian(
model=model, data=data, output_vel_repr=VelRepr.Inertial
Expand All @@ -496,9 +529,7 @@ def jacobian(
# Compute the contact Jacobian.
# In inertial-fixed output representation, the Jacobian of the parent link is also
# the Jacobian of the frame C implicitly associated with the collidable point.
W_J_WC = W_J_WL[
jnp.array(model.kin_dyn_parameters.contact_parameters.body, dtype=int)
]
W_J_WC = W_J_WL[parent_link_idx_of_enabled_collidable_points]

# Adjust the output representation.
match output_vel_repr:
Expand Down Expand Up @@ -550,7 +581,7 @@ def jacobian_derivative(
output_vel_repr: VelRepr | None = None,
) -> jtp.Matrix:
r"""
Compute the derivative of the free-floating jacobian of the contact points.
Compute the derivative of the free-floating jacobian of the enabled collidable points.

Args:
model: The model to consider.
Expand All @@ -559,7 +590,7 @@ def jacobian_derivative(
The output velocity representation of the free-floating jacobian derivative.

Returns:
The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the contact points.
The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the enabled collidable points.

Note:
The input representation of the free-floating jacobian derivative is the active
Expand All @@ -570,10 +601,18 @@ def jacobian_derivative(
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)

indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)

# Get the index of the parent link and the position of the collidable point.
parent_link_idxs = jnp.array(model.kin_dyn_parameters.contact_parameters.body)
L_p_Ci = jnp.array(model.kin_dyn_parameters.contact_parameters.point)
contact_idxs = jnp.arange(L_p_Ci.shape[0])
parent_link_idx_of_enabled_collidable_points = jnp.array(
model.kin_dyn_parameters.contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]

L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
indices_of_enabled_collidable_points
]

# Get the transforms of all the parent links.
W_H_Li = js.model.forward_kinematics(model=model, data=data)
Expand Down Expand Up @@ -646,7 +685,7 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix:
output_vel_repr=VelRepr.Inertial,
)

# Get the Jacobian of the collidable points in the mixed representation.
# Get the Jacobian of the enabled collidable points in the mixed representation.
with data.switch_velocity_representation(VelRepr.Mixed):
CW_J_WC_BW = jacobian(
model=model,
Expand All @@ -656,13 +695,11 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix:

def compute_O_J̇_WC_I(
L_p_C: jtp.Vector,
contact_idx: jtp.Int,
parent_link_idx: jtp.Int,
CW_J_WC_BW: jtp.Matrix,
W_H_L: jtp.Matrix,
) -> jtp.Matrix:

parent_link_idx = parent_link_idxs[contact_idx]

match output_vel_repr:
case VelRepr.Inertial:
O_X_W = W_X_W = Adjoint.from_transform( # noqa: F841
Expand Down Expand Up @@ -703,7 +740,7 @@ def compute_O_J̇_WC_I(
return O_J̇_WC_I

O_J̇_WC = jax.vmap(compute_O_J̇_WC_I, in_axes=(0, 0, 0, None))(
L_p_Ci, contact_idxs, CW_J_WC_BW, W_H_Li
L_p_Ci, parent_link_idx_of_enabled_collidable_points, CW_J_WC_BW, W_H_Li
)

return O_J̇_WC
3 changes: 3 additions & 0 deletions src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,9 @@ class ContactParameters(JaxsimDataclass):
point:
The translations between the link frame and the collidable point, expressed
in the coordinates of the parent link frame.
enabled:
A tuple of booleans representing, for each collidable point, whether it is
enabled or not in contact models.

Note:
Contrarily to LinkParameters and JointParameters, this class is not meant
Expand Down
14 changes: 11 additions & 3 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2287,7 +2287,14 @@ def step(
msg="Baumgarte stabilization is not supported with ForwardEuler integrators",
)

W_p_C = js.contact.collidable_point_positions(model, data_tf)
# Extract the indices corresponding to the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)

W_p_C = js.contact.collidable_point_positions(model, data_tf)[
indices_of_enabled_collidable_points
]

# Compute the penetration depth of the collidable points.
δ, *_ = jax.vmap(
Expand All @@ -2296,8 +2303,9 @@ def step(
)(W_p_C, jnp.zeros_like(W_p_C), model.terrain)

with data_tf.switch_velocity_representation(VelRepr.Mixed):

J_WC = js.contact.jacobian(model, data_tf)
J_WC = js.contact.jacobian(model, data_tf)[
indices_of_enabled_collidable_points
]
M = js.model.free_floating_mass_matrix(model, data_tf)

# Compute the impact velocity.
Expand Down
23 changes: 18 additions & 5 deletions src/jaxsim/rbda/collidable_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def collidable_points_pos_vel(
) -> tuple[jtp.Matrix, jtp.Matrix]:
"""

Compute the position and linear velocity of collidable points in the world frame.
Compute the position and linear velocity of the enabled collidable points in the world frame.

Args:
model: The model to consider.
Expand All @@ -35,10 +35,23 @@ def collidable_points_pos_vel(
joint_velocities: The velocities of the joints.

Returns:
A tuple containing the position and linear velocity of collidable points.
A tuple containing the position and linear velocity of the enabled collidable points.
"""

if len(model.kin_dyn_parameters.contact_parameters.body) == 0:
# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)

parent_link_idx_of_enabled_collidable_points = jnp.array(
model.kin_dyn_parameters.contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]

L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
indices_of_enabled_collidable_points
]

if len(indices_of_enabled_collidable_points) == 0:
return jnp.array(0).astype(float), jnp.empty(0).astype(float)

W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs(
Expand Down Expand Up @@ -136,8 +149,8 @@ def process_point_kinematics(

# Process all the collidable points in parallel.
W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)(
model.kin_dyn_parameters.contact_parameters.point,
jnp.array(model.kin_dyn_parameters.contact_parameters.body),
L_p_Ci,
parent_link_idx_of_enabled_collidable_points,
)

return W_p_Ci, CW_vl_WC
20 changes: 11 additions & 9 deletions src/jaxsim/rbda/contacts/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,21 @@ def link_forces_from_contact_forces(
the velocity representation of data.
"""

# Get the object storing the contact parameters of the model.
contact_parameters = model.kin_dyn_parameters.contact_parameters

# Extract the indices corresponding to the enabled collidable points.
indices_of_enabled_collidable_points = (
contact_parameters.indices_of_enabled_collidable_points
)

# Convert the contact forces to a JAX array.
f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze())

# Get the pose of the enabled collidable points.
W_H_C = js.contact.transforms(model=model, data=data)
W_H_C = js.contact.transforms(model=model, data=data)[
indices_of_enabled_collidable_points
]

# Convert the contact forces to inertial-fixed representation.
W_f_C = jax.vmap(
Expand All @@ -234,14 +244,6 @@ def link_forces_from_contact_forces(
)
)(f_C, W_H_C)

# Get the object storing the contact parameters of the model.
contact_parameters = model.kin_dyn_parameters.contact_parameters

# Extract the indices corresponding to the enabled collidable points.
indices_of_enabled_collidable_points = (
contact_parameters.indices_of_enabled_collidable_points
)

# Construct the vector defining the parent link index of each collidable point.
# We use this vector to sum the 6D forces of all collidable points rigidly
# attached to the same link.
Expand Down
Loading