Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce joint position limits #275

Merged
merged 8 commits into from
Nov 5, 2024
Merged

Enforce joint position limits #275

merged 8 commits into from
Nov 5, 2024

Conversation

flferretti
Copy link
Collaborator

@flferretti flferretti commented Oct 25, 2024

This PR adds support to joint limit torques.

Example behaviour:

single_pendulum.mp4
MWE

import jaxsim.api as js
import rod
from rod.builder.primitives import BoxBuilder, PrimitiveBuilder
import jaxsim.mujoco
import numpy as np
from pathlib import Path
import jaxsim
import jax.numpy as jnp
import os

os.environ["MUJOCO_GL"] = "egl"

base_height = 2.15
upper_height = 1.0

# ===================
# Create the builders
# ===================

base_builder = BoxBuilder(
    name="base",
    mass=1.0,
    x=0.15,
    y=0.15,
    z=base_height,
)

upper_builder = BoxBuilder(
    name="upper",
    mass=0.5,
    x=0.15,
    y=0.15,
    z=upper_height,
)

# =================
# Create the joints
# =================

fixed = rod.Joint(
    name="fixed_joint",
    type="fixed",
    parent="world",
    child=base_builder.name,
)

pivot = rod.Joint(
    name="upper_joint",
    type="revolute",
    parent=base_builder.name,
    child=upper_builder.name,
    axis=rod.Axis(
        xyz=rod.Xyz([1, 0, 0]),
        limit=rod.Limit(
            lower=-1.5708,
            upper=1.5708,
            stiffness=50.0,
            dissipation=0.1,
        ),
    ),
)

# ================
# Create the links
# ================

base = (
    base_builder.build_link(
        name=base_builder.name,
        pose=PrimitiveBuilder.build_pose(pos=np.array([0, 0, base_height / 2])),
    )
    .add_inertial()
    .add_visual()
    .add_collision()
    .build()
)

upper_pose = PrimitiveBuilder.build_pose(pos=np.array([0, 0, upper_height / 2]))

upper = (
    upper_builder.build_link(
        name=upper_builder.name,
        pose=PrimitiveBuilder.build_pose(
            relative_to=base.name, pos=np.array([0, 0, upper_height])
        ),
    )
    .add_inertial(pose=upper_pose)
    .add_visual(pose=upper_pose)
    .add_collision(pose=upper_pose)
    .build()
)

rod_model = rod.Sdf(
    version="1.10",
    model=rod.Model(
        name="single_pendulum",
        link=[base, upper],
        joint=[fixed, pivot],
    ),
)

rod_model.model.resolve_frames()

model = js.model.JaxSimModel.build_from_model_description(
    model_description=rod_model,
    time_step=0.01,
    terrain=jaxsim.terrain.FlatTerrain.build(height=-1e3),
)

data = js.data.JaxSimModelData.build(model=model)

data = data.reset_joint_positions(jnp.array([0.5]))

mjcf_string, assets = jaxsim.mujoco.loaders.RodModelToMjcf.convert(
    rod_model=rod_model.model,
    cameras=jaxsim.mujoco.loaders.MujocoCamera.build_from_target_view(
        camera_name="pendulum_camera",
        lookat=js.link.com_position(
            model=model,
            data=data,
            link_index=js.link.name_to_idx(model=model, link_name="base"),
            in_link_frame=False,
        ),
        distance=3,
        azimut=150,
        elevation=-10,
    ),
)

mj_model_helper = jaxsim.mujoco.model.MujocoModelHelper.build_from_xml(
    mjcf_description=mjcf_string,
    assets=assets,
)

recorder = jaxsim.mujoco.MujocoVideoRecorder(
    model=mj_model_helper.model,
    data=mj_model_helper.data,
    fps=int(1 / model.time_step),
    width=320 * 2,
    height=240 * 2,
)

integrator = jaxsim.integrators.fixed_step.Heun2.build(
    fsal_enabled_if_supported=False,
    dynamics=js.ode.wrap_system_dynamics_for_integration(
        model=model,
        data=data,
        system_dynamics=js.ode.system_dynamics,
    ),
)

integrator_state = integrator.init(x0=data.state, t0=0.0, dt=model.time_step)

joint_positions = []

for _ in range(1000):
    data, integrator_state = js.model.step(
        model=model,
        data=data,
        integrator_state=integrator_state,
        integrator=integrator,
    )

    joint_positions.append(data.joint_positions())
    recorder.record_frame(camera_name="pendulum_camera")

    mj_model_helper.set_joint_positions(
        joint_names=model.joint_names(), positions=data.joint_positions()
    )

    print(f"Step: {_}/1000, Joint Position: {data.joint_positions()}", end="\r")

recorder.write_video(path=Path.cwd() / Path("single_pendulum.mp4"), exist_ok=True)


📚 Documentation preview 📚: https://jaxsim--275.org.readthedocs.build//275/

@flferretti flferretti self-assigned this Oct 25, 2024
@flferretti flferretti force-pushed the feature/joint_limits branch 2 times, most recently from 3d94caa to 21d6e3d Compare October 28, 2024 16:54
@flferretti flferretti marked this pull request as ready for review October 28, 2024 16:54
Copy link
Member

@diegoferigo diegoferigo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! I remember that with the old OOP APIs, the position limits enforced in this way were quite delicate. Regardless, it's well worth to have the previous logic ported to the new functional APIs. Can you make sure that your logic matches the previous one (that was only considering the damping component)?

Check also the previous PR #22.

# =====================
# Joint position limits
# =====================
if physics_model.dofs() > 0:
# Get the joint position limits
s_min, s_max = jnp.array(
[j.position_limit for j in physics_model.description.joints_dict.values()]
).T
# Get the spring/damper parameters of joint limits enforcement
k_damper = jnp.array(list(physics_model._joint_limit_damper.values()))
# Compute the joint torques that enforce joint limits
s = ode_state.physics_model.joint_positions
tau_min = jnp.where(s <= s_min, k_damper * (s_min - s), 0)
tau_max = jnp.where(s >= s_max, k_damper * (s_max - s), 0)
tau_limit = tau_max + tau_min

tests/test_simulations.py Outdated Show resolved Hide resolved
tests/conftest.py Show resolved Hide resolved
tests/test_simulations.py Outdated Show resolved Hide resolved
@flferretti
Copy link
Collaborator Author

@diegoferigo in 6a7dba2 I've added a fix for single-joint models, would you mind taking a look at that as well? Thanks :)

src/jaxsim/api/joint.py Outdated Show resolved Hide resolved
@flferretti flferretti merged commit b274aab into main Nov 5, 2024
24 checks passed
@flferretti flferretti deleted the feature/joint_limits branch November 5, 2024 13:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants