From 096a64678553feeda64cfb2b8c380a99a46e7b57 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Wed, 13 Nov 2024 15:28:48 +0100 Subject: [PATCH 01/14] Update `ContactParameters` docstring --- src/jaxsim/api/kin_dyn_parameters.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 77eee1a0f..d2304f018 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -743,6 +743,9 @@ class ContactParameters(JaxsimDataclass): point: The translations between the link frame and the collidable point, expressed in the coordinates of the parent link frame. + enabled: + A tuple of booleans representing, for each collidable point, whether it is + enabled or not in contact models. Note: Contrarily to LinkParameters and JointParameters, this class is not meant From 66829d475f83b2b9cf72925546efa7e06fd0b397 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Wed, 13 Nov 2024 15:28:48 +0100 Subject: [PATCH 02/14] Update `jaxsim.api.model.step` to consider only enabled collidable points --- src/jaxsim/api/model.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index eeb32fb39..ca07af9f9 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2287,7 +2287,14 @@ def step( msg="Baumgarte stabilization is not supported with ForwardEuler integrators", ) - W_p_C = js.contact.collidable_point_positions(model, data_tf) + # Extract the indices corresponding to the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + W_p_C = js.contact.collidable_point_positions(model, data_tf)[ + indices_of_enabled_collidable_points + ] # Compute the penetration depth of the collidable points. δ, *_ = jax.vmap( @@ -2296,8 +2303,9 @@ def step( )(W_p_C, jnp.zeros_like(W_p_C), model.terrain) with data_tf.switch_velocity_representation(VelRepr.Mixed): - - J_WC = js.contact.jacobian(model, data_tf) + J_WC = js.contact.jacobian(model, data_tf)[ + indices_of_enabled_collidable_points + ] M = js.model.free_floating_mass_matrix(model, data_tf) # Compute the impact velocity. From 3f9e9f2928025f7b228c9a3867be94a7c53fe24c Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Wed, 13 Nov 2024 15:28:48 +0100 Subject: [PATCH 03/14] Update `RigidContacts` to consider only enabled collidable points --- src/jaxsim/rbda/contacts/rigid.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index e8fe82c79..4eba17e73 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -285,6 +285,13 @@ def compute_contact_forces( # Import qpax privately just in this method. import qpax + # Get the indices of the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + n_collidable_points = len(indices_of_enabled_collidable_points) + link_forces = jnp.atleast_2d( jnp.array(link_forces, dtype=float).squeeze() if link_forces is not None @@ -299,24 +306,26 @@ def compute_contact_forces( # Compute kin-dyn quantities used in the contact model. with data.switch_velocity_representation(VelRepr.Mixed): - BW_ν = data.generalized_velocity() M = js.model.free_floating_mass_matrix(model=model, data=data) - J_WC = js.contact.jacobian(model=model, data=data) - J̇_WC = js.contact.jacobian_derivative(model=model, data=data) + J_WC = js.contact.jacobian(model=model, data=data)[ + indices_of_enabled_collidable_points + ] + J̇_WC = js.contact.jacobian_derivative(model=model, data=data)[ + indices_of_enabled_collidable_points + ] - W_H_C = js.contact.transforms(model=model, data=data) + W_H_C = js.contact.transforms(model=model, data=data)[ + indices_of_enabled_collidable_points + ] # Compute the position and linear velocities (mixed representation) of # all collidable points belonging to the robot. - position, velocity = js.contact.collidable_point_kinematics( - model=model, data=data - ) - - # Get the number of collidable points. - n_collidable_points = len(model.kin_dyn_parameters.contact_parameters.body) + p, v = js.contact.collidable_point_kinematics(model=model, data=data) + position = p[indices_of_enabled_collidable_points] + velocity = v[indices_of_enabled_collidable_points] # Compute the penetration depth and velocity of the collidable points. # Note that this function considers the penetration in the normal direction. @@ -460,7 +469,7 @@ def _compute_ineq_constraint_matrix( return G @staticmethod - def _compute_ineq_bounds(n_collidable_points: jtp.FloatLike) -> jtp.Vector: + def _compute_ineq_bounds(n_collidable_points: int) -> jtp.Vector: n_constraints = 6 * n_collidable_points return jnp.zeros(shape=(n_constraints,)) From fcea9170e45200f08b10c78c91bcbc22f9dd0741 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Wed, 13 Nov 2024 15:28:48 +0100 Subject: [PATCH 04/14] Update `link_forces_from_contact_forces` in rbda/contacts/common.py --- src/jaxsim/rbda/contacts/common.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index 517ecb483..6bdce1762 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -216,11 +216,21 @@ def link_forces_from_contact_forces( the velocity representation of data. """ + # Get the object storing the contact parameters of the model. + contact_parameters = model.kin_dyn_parameters.contact_parameters + + # Extract the indices corresponding to the enabled collidable points. + indices_of_enabled_collidable_points = ( + contact_parameters.indices_of_enabled_collidable_points + ) + # Convert the contact forces to a JAX array. f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze()) # Get the pose of the enabled collidable points. - W_H_C = js.contact.transforms(model=model, data=data) + W_H_C = js.contact.transforms(model=model, data=data)[ + indices_of_enabled_collidable_points + ] # Convert the contact forces to inertial-fixed representation. W_f_C = jax.vmap( @@ -234,14 +244,6 @@ def link_forces_from_contact_forces( ) )(f_C, W_H_C) - # Get the object storing the contact parameters of the model. - contact_parameters = model.kin_dyn_parameters.contact_parameters - - # Extract the indices corresponding to the enabled collidable points. - indices_of_enabled_collidable_points = ( - contact_parameters.indices_of_enabled_collidable_points - ) - # Construct the vector defining the parent link index of each collidable point. # We use this vector to sum the 6D forces of all collidable points rigidly # attached to the same link. From a2938e53d42d0a3343d3ae8c4cf0ebb24d137ec5 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Wed, 13 Nov 2024 15:28:48 +0100 Subject: [PATCH 05/14] Update `collidable_point_dynamics` in api/contact.py to consider only enabled collidable points --- src/jaxsim/api/contact.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index a01f9b35e..8537b5eab 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -199,16 +199,22 @@ def collidable_point_dynamics( ) # Compute the transforms of the implicit frames `C[L] = (W_p_C, [L])` - # associated to each collidable point. + # associated to the enabled collidable point. # In inertial-fixed representation, the computation of these transforms # is not necessary and the conversion below becomes a no-op. + + # Get the indices of the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + W_H_C = ( js.contact.transforms(model=model, data=data) if data.velocity_representation is not VelRepr.Inertial else jnp.zeros( shape=(len(model.kin_dyn_parameters.contact_parameters.body), 4, 4) ) - ) + )[indices_of_enabled_collidable_points] # Convert the 6D forces to the active representation. f_Ci = jax.vmap( From 951fe09bf19a188266eae184e944f151790abcd5 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Wed, 13 Nov 2024 15:28:48 +0100 Subject: [PATCH 06/14] Update `RelaxedRigidContacts` to consider only enabled collidable points --- src/jaxsim/rbda/contacts/relaxed_rigid.py | 31 ++++++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index ee58790d6..34c8ce2bf 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -7,6 +7,7 @@ import jax import jax.numpy as jnp import jax_dataclasses +import numpy.typing as npt import optax import jaxsim.api as js @@ -322,18 +323,25 @@ def detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array: return jnp.dot(h, n̂) + # Get the indices of the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + # Compute the position and linear velocities (mixed representation) of # all collidable points belonging to the robot. - position, velocity = js.contact.collidable_point_kinematics( - model=model, data=data - ) + p, v = js.contact.collidable_point_kinematics(model=model, data=data) + position = p[indices_of_enabled_collidable_points] + velocity = v[indices_of_enabled_collidable_points] # Compute the activation state of the collidable points δ = jax.vmap(detect_contact)(*position.T) # Compute the transforms of the implicit frames corresponding to the # collidable points. - W_H_C = js.contact.transforms(model=model, data=data) + W_H_C = js.contact.transforms(model=model, data=data)[ + indices_of_enabled_collidable_points + ] with ( references.switch_velocity_representation(VelRepr.Mixed), @@ -357,13 +365,19 @@ def detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array: Jl_WC = jnp.vstack( jax.vmap(lambda J, height: J * (height < 0))( - js.contact.jacobian(model=model, data=data)[:, :3, :], δ + js.contact.jacobian(model=model, data=data)[ + indices_of_enabled_collidable_points, :3, : + ], + δ, ) ) J̇_WC = jnp.vstack( jax.vmap(lambda J̇, height: J̇ * (height < 0))( - js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ + js.contact.jacobian_derivative(model=model, data=data)[ + indices_of_enabled_collidable_points, :3 + ], + δ, ), ) @@ -373,6 +387,7 @@ def detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array: penetration=δ, velocity=velocity, parameters=data.contacts_params, + indices_of_enabled_collidable_points=indices_of_enabled_collidable_points, ) # Compute the Delassus matrix and the free mixed linear acceleration of @@ -499,6 +514,7 @@ def _regularizers( penetration: jtp.Array, velocity: jtp.Array, parameters: RelaxedRigidContactsParams, + indices_of_enabled_collidable_points: npt.NDArray, ) -> tuple: """ Compute the contact jacobian and the reference acceleration. @@ -508,6 +524,7 @@ def _regularizers( penetration: The penetration of the collidable points. velocity: The velocity of the collidable points. parameters: The parameters of the relaxed rigid contacts model. + indices_of_enabled_collidable_points: The indices of the enabled collidable points. Returns: A tuple containing the reference acceleration, the regularization matrix, the stiffness, and the damping. @@ -597,7 +614,7 @@ def compute_row( *jax.vmap(compute_row)( link_idx=jnp.array( model.kin_dyn_parameters.contact_parameters.body - ), + )[indices_of_enabled_collidable_points], penetration=penetration, velocity=velocity, ), From 7c61e468dc92db266101c15767ebef8a74f1d636 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Wed, 13 Nov 2024 15:28:48 +0100 Subject: [PATCH 07/14] Update `SoftContacts` to consider enabled collidable points --- src/jaxsim/rbda/contacts/soft.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index e726df379..b9332e32f 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -445,16 +445,29 @@ def compute_contact_forces( # contact parameters are not compatible. model, data = self.initialize_model_and_data(model=model, data=data) + # Get the indices of the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + # Compute the position and linear velocities (mixed representation) of - # all collidable points belonging to the robot. + # all the collidable points belonging to the robot and extract the ones + # for the enabled collidable points. W_p_C, W_ṗ_C = js.contact.collidable_point_kinematics(model=model, data=data) + W_p_C_enabled = W_p_C[indices_of_enabled_collidable_points] + W_ṗ_C_enabled = W_ṗ_C[indices_of_enabled_collidable_points] # Extract the material deformation corresponding to the collidable points. m = data.state.extended["tangential_deformation"] - # Compute the contact forces for all collidable points. + m_enabled = m[indices_of_enabled_collidable_points] + + # Initialize the tangential deformation rate array for every collidable point. + ṁ = jnp.zeros_like(m) + + # Compute the contact forces only for the enabled collidable points. # Since we treat them as independent, we can vmap the computation. - W_f, ṁ = jax.vmap( + W_f, ṁ_enabled = jax.vmap( lambda p, v, m: SoftContacts.compute_contact_force( position=p, velocity=v, @@ -462,6 +475,8 @@ def compute_contact_forces( parameters=data.contacts_params, terrain=model.terrain, ) - )(W_p_C, W_ṗ_C, m) + )(W_p_C_enabled, W_ṗ_C_enabled, m_enabled) + + ṁ = ṁ.at[indices_of_enabled_collidable_points].set(ṁ_enabled) return W_f, dict(m_dot=ṁ) From c6b06f5ba2cdea6542d31c9158905f79316cb8a4 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Wed, 13 Nov 2024 15:28:48 +0100 Subject: [PATCH 08/14] Update `test_simulations.py` Enable a subset of collidable points in SoftContacts, RigidContacts, and RelaxedRigidContacts --- tests/test_simulations.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 49240987a..20164ce53 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -212,6 +212,16 @@ def test_simulation_with_soft_contacts( model.contact_model = jaxsim.rbda.contacts.SoftContacts.build( terrain=model.terrain, ) + # Enable a subset of the collidable points. + enabled_collidable_points_mask = np.zeros( + len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool + ) + enabled_collidable_points_mask[[0, 1, 2, 3]] = True + model.kin_dyn_parameters.contact_parameters.enabled = tuple( + enabled_collidable_points_mask.tolist() + ) + + assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4 # Initialize the maximum penetration of each collidable point at steady state. max_penetration = 0.001 @@ -296,6 +306,16 @@ def test_simulation_with_rigid_contacts( model.contact_model = jaxsim.rbda.contacts.RigidContacts.build( terrain=model.terrain, ) + # Enable a subset of the collidable points. + enabled_collidable_points_mask = np.zeros( + len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool + ) + enabled_collidable_points_mask[[0, 1, 2, 3]] = True + model.kin_dyn_parameters.contact_parameters.enabled = tuple( + enabled_collidable_points_mask.tolist() + ) + + assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4 # Initialize the maximum penetration of each collidable point at steady state. # This model is rigid, so we expect (almost) no penetration. @@ -338,6 +358,16 @@ def test_simulation_with_relaxed_rigid_contacts( model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts.build( terrain=model.terrain, ) + # Enable a subset of the collidable points. + enabled_collidable_points_mask = np.zeros( + len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool + ) + enabled_collidable_points_mask[[0, 1, 2, 3]] = True + model.kin_dyn_parameters.contact_parameters.enabled = tuple( + enabled_collidable_points_mask.tolist() + ) + + assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4 # Initialize the maximum penetration of each collidable point at steady state. # This model is quasi-rigid, so we expect (almost) no penetration. From da9c3f2e97e91ab1cc1b277948891435fc233d60 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Wed, 13 Nov 2024 15:28:48 +0100 Subject: [PATCH 09/14] Uniform `jaxsim.api.contact` APIs to consider only enabled collidable points --- src/jaxsim/api/contact.py | 87 ++++++++++++++++++++++++++------------- 1 file changed, 59 insertions(+), 28 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 8537b5eab..eb1fa19f1 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -138,7 +138,7 @@ def collidable_point_dynamics( **kwargs, ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: r""" - Compute the 6D force applied to each collidable point. + Compute the 6D force applied to each enabled collidable point. Args: model: The model to consider. @@ -151,7 +151,7 @@ def collidable_point_dynamics( kwargs: Additional keyword arguments to pass to the active contact model. Returns: - The 6D force applied to each collidable point and additional data based + The 6D force applied to each eneabled collidable point and additional data based on the contact model configured: - Soft: the material deformation rate. - Rigid: no additional data. @@ -211,10 +211,8 @@ def collidable_point_dynamics( W_H_C = ( js.contact.transforms(model=model, data=data) if data.velocity_representation is not VelRepr.Inertial - else jnp.zeros( - shape=(len(model.kin_dyn_parameters.contact_parameters.body), 4, 4) - ) - )[indices_of_enabled_collidable_points] + else jnp.zeros(shape=(len(indices_of_enabled_collidable_points), 4, 4)) + ) # Convert the 6D forces to the active representation. f_Ci = jax.vmap( @@ -252,6 +250,15 @@ def in_contact( if link_names is not None and set(link_names).difference(model.link_names()): raise ValueError("One or more link names are not part of the model") + # Get the indices of the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + parent_link_idx_of_enabled_collidable_points = jnp.array( + model.kin_dyn_parameters.contact_parameters.body, dtype=int + )[indices_of_enabled_collidable_points] + W_p_Ci = collidable_point_positions(model=model, data=data) terrain_height = jax.vmap(lambda x, y: model.terrain.height(x=x, y=y))( @@ -268,7 +275,7 @@ def in_contact( links_in_contact = jax.vmap( lambda link_index: jnp.where( - jnp.array(model.kin_dyn_parameters.contact_parameters.body) == link_index, + parent_link_idx_of_enabled_collidable_points == link_index, below_terrain, jnp.zeros_like(below_terrain, dtype=bool), ).any() @@ -432,14 +439,14 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: @jax.jit def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array: r""" - Return the pose of the collidable points. + Return the pose of the enabled collidable points. Args: model: The model to consider. data: The data of the considered model. Returns: - The stacked SE(3) matrices of all collidable points. + The stacked SE(3) matrices of all enabled collidable points. Note: Each collidable point is implicitly associated with a frame @@ -448,16 +455,27 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt rigidly attached to. """ + # Get the indices of the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + parent_link_idx_of_enabled_collidable_points = jnp.array( + model.kin_dyn_parameters.contact_parameters.body, dtype=int + )[indices_of_enabled_collidable_points] + # Get the transforms of the parent link of all collidable points. W_H_L = js.model.forward_kinematics(model=model, data=data)[ - jnp.array(model.kin_dyn_parameters.contact_parameters.body, dtype=int) + parent_link_idx_of_enabled_collidable_points + ] + + L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[ + indices_of_enabled_collidable_points ] # Build the link-to-point transform from the displacement between the link frame L # and the implicit contact frame C. - L_H_C = jax.vmap(lambda L_p_C: jnp.eye(4).at[0:3, 3].set(L_p_C))( - model.kin_dyn_parameters.contact_parameters.point - ) + L_H_C = jax.vmap(lambda L_p_C: jnp.eye(4).at[0:3, 3].set(L_p_C))(L_p_Ci) # Compose the work-to-link and link-to-point transforms. return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C) @@ -471,7 +489,7 @@ def jacobian( output_vel_repr: VelRepr | None = None, ) -> jtp.Array: r""" - Return the free-floating Jacobian of the collidable points. + Return the free-floating Jacobian of the enabled collidable points. Args: model: The model to consider. @@ -481,7 +499,7 @@ def jacobian( Returns: The stacked :math:`6 \times (6+n)` free-floating jacobians of the frames associated to the - collidable points. + enabled collidable points. Note: Each collidable point is implicitly associated with a frame @@ -494,6 +512,15 @@ def jacobian( output_vel_repr if output_vel_repr is not None else data.velocity_representation ) + # Get the indices of the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + parent_link_idx_of_enabled_collidable_points = jnp.array( + model.kin_dyn_parameters.contact_parameters.body, dtype=int + )[indices_of_enabled_collidable_points] + # Compute the Jacobians of all links. W_J_WL = js.model.generalized_free_floating_jacobian( model=model, data=data, output_vel_repr=VelRepr.Inertial @@ -502,9 +529,7 @@ def jacobian( # Compute the contact Jacobian. # In inertial-fixed output representation, the Jacobian of the parent link is also # the Jacobian of the frame C implicitly associated with the collidable point. - W_J_WC = W_J_WL[ - jnp.array(model.kin_dyn_parameters.contact_parameters.body, dtype=int) - ] + W_J_WC = W_J_WL[parent_link_idx_of_enabled_collidable_points] # Adjust the output representation. match output_vel_repr: @@ -556,7 +581,7 @@ def jacobian_derivative( output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: r""" - Compute the derivative of the free-floating jacobian of the contact points. + Compute the derivative of the free-floating jacobian of the enabled collidable points. Args: model: The model to consider. @@ -565,7 +590,7 @@ def jacobian_derivative( The output velocity representation of the free-floating jacobian derivative. Returns: - The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the contact points. + The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the enabled collidable points. Note: The input representation of the free-floating jacobian derivative is the active @@ -576,10 +601,18 @@ def jacobian_derivative( output_vel_repr if output_vel_repr is not None else data.velocity_representation ) + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + # Get the index of the parent link and the position of the collidable point. - parent_link_idxs = jnp.array(model.kin_dyn_parameters.contact_parameters.body) - L_p_Ci = jnp.array(model.kin_dyn_parameters.contact_parameters.point) - contact_idxs = jnp.arange(L_p_Ci.shape[0]) + parent_link_idx_of_enabled_collidable_points = jnp.array( + model.kin_dyn_parameters.contact_parameters.body, dtype=int + )[indices_of_enabled_collidable_points] + + L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[ + indices_of_enabled_collidable_points + ] # Get the transforms of all the parent links. W_H_Li = js.model.forward_kinematics(model=model, data=data) @@ -652,7 +685,7 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: output_vel_repr=VelRepr.Inertial, ) - # Get the Jacobian of the collidable points in the mixed representation. + # Get the Jacobian of the enabled collidable points in the mixed representation. with data.switch_velocity_representation(VelRepr.Mixed): CW_J_WC_BW = jacobian( model=model, @@ -662,13 +695,11 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: def compute_O_J̇_WC_I( L_p_C: jtp.Vector, - contact_idx: jtp.Int, + parent_link_idx: jtp.Int, CW_J_WC_BW: jtp.Matrix, W_H_L: jtp.Matrix, ) -> jtp.Matrix: - parent_link_idx = parent_link_idxs[contact_idx] - match output_vel_repr: case VelRepr.Inertial: O_X_W = W_X_W = Adjoint.from_transform( # noqa: F841 @@ -709,7 +740,7 @@ def compute_O_J̇_WC_I( return O_J̇_WC_I O_J̇_WC = jax.vmap(compute_O_J̇_WC_I, in_axes=(0, 0, 0, None))( - L_p_Ci, contact_idxs, CW_J_WC_BW, W_H_Li + L_p_Ci, parent_link_idx_of_enabled_collidable_points, CW_J_WC_BW, W_H_Li ) return O_J̇_WC From deffe0830d549ceed8808740e624cd630c5642f2 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Wed, 13 Nov 2024 15:28:48 +0100 Subject: [PATCH 10/14] Update `collidable_points_pos_vel` to compute positions and velocities of enabled collidable points --- src/jaxsim/rbda/collidable_points.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/jaxsim/rbda/collidable_points.py b/src/jaxsim/rbda/collidable_points.py index a7940c62a..543be5328 100644 --- a/src/jaxsim/rbda/collidable_points.py +++ b/src/jaxsim/rbda/collidable_points.py @@ -21,7 +21,7 @@ def collidable_points_pos_vel( ) -> tuple[jtp.Matrix, jtp.Matrix]: """ - Compute the position and linear velocity of collidable points in the world frame. + Compute the position and linear velocity of the enabled collidable points in the world frame. Args: model: The model to consider. @@ -35,10 +35,23 @@ def collidable_points_pos_vel( joint_velocities: The velocities of the joints. Returns: - A tuple containing the position and linear velocity of collidable points. + A tuple containing the position and linear velocity of the enabled collidable points. """ - if len(model.kin_dyn_parameters.contact_parameters.body) == 0: + # Get the indices of the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + parent_link_idx_of_enabled_collidable_points = jnp.array( + model.kin_dyn_parameters.contact_parameters.body, dtype=int + )[indices_of_enabled_collidable_points] + + L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[ + indices_of_enabled_collidable_points + ] + + if len(indices_of_enabled_collidable_points) == 0: return jnp.array(0).astype(float), jnp.empty(0).astype(float) W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs( @@ -136,8 +149,8 @@ def process_point_kinematics( # Process all the collidable points in parallel. W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)( - model.kin_dyn_parameters.contact_parameters.point, - jnp.array(model.kin_dyn_parameters.contact_parameters.body), + L_p_Ci, + parent_link_idx_of_enabled_collidable_points, ) return W_p_Ci, CW_vl_WC From d113766f9d8f484c16f5289c06dd8e3d86f7e257 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Wed, 13 Nov 2024 15:28:48 +0100 Subject: [PATCH 11/14] Refactor `RelaxedRigidContacts` to streamline handling of enabled collidable points --- src/jaxsim/rbda/contacts/relaxed_rigid.py | 40 +++++++++-------------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 34c8ce2bf..f91d4bd39 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -7,7 +7,6 @@ import jax import jax.numpy as jnp import jax_dataclasses -import numpy.typing as npt import optax import jaxsim.api as js @@ -323,25 +322,18 @@ def detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array: return jnp.dot(h, n̂) - # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) - # Compute the position and linear velocities (mixed representation) of # all collidable points belonging to the robot. - p, v = js.contact.collidable_point_kinematics(model=model, data=data) - position = p[indices_of_enabled_collidable_points] - velocity = v[indices_of_enabled_collidable_points] + position, velocity = js.contact.collidable_point_kinematics( + model=model, data=data + ) # Compute the activation state of the collidable points δ = jax.vmap(detect_contact)(*position.T) # Compute the transforms of the implicit frames corresponding to the # collidable points. - W_H_C = js.contact.transforms(model=model, data=data)[ - indices_of_enabled_collidable_points - ] + W_H_C = js.contact.transforms(model=model, data=data) with ( references.switch_velocity_representation(VelRepr.Mixed), @@ -365,18 +357,14 @@ def detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array: Jl_WC = jnp.vstack( jax.vmap(lambda J, height: J * (height < 0))( - js.contact.jacobian(model=model, data=data)[ - indices_of_enabled_collidable_points, :3, : - ], + js.contact.jacobian(model=model, data=data)[:, :3, :], δ, ) ) J̇_WC = jnp.vstack( jax.vmap(lambda J̇, height: J̇ * (height < 0))( - js.contact.jacobian_derivative(model=model, data=data)[ - indices_of_enabled_collidable_points, :3 - ], + js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ, ), ) @@ -387,7 +375,6 @@ def detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array: penetration=δ, velocity=velocity, parameters=data.contacts_params, - indices_of_enabled_collidable_points=indices_of_enabled_collidable_points, ) # Compute the Delassus matrix and the free mixed linear acceleration of @@ -514,7 +501,6 @@ def _regularizers( penetration: jtp.Array, velocity: jtp.Array, parameters: RelaxedRigidContactsParams, - indices_of_enabled_collidable_points: npt.NDArray, ) -> tuple: """ Compute the contact jacobian and the reference acceleration. @@ -524,7 +510,6 @@ def _regularizers( penetration: The penetration of the collidable points. velocity: The velocity of the collidable points. parameters: The parameters of the relaxed rigid contacts model. - indices_of_enabled_collidable_points: The indices of the enabled collidable points. Returns: A tuple containing the reference acceleration, the regularization matrix, the stiffness, and the damping. @@ -547,6 +532,15 @@ def _regularizers( ) ) + # Get the indices of the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + parent_link_idx_of_enabled_collidable_points = jnp.array( + model.kin_dyn_parameters.contact_parameters.body, dtype=int + )[indices_of_enabled_collidable_points] + # Compute the 6D inertia matrices of all links. M_L = js.model.link_spatial_inertia_matrices(model=model) @@ -612,9 +606,7 @@ def compute_row( f=jnp.concatenate, tree=( *jax.vmap(compute_row)( - link_idx=jnp.array( - model.kin_dyn_parameters.contact_parameters.body - )[indices_of_enabled_collidable_points], + link_idx=parent_link_idx_of_enabled_collidable_points, penetration=penetration, velocity=velocity, ), From e21f84b96310b8385fd6593473b5f20882ca6e93 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Wed, 13 Nov 2024 15:28:48 +0100 Subject: [PATCH 12/14] Refactor `RigidContacts` to streamline handling of enabled collidable points --- src/jaxsim/rbda/contacts/rigid.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 4eba17e73..5d8a8b87e 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -310,22 +310,16 @@ def compute_contact_forces( M = js.model.free_floating_mass_matrix(model=model, data=data) - J_WC = js.contact.jacobian(model=model, data=data)[ - indices_of_enabled_collidable_points - ] - J̇_WC = js.contact.jacobian_derivative(model=model, data=data)[ - indices_of_enabled_collidable_points - ] + J_WC = js.contact.jacobian(model=model, data=data) + J̇_WC = js.contact.jacobian_derivative(model=model, data=data) - W_H_C = js.contact.transforms(model=model, data=data)[ - indices_of_enabled_collidable_points - ] + W_H_C = js.contact.transforms(model=model, data=data) # Compute the position and linear velocities (mixed representation) of - # all collidable points belonging to the robot. - p, v = js.contact.collidable_point_kinematics(model=model, data=data) - position = p[indices_of_enabled_collidable_points] - velocity = v[indices_of_enabled_collidable_points] + # all enabled collidable points belonging to the robot. + position, velocity = js.contact.collidable_point_kinematics( + model=model, data=data + ) # Compute the penetration depth and velocity of the collidable points. # Note that this function considers the penetration in the normal direction. From 56ff3aab54a14055ffa6cd028c8d46048f31da2c Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Wed, 13 Nov 2024 15:28:48 +0100 Subject: [PATCH 13/14] Refactor `SoftContacts` to streamline handling of enabled collidable points --- src/jaxsim/rbda/contacts/soft.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index b9332e32f..92f0ba091 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -454,8 +454,6 @@ def compute_contact_forces( # all the collidable points belonging to the robot and extract the ones # for the enabled collidable points. W_p_C, W_ṗ_C = js.contact.collidable_point_kinematics(model=model, data=data) - W_p_C_enabled = W_p_C[indices_of_enabled_collidable_points] - W_ṗ_C_enabled = W_ṗ_C[indices_of_enabled_collidable_points] # Extract the material deformation corresponding to the collidable points. m = data.state.extended["tangential_deformation"] @@ -475,7 +473,7 @@ def compute_contact_forces( parameters=data.contacts_params, terrain=model.terrain, ) - )(W_p_C_enabled, W_ṗ_C_enabled, m_enabled) + )(W_p_C, W_ṗ_C, m_enabled) ṁ = ṁ.at[indices_of_enabled_collidable_points].set(ṁ_enabled) From c531bbd94e729e446c0bc08ea943cac43ea2cf1a Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Wed, 13 Nov 2024 16:52:28 +0100 Subject: [PATCH 14/14] Refactor `test_api_contact.py` to use indices of enabled collidable points --- tests/test_api_contact.py | 40 +++++++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/tests/test_api_contact.py b/tests/test_api_contact.py index 6456f2645..4a0882737 100644 --- a/tests/test_api_contact.py +++ b/tests/test_api_contact.py @@ -22,6 +22,15 @@ def test_contact_kinematics( velocity_representation=velocity_representation, ) + # Get the indices of the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + parent_link_idx_of_enabled_collidable_points = jnp.array( + model.kin_dyn_parameters.contact_parameters.body, dtype=int + )[indices_of_enabled_collidable_points] + # ===== # Tests # ===== @@ -34,7 +43,7 @@ def test_contact_kinematics( # Check that the orientation of the implicit contact frame matches with the # orientation of the link to which the contact point is attached. for contact_idx, index_of_parent_link in enumerate( - model.kin_dyn_parameters.contact_parameters.body + parent_link_idx_of_enabled_collidable_points ): assert W_H_C[contact_idx, 0:3, 0:3] == pytest.approx( W_H_L[index_of_parent_link][0:3, 0:3] @@ -74,29 +83,40 @@ def test_contact_jacobian_derivative( velocity_representation=velocity_representation, ) - # ===== - # Tests - # ===== + # Get the indices of the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) # Extract the parent link names and the poses of the contact points. parent_link_names = js.link.idxs_to_names( - model=model, link_indices=model.kin_dyn_parameters.contact_parameters.body + model=model, + link_indices=jnp.array( + model.kin_dyn_parameters.contact_parameters.body, dtype=int + )[indices_of_enabled_collidable_points], ) - W_p_Ci = model.kin_dyn_parameters.contact_parameters.point + + L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[ + indices_of_enabled_collidable_points + ] + + # ===== + # Tests + # ===== # Load the model in ROD. rod_model = rod.Sdf.load(sdf=model.built_from).model # Add dummy frames on the contact points. - for idx, (link_name, W_p_C) in enumerate( - zip(parent_link_names, W_p_Ci, strict=True) + for idx, link_name, L_p_C in zip( + indices_of_enabled_collidable_points, parent_link_names, L_p_Ci, strict=True ): rod_model.add_frame( frame=rod.Frame( name=f"contact_point_{idx}", attached_to=link_name, pose=rod.Pose( - relative_to=link_name, pose=jnp.zeros(shape=(6,)).at[0:3].set(W_p_C) + relative_to=link_name, pose=jnp.zeros(shape=(6,)).at[0:3].set(L_p_C) ), ), ) @@ -125,7 +145,7 @@ def test_contact_jacobian_derivative( frame_idxs = js.frame.names_to_idxs( model=model_with_frames, frame_names=( - f"contact_point_{idx}" for idx in list(range(len(parent_link_names))) + f"contact_point_{idx}" for idx in indices_of_enabled_collidable_points ), )