Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce joint position limits #275

Merged
merged 8 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/jaxsim/api/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,11 @@ def position_limit(
The position limits of the joint.
"""

if model.number_of_joints() <= 1:
return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
if model.number_of_joints() == 0:
s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min
s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max

return jnp.atleast_1d(s_min).astype(float), jnp.atleast_1d(s_max).astype(float)

exceptions.raise_value_error_if(
condition=jnp.array(
Expand Down
33 changes: 32 additions & 1 deletion src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,40 @@ def system_acceleration(
# Enforce joint limits
# ====================

# TODO: enforce joint limits
τ_position_limit = jnp.zeros_like(τ_references).astype(float)

if model.dofs() > 0:

# Stiffness and damper parameters for the joint position limits.
k_j = jnp.array(
model.kin_dyn_parameters.joint_parameters.position_limit_spring
).astype(float)
d_j = jnp.array(
model.kin_dyn_parameters.joint_parameters.position_limit_damper
).astype(float)

# Compute the joint position limit violations.
lower_violation = jnp.clip(
data.state.physics_model.joint_positions
- model.kin_dyn_parameters.joint_parameters.position_limits_min,
max=0.0,
)

upper_violation = jnp.clip(
data.state.physics_model.joint_positions
- model.kin_dyn_parameters.joint_parameters.position_limits_max,
min=0.0,
)

# Compute the joint position limit torque.
τ_position_limit -= jnp.diag(k_j) @ (lower_violation + upper_violation)

τ_position_limit -= (
jnp.positive(τ_position_limit)
* jnp.diag(d_j)
@ data.state.physics_model.joint_velocities
)

# ====================
# Joint friction model
# ====================
Expand Down
5 changes: 3 additions & 2 deletions src/jaxsim/parsers/rod/parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import os
import pathlib
from typing import NamedTuple

Expand Down Expand Up @@ -273,14 +274,14 @@ def extract_model_data(
if j.axis is not None
and j.axis.limit is not None
and j.axis.limit.dissipation is not None
else 0.0
else os.environ.get("JAXSIM_JOINT_POSITION_LIMIT_DAMPER", 0.0)
),
position_limit_spring=float(
j.axis.limit.stiffness
if j.axis is not None
and j.axis.limit is not None
and j.axis.limit.stiffness is not None
else 0.0
else os.environ.get("JAXSIM_JOINT_POSITION_LIMIT_SPRING", 0.0)
),
)
for j in sdf_model.joints()
Expand Down
112 changes: 112 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,116 @@ def jaxsim_model_ur10() -> js.model.JaxSimModel:
return build_jaxsim_model(model_description=model_urdf_path)


@pytest.fixture(scope="session")
def jaxsim_model_single_pendulum() -> js.model.JaxSimModel:
flferretti marked this conversation as resolved.
Show resolved Hide resolved
"""
Fixture providing the JaxSim model of a single pendulum.
Returns:
The JaxSim model of a single pendulum.
"""

import numpy as np
import rod.builder.primitives

base_height = 2.15
upper_height = 1.0

# ===================
# Create the builders
# ===================

base_builder = rod.builder.primitives.BoxBuilder(
name="base",
mass=1.0,
x=0.15,
y=0.15,
z=base_height,
)

upper_builder = rod.builder.primitives.BoxBuilder(
name="upper",
mass=0.5,
x=0.15,
y=0.15,
z=upper_height,
)

# =================
# Create the joints
# =================

fixed = rod.Joint(
name="fixed_joint",
type="fixed",
parent="world",
child=base_builder.name,
)

pivot = rod.Joint(
name="upper_joint",
type="continuous",
parent=base_builder.name,
child=upper_builder.name,
axis=rod.Axis(
xyz=rod.Xyz([1, 0, 0]),
limit=rod.Limit(),
),
)

# ================
# Create the links
# ================

base = (
base_builder.build_link(
name=base_builder.name,
pose=rod.builder.primitives.PrimitiveBuilder.build_pose(
pos=np.array([0, 0, base_height / 2])
),
)
.add_inertial()
.add_visual()
.add_collision()
.build()
)

upper_pose = rod.builder.primitives.PrimitiveBuilder.build_pose(
pos=np.array([0, 0, upper_height / 2])
)

upper = (
upper_builder.build_link(
name=upper_builder.name,
pose=rod.builder.primitives.PrimitiveBuilder.build_pose(
relative_to=base.name, pos=np.array([0, 0, upper_height])
),
)
.add_inertial(pose=upper_pose)
.add_visual(pose=upper_pose)
.add_collision(pose=upper_pose)
.build()
)

rod_model = rod.Sdf(
version="1.10",
model=rod.Model(
name="single_pendulum",
link=[base, upper],
joint=[fixed, pivot],
),
)

rod_model.model.resolve_frames()

urdf_string = rod.urdf.exporter.UrdfExporter.sdf_to_urdf_string(
sdf=rod_model.models()[0]
)

model = build_jaxsim_model(model_description=urdf_string)

return model


# ============================
# Collections of JaxSim models
# ============================
Expand Down Expand Up @@ -280,6 +390,8 @@ def get_jaxsim_model_fixture(
return request.getfixturevalue(jaxsim_model_ergocub_reduced.__name__)
case "ur10":
return request.getfixturevalue(jaxsim_model_ur10.__name__)
case "single_pendulum":
return request.getfixturevalue(jaxsim_model_single_pendulum.__name__)
case _:
raise ValueError(model_name)

Expand Down
56 changes: 56 additions & 0 deletions tests/test_simulations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax
import jax.numpy as jnp
import numpy as np
import pytest

import jaxsim.api as js
Expand Down Expand Up @@ -370,3 +371,58 @@ def test_simulation_with_relaxed_rigid_contacts(
assert data_tf.base_position()[2] + max_penetration == pytest.approx(
box_height / 2, abs=0.000_100
)


def test_joint_limits(
jaxsim_model_single_pendulum: js.model.JaxSimModel,
):

model = jaxsim_model_single_pendulum

with model.editable(validate=False) as model:
model.kin_dyn_parameters.joint_parameters.position_limits_max = jnp.atleast_1d(
jnp.array(1.5708)
)
model.kin_dyn_parameters.joint_parameters.position_limits_min = jnp.atleast_1d(
jnp.array(-1.5708)
)
model.kin_dyn_parameters.joint_parameters.position_limit_spring = (
jnp.atleast_1d(jnp.array(75.0))
)
model.kin_dyn_parameters.joint_parameters.position_limit_damper = (
jnp.atleast_1d(jnp.array(0.1))
)

position_limits_min, position_limits_max = js.joint.position_limits(model=model)

data = js.data.JaxSimModelData.build(
model=model,
velocity_representation=VelRepr.Inertial,
)

theta = 10 * np.pi / 180

# Define a tolerance since the spring-damper model does
# not guarantee that the joint position will be exactly
# below the limit.
tolerance = theta * 0.10

# Test minimum joint position limits.
data_t0 = data.reset_joint_positions(positions=position_limits_min - theta)

data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.005, tf=3.0)

assert (
np.min(np.array(data_tf.joint_positions()), axis=0) + tolerance
>= position_limits_min
)

# Test maximum joint position limits.
data_t0 = data.reset_joint_positions(positions=position_limits_max - theta)

data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=3.0)

assert (
np.max(np.array(data_tf.joint_positions()), axis=0) - tolerance
<= position_limits_max
)