Skip to content

Commit

Permalink
Avoid unnecessary jax.vmap in RelaxedRigidContacts
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Nov 22, 2024
1 parent c8c9451 commit 515759c
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,16 +490,12 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool:
CW_fl_C = solution.reshape(-1, 3)

# Convert the contact forces from mixed to inertial-fixed representation.
W_f_C = jax.vmap(
lambda CW_fl_C, W_H_C: (
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
array=jnp.zeros(6).at[0:3].set(CW_fl_C),
transform=W_H_C,
other_representation=VelRepr.Mixed,
is_force=True,
)
),
)(CW_fl_C, W_H_C)
W_f_C = ModelDataWithVelocityRepresentation.other_representation_to_inertial(
array=jnp.hstack((CW_fl_C, jnp.zeros_like(CW_fl_C))),
transform=W_H_C,
other_representation=VelRepr.Mixed,
is_force=True,
)

return W_f_C, {}

Expand Down

0 comments on commit 515759c

Please sign in to comment.