Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finalize minor changes for v0.5 release #272

Merged
merged 22 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
11a0b00
Remove unnecessary lambdas
flferretti Oct 22, 2024
a39abfe
Move `MujocoCamera` to `mujoco.utils` and define `MujocoCameraType`
flferretti Oct 22, 2024
58bba61
Improve log formatting
flferretti Oct 22, 2024
08f3f9e
Use bitwise OR to verify tracing
flferretti Oct 22, 2024
82529c4
Simplify `JointDescription` hash
flferretti Oct 22, 2024
5d251ec
Use bitwise NOT instead of `jnp.logical_not`
flferretti Oct 22, 2024
09f6530
Simplify index extraction checks in `api.link` and `api.joint`
flferretti Oct 22, 2024
32dcd40
Use bitwise operators instead of `jnp.logical_*`
flferretti Oct 22, 2024
2d9ad59
Fix import in mujoco.model module
flferretti Oct 23, 2024
15b42f2
Fix return type of `__iter__`
flferretti Oct 23, 2024
9296e62
Avoid explicit calls to `dict.keys()` when iterating
flferretti Oct 23, 2024
1ca0b0f
Merge chained context managers
flferretti Oct 23, 2024
f6698aa
Test passing params and solver options to `RigidContacts` model
flferretti Oct 25, 2024
76ad446
Test passing params and solver options to `RelaxedRigidContacts` model
flferretti Oct 25, 2024
59d7091
Reduce default sphere collidable points
flferretti Oct 30, 2024
0fad8ab
Allow not passing `x` and `y` for flat terrains normal
flferretti Oct 30, 2024
aca7807
Fix `ZeroDivision` error in `PlaneTerrain.height`
flferretti Oct 30, 2024
9f8b0b7
Simplify position limit extraction logic
flferretti Nov 5, 2024
b76d7a7
Minor updates to jaxsim.api.model
diegoferigo Oct 22, 2024
b3c59ea
Minor updated to kin_dyn_parameters
diegoferigo Nov 14, 2024
f00c3d4
Remove JAX deprecated methods
flferretti Nov 14, 2024
6e097d5
Speed up `js.link.idxs_to_names`
flferretti Nov 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions src/jaxsim/api/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str
"""

exceptions.raise_value_error_if(
condition=jnp.array(
[joint_index < 0, joint_index >= model.number_of_joints()]
).any(),
condition=joint_index < 0,
msg="Invalid joint index '{idx}'",
idx=joint_index,
)
Expand Down Expand Up @@ -123,10 +121,7 @@ def position_limit(
"""

if model.number_of_joints() == 0:
s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min
s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max

return jnp.atleast_1d(s_min).astype(float), jnp.atleast_1d(s_max).astype(float)
return jnp.empty(0).astype(float), jnp.empty(0).astype(float)

