Skip to content

Commit

Permalink
Avoid unnecessary jax.vmap in link_forces_from_contact_forces
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Nov 21, 2024
1 parent 0f3c5e9 commit ceb4d20
Showing 1 changed file with 12 additions and 20 deletions.
32 changes: 12 additions & 20 deletions src/jaxsim/rbda/contacts/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,16 +233,12 @@ def link_forces_from_contact_forces(
]

# Convert the contact forces to inertial-fixed representation.
W_f_C = jax.vmap(
lambda f_C, W_H_C: (
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
array=f_C,
other_representation=data.velocity_representation,
transform=W_H_C,
is_force=True,
)
)
)(f_C, W_H_C)
W_f_C = ModelDataWithVelocityRepresentation.other_representation_to_inertial(
array=f_C,
other_representation=data.velocity_representation,
transform=W_H_C,
is_force=True,
)

# Construct the vector defining the parent link index of each collidable point.
# We use this vector to sum the 6D forces of all collidable points rigidly
Expand Down Expand Up @@ -270,16 +266,12 @@ def link_forces_from_contact_forces(
)

# Convert the inertial-fixed link forces to the velocity representation of data.
f_L = jax.vmap(
lambda W_f_L, W_H_L: (
ModelDataWithVelocityRepresentation.inertial_to_other_representation(
array=W_f_L,
other_representation=data.velocity_representation,
transform=W_H_L,
is_force=True,
)
)
)(W_f_L, W_H_L)
f_L = ModelDataWithVelocityRepresentation.inertial_to_other_representation(
array=W_f_L,
other_representation=data.velocity_representation,
transform=W_H_L,
is_force=True,
)

return f_L

Expand Down

0 comments on commit ceb4d20

Please sign in to comment.