Skip to content

Commit

Permalink
Update RelaxedRigidContacts to consider only enabled collidable points
Browse files Browse the repository at this point in the history
  • Loading branch information
xela-95 committed Oct 29, 2024
1 parent d039b0f commit 29cbe95
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jax
import jax.numpy as jnp
import jax_dataclasses
import numpy.typing as npt
import optax

import jaxsim.api as js
Expand Down Expand Up @@ -322,18 +323,25 @@ def detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:

return jnp.dot(h, )

# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)

# Compute the position and linear velocities (mixed representation) of
# all collidable points belonging to the robot.
position, velocity = js.contact.collidable_point_kinematics(
model=model, data=data
)
p, v = js.contact.collidable_point_kinematics(model=model, data=data)
position = p[indices_of_enabled_collidable_points]
velocity = v[indices_of_enabled_collidable_points]

# Compute the activation state of the collidable points
δ = jax.vmap(detect_contact)(*position.T)

# Compute the transforms of the implicit frames corresponding to the
# collidable points.
W_H_C = js.contact.transforms(model=model, data=data)
W_H_C = js.contact.transforms(model=model, data=data)[
indices_of_enabled_collidable_points
]

with (
references.switch_velocity_representation(VelRepr.Mixed),
Expand All @@ -357,13 +365,19 @@ def detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:

Jl_WC = jnp.vstack(
jax.vmap(lambda J, height: J * (height < 0))(
js.contact.jacobian(model=model, data=data)[:, :3, :], δ
js.contact.jacobian(model=model, data=data)[
indices_of_enabled_collidable_points, :3, :
],
δ,
)
)

J̇_WC = jnp.vstack(
jax.vmap(lambda , height: * (height < 0))(
js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ
js.contact.jacobian_derivative(model=model, data=data)[
indices_of_enabled_collidable_points, :3
],
δ,
),
)

Expand All @@ -373,6 +387,7 @@ def detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
penetration=δ,
velocity=velocity,
parameters=data.contacts_params,
indices_of_enabled_collidable_points=indices_of_enabled_collidable_points,
)

# Compute the Delassus matrix and the free mixed linear acceleration of
Expand Down Expand Up @@ -499,6 +514,7 @@ def _regularizers(
penetration: jtp.Array,
velocity: jtp.Array,
parameters: RelaxedRigidContactsParams,
indices_of_enabled_collidable_points: npt.NDArray,
) -> tuple:
"""
Compute the contact jacobian and the reference acceleration.
Expand All @@ -508,6 +524,7 @@ def _regularizers(
penetration: The penetration of the collidable points.
velocity: The velocity of the collidable points.
parameters: The parameters of the relaxed rigid contacts model.
indices_of_enabled_collidable_points: The indices of the enabled collidable points.
Returns:
A tuple containing the reference acceleration, the regularization matrix, the stiffness, and the damping.
Expand Down Expand Up @@ -597,7 +614,7 @@ def compute_row(
*jax.vmap(compute_row)(
link_idx=jnp.array(
model.kin_dyn_parameters.contact_parameters.body
),
)[indices_of_enabled_collidable_points],
penetration=penetration,
velocity=velocity,
),
Expand Down

0 comments on commit 29cbe95

Please sign in to comment.