exceptions.raise_value_error_if(
condition=jnp.array(
Expand All @@ -136,8 +131,12 @@ def position_limit(
idx=joint_index,
)

s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min[joint_index]
s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max[joint_index]
s_min = jnp.atleast_1d(
model.kin_dyn_parameters.joint_parameters.position_limits_min
)[joint_index]
s_max = jnp.atleast_1d(
model.kin_dyn_parameters.joint_parameters.position_limits_max
)[joint_index]

return s_min.astype(float), s_max.astype(float)

Expand Down
10 changes: 6 additions & 4 deletions src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,9 @@ def joint_transforms_and_motion_subspaces(
# Helpers to update parameters
# ============================

def set_link_mass(self, link_index: int, mass: jtp.FloatLike) -> KynDynParameters:
def set_link_mass(
self, link_index: jtp.IntLike, mass: jtp.FloatLike
) -> KynDynParameters:
"""
Set the mass of a link.

Expand All @@ -457,7 +459,7 @@ def set_link_mass(self, link_index: int, mass: jtp.FloatLike) -> KynDynParameter
return self.replace(link_parameters=link_parameters)

def set_link_inertia(
self, link_index: int, inertia: jtp.MatrixLike
self, link_index: jtp.IntLike, inertia: jtp.MatrixLike
) -> KynDynParameters:
r"""
Set the inertia tensor of a link.
Expand Down Expand Up @@ -593,10 +595,10 @@ def build_from_spatial_inertia(index: jtp.IntLike, M: jtp.Matrix) -> LinkParamet
"""

# Extract the link parameters from the 6D spatial inertia.
m, L_p_CoM, I = Inertia.to_params(M=M)
m, L_p_CoM, I_CoM = Inertia.to_params(M=M)

# Extract only the necessary elements of the inertia tensor.
inertia_elements = I[jnp.triu_indices(3)]
inertia_elements = I_CoM[jnp.triu_indices(3)]

return LinkParameters(
index=jnp.array(index).squeeze().astype(int),
Expand Down
7 changes: 3 additions & 4 deletions src/jaxsim/api/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jax
import jax.numpy as jnp
import jax.scipy.linalg
import numpy as np

import jaxsim.api as js
import jaxsim.rbda
Expand Down Expand Up @@ -54,9 +55,7 @@ def idx_to_name(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
"""

exceptions.raise_value_error_if(
condition=jnp.array(
[link_index < 0, link_index >= model.number_of_links()]
).any(),
condition=link_index < 0,
xela-95 marked this conversation as resolved.
Show resolved Hide resolved
msg="Invalid link index '{idx}'",
idx=link_index,
)
Expand Down Expand Up @@ -98,7 +97,7 @@ def idxs_to_names(
The names of the links.
"""

return tuple(idx_to_name(model=model, link_index=idx) for idx in link_indices)
return tuple(np.array(model.kin_dyn_parameters.link_names)[list(link_indices)])


# =========
Expand Down
29 changes: 10 additions & 19 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def name(self) -> str:

return self.model_name

def number_of_links(self) -> jtp.Int:
def number_of_links(self) -> int:
"""
Return the number of links in the model.

Expand All @@ -317,7 +317,7 @@ def number_of_links(self) -> jtp.Int:

return self.kin_dyn_parameters.number_of_links()

def number_of_joints(self) -> jtp.Int:
def number_of_joints(self) -> int:
"""
Return the number of joints in the model.

Expand Down Expand Up @@ -419,7 +419,7 @@ def frame_names(self) -> tuple[str, ...]:
def reduce(
model: JaxSimModel,
considered_joints: tuple[str, ...],
locked_joint_positions: dict[str, jtp.Float] | None = None,
locked_joint_positions: dict[str, jtp.FloatLike] | None = None,
) -> JaxSimModel:
"""
Reduce the model by lumping together the links connected by removed joints.
Expand Down Expand Up @@ -1038,12 +1038,7 @@ def to_active(
C_v̇_WB = to_active(
W_v̇_WB=W_v̇_WB,
W_H_C=W_H_C,
W_v_WB=jnp.hstack(
[
data.state.physics_model.base_linear_velocity,
data.state.physics_model.base_angular_velocity,
]
),
W_v_WB=W_v_WB,
W_v_WC=W_v_WC,
)

Expand Down Expand Up @@ -2274,16 +2269,12 @@ def step(
# Raise runtime error for not supported case in which Rigid contacts and
# Baumgarte stabilization are enabled and used with ForwardEuler integrator.
jaxsim.exceptions.raise_runtime_error_if(
condition=jnp.logical_and(
isinstance(
integrator,
jaxsim.integrators.fixed_step.ForwardEuler
| jaxsim.integrators.fixed_step.ForwardEulerSO3,
),
jnp.array(
[data_tf.contacts_params.K, data_tf.contacts_params.D]
).any(),
),
condition=isinstance(
integrator,
jaxsim.integrators.fixed_step.ForwardEuler
| jaxsim.integrators.fixed_step.ForwardEulerSO3,
)
& ((data_tf.contacts_params.K > 0) | (data_tf.contacts_params.D > 0)),
msg="Baumgarte stabilization is not supported with ForwardEuler integrators",
)

Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/api/references.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def apply_frame_forces(
]

exceptions.raise_value_error_if(
condition=jnp.logical_not(data.valid(model=model)),
condition=~data.valid(model=model),
msg="The provided data is not valid for the model",
)
W_H_Fi = jax.vmap(
Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/integrators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def _compute_next_state(
f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)

# Initialize the carry of the for loop with the stacked kᵢ vectors.
carry0 = jax.tree_map(
carry0 = jax.tree.map(
lambda l: jnp.zeros((c.size, *l.shape), dtype=l.dtype), x0
)

Expand Down Expand Up @@ -507,7 +507,7 @@ def post_process_state(

# We assume that the initial quaternion is already unary.
exceptions.raise_runtime_error_if(
condition=jnp.logical_not(jnp.allclose(W_Q_B_t0.dot(W_Q_B_t0), 1.0)),
condition=~jnp.allclose(W_Q_B_t0.dot(W_Q_B_t0), 1.0),
msg="The SO(3) integrator received a quaternion at t0 that is not unary.",
)

Expand Down
18 changes: 6 additions & 12 deletions src/jaxsim/integrators/variable_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def compute_pytree_scale(
"""

# Consider a zero second pytree, if not given.
x2 = jax.tree.map(lambda l: jnp.zeros_like(l), x1) if x2 is None else x2
x2 = jax.tree.map(jnp.zeros_like, x1) if x2 is None else x2

# Compute the scaling factors of the initial state and its derivative.
compute_scale = lambda l1, l2: atol + jnp.maximum(jnp.abs(l1), jnp.abs(l2)) * rtol
Expand Down Expand Up @@ -199,9 +199,7 @@ def flatten(pytree) -> jax.Array:

# Consider a zero estimated final state, if not given.
xf_estimate = (
jax.tree.map(lambda l: jnp.zeros_like(l), xf)
if xf_estimate is None
else xf_estimate
jax.tree.map(jnp.zeros_like, xf) if xf_estimate is None else xf_estimate
)

# Estimate the error.
Expand Down Expand Up @@ -483,14 +481,10 @@ def reject_step():
metadata_next,
discarded_steps,
) = jax.lax.cond(
pred=jnp.array(
[
discarded_steps >= self.max_step_rejections,
local_error <= 1.0,
Δt_next < self.dt_min,
integrator_init,
]
).any(),
pred=discarded_steps
>= self.max_step_rejections | local_error
<= 1.0 | Δt_next
< self.dt_min | integrator_init,
true_fun=accept_step,
false_fun=reject_step,
)
Expand Down
Loading