Skip to content

Commit

Permalink
Merge pull request #46 from dfki-ric/fix/blacken-codestyle
Browse files Browse the repository at this point in the history
Fix/blacken codestyle
  • Loading branch information
mlaux1 authored Jun 26, 2024
2 parents a301611 + 6bcef83 commit 11d8f69
Show file tree
Hide file tree
Showing 60 changed files with 1,757 additions and 1,093 deletions.
16 changes: 16 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
exclude: ^(object_data/|robots/|doc/)
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# DeformableGym

[![Tests](https://github.com/dfki-ric/deformable_gym/actions/workflows/test.yaml/badge.svg)](https://github.com/dfki-ric/deformable_gym/actions/workflows/test.yaml)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit)](https://github.com/pre-commit/pre-commit)
# DeformableGym

This repository contains a collection of [gymnasium](https://github.com/Farama-Foundation/Gymnasium) environments built with [PyBullet](https://pybullet.org/). In these environments, the agent
needs to learn to grasp deformable object such as shoe insoles or pillows from sparse reward signals.
Expand Down
12 changes: 8 additions & 4 deletions deformable_gym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
register(
id="FloatingMiaGraspInsole-v0",
entry_point="deformable_gym.envs.floating_mia_grasp_env:FloatingMiaGraspEnv",
kwargs={"object_name": "insole_on_conveyor_belt/back",
"observable_object_pos": True}
kwargs={
"object_name": "insole_on_conveyor_belt/back",
"observable_object_pos": True,
},
)

register(
id="FloatingShadowGraspInsole-v0",
entry_point="deformable_gym.envs.floating_shadow_grasp_env:FloatingShadowGraspEnv",
kwargs={"object_name": "insole_on_conveyor_belt/back",
"observable_object_pos": True}
kwargs={
"object_name": "insole_on_conveyor_belt/back",
"observable_object_pos": True,
},
)
3 changes: 2 additions & 1 deletion deformable_gym/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Deformable Gym environments."""

from .base_env import BaseBulletEnv
from .floating_mia_grasp_env import FloatingMiaGraspEnv
from .floating_shadow_grasp_env import FloatingShadowGraspEnv
Expand All @@ -10,5 +11,5 @@
"FloatingMiaGraspEnv",
"FloatingShadowGraspEnv",
"UR5MiaGraspEnv",
"UR10ShadowGraspEnv"
"UR10ShadowGraspEnv",
]
98 changes: 52 additions & 46 deletions deformable_gym/envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class BaseBulletEnv(gym.Env, abc.ABC):
:param verbose_dt: Time after which simulation info should be printed.
:param pybullet_options: Options to pass to pybullet.connect().
"""

gui: bool
verbose: bool
horizon: int
Expand All @@ -37,15 +38,16 @@ class BaseBulletEnv(gym.Env, abc.ABC):
action_space: spaces.Box

def __init__(
self,
gui: bool = True,
real_time: bool = False,
horizon: int = 100,
soft: bool = False,
verbose: bool = False,
time_delta: float = 0.001,
verbose_dt: float = 10.00,
pybullet_options: str = ""):
self,
gui: bool = True,
real_time: bool = False,
horizon: int = 100,
soft: bool = False,
verbose: bool = False,
time_delta: float = 0.001,
verbose_dt: float = 10.00,
pybullet_options: str = "",
):

self.gui = gui
self.verbose = verbose
Expand All @@ -59,7 +61,8 @@ def __init__(
real_time=real_time,
mode=mode,
verbose_dt=verbose_dt,
pybullet_options=pybullet_options)
pybullet_options=pybullet_options,
)

self.pb_client = self.simulation.pb_client

Expand All @@ -75,9 +78,7 @@ def _create_robot(self):
def _load_objects(self):
"""Load objects to PyBullet simulation."""
self.plane = self.pb_client.loadURDF(
"plane.urdf",
(0, 0, 0),
useFixedBase=1
"plane.urdf", (0, 0, 0), useFixedBase=1
)

def _hard_reset(self):
Expand Down Expand Up @@ -140,11 +141,11 @@ def _get_observation(self) -> npt.ArrayLike:
return self.robot.get_joint_positions()

def _get_info(
self,
observation: npt.ArrayLike = None,
action: npt.ArrayLike = None,
reward: float = None,
next_observation: npt.ArrayLike = None
self,
observation: npt.ArrayLike = None,
action: npt.ArrayLike = None,
reward: float = None,
next_observation: npt.ArrayLike = None,
) -> dict:
"""Returns the current environment state.
Expand All @@ -153,10 +154,10 @@ def _get_info(
return {}

def _is_terminated(
self,
observation: npt.ArrayLike,
action: npt.ArrayLike,
next_observation: npt.ArrayLike
self,
observation: npt.ArrayLike,
action: npt.ArrayLike,
next_observation: npt.ArrayLike,
) -> bool:
"""Checks whether the current episode is terminated.
Expand All @@ -168,10 +169,10 @@ def _is_terminated(
return self.step_counter >= self.horizon

def _is_truncated(
self,
state: npt.ArrayLike,
action: npt.ArrayLike,
next_state: npt.ArrayLike
self,
state: npt.ArrayLike,
action: npt.ArrayLike,
next_state: npt.ArrayLike,
) -> bool:
"""Checks whether the current episode is truncated.
Expand Down Expand Up @@ -216,16 +217,19 @@ def step(self, action: npt.ArrayLike):

# calculate the reward
reward = self.calculate_reward(
observation, action, next_observation, terminated)
observation, action, next_observation, terminated
)

info = self._get_info(observation, action, reward)

if self.verbose:
print(f"Finished environment step: "
f"{next_observation=}, "
f"{reward=}, "
f"{terminated=}, "
f"{truncated=}")
print(
f"Finished environment step: "
f"{next_observation=}, "
f"{reward=}, "
f"{terminated=}, "
f"{truncated=}"
)

return next_observation, reward, terminated, truncated, info

Expand All @@ -234,11 +238,11 @@ def close(self):

@abc.abstractmethod
def calculate_reward(
self,
state: npt.ArrayLike,
action: npt.ArrayLike,
next_state: npt.ArrayLike,
terminated: bool
self,
state: npt.ArrayLike,
action: npt.ArrayLike,
next_state: npt.ArrayLike,
terminated: bool,
) -> float:
"""Calculate reward.
Expand Down Expand Up @@ -295,7 +299,8 @@ def get_object_pose(self):
quaternion: (x, y, z, qw, qx, qy, qz)
"""
return MultibodyPose.internal_pose_to_external_pose(
np.hstack((self.object_position, self.object_orientation)))
np.hstack((self.object_position, self.object_orientation))
)


class FloatingHandMixin:
Expand All @@ -309,11 +314,10 @@ def _init_hand_pose(self, robot: BulletRobot):
"""
desired_robot2world_pos = self.hand_world_pose[:3]
desired_robot2world_orn = pb.getQuaternionFromEuler(
self.hand_world_pose[3:])
self.hand_world_pose[3:]
)
self.multibody_pose = MultibodyPose(
robot.get_id(),
desired_robot2world_pos,
desired_robot2world_orn
robot.get_id(), desired_robot2world_pos, desired_robot2world_orn
)

def set_world_pose(self, world_pose):
Expand All @@ -323,13 +327,15 @@ def set_world_pose(self, world_pose):
quaternion: (x, y, z, qw, qx, qy, qz)
"""
self.hand_world_pose = MultibodyPose.external_pose_to_internal_pose(
world_pose)
world_pose
)
desired_robot2world_pos = self.hand_world_pose[:3]
desired_robot2world_orn = pb.getQuaternionFromEuler(
self.hand_world_pose[3:])
self.hand_world_pose[3:]
)
self.multibody_pose.set_pose(
desired_robot2world_pos,
desired_robot2world_orn)
desired_robot2world_pos, desired_robot2world_orn
)

def get_world_pose(self):
"""Get pose of the hand.
Expand Down
68 changes: 39 additions & 29 deletions deformable_gym/envs/bullet_simulation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pybullet as pb
import pybullet_data

from pybullet_utils import bullet_client as bc
from ..robots.bullet_robot import BulletRobot

from ..helpers import pybullet_helper as pbh
from ..robots.bullet_robot import BulletRobot


class BulletSimulation:
Expand All @@ -18,15 +18,17 @@ class BulletSimulation:
:param pybullet_options: Options that should be passed to PyBullet
connection command.
"""

def __init__(
self,
time_delta: float = 0.001,
mode: int = pb.GUI,
gravity: float = -9.81,
soft: bool = False,
real_time: bool = False,
verbose_dt: float = 0.01,
pybullet_options: str = ""):
self,
time_delta: float = 0.001,
mode: int = pb.GUI,
gravity: float = -9.81,
soft: bool = False,
real_time: bool = False,
verbose_dt: float = 0.01,
pybullet_options: str = "",
):

self.time_delta = time_delta
self.mode = mode
Expand All @@ -36,8 +38,7 @@ def __init__(

with pbh.stdout_redirected():
self.pb_client = bc.BulletClient(
connection_mode=self.mode,
options=pybullet_options
connection_mode=self.mode, options=pybullet_options
)
self.pb_client.setAdditionalSearchPath(pybullet_data.getDataPath())

Expand Down Expand Up @@ -107,13 +108,13 @@ def disconnect(self) -> None:

class BulletTiming:
"""This class handles all timing issues for a single BulletSimulation."""

def __init__(
self,
pb_client: bc.BulletClient,
dt: float = 0.001,
verbose_dt: float = 0.01,
self,
pb_client: bc.BulletClient,
dt: float = 0.001,
verbose_dt: float = 0.01,
):

"""
Create new BulletTiming instance.
Expand Down Expand Up @@ -146,7 +147,9 @@ def add_subsystem(self, name, frequency, callback=None):
"""
if name not in self.subsystems.keys():
self.subsystems[name] = (
max(1, round(1.0/frequency/self.dt)), callback)
max(1, round(1.0 / frequency / self.dt)),
callback,
)

def remove_subsystem(self, name):
"""
Expand Down Expand Up @@ -182,8 +185,10 @@ def step(self):
self.sim_time += self.dt

if (self.sim_time % self.verbose_dt) < self.dt:
print(f"Step: {self.time_step}, Time: {self.sim_time}, "
f"Triggers: {triggers}")
print(
f"Step: {self.time_step}, Time: {self.sim_time}, "
f"Triggers: {triggers}"
)

def reset(self):
self.time_step = 0
Expand All @@ -192,13 +197,14 @@ def reset(self):

class BulletCamera:
"""This class handles all camera operations for one BulletSimulation."""

def __init__(
self,
pb_client: bc.BulletClient,
position: tuple = (0, 0, 0),
pitch: int = -52,
yaw: int = 30,
distance: int = 3,
self,
pb_client: bc.BulletClient,
position: tuple = (0, 0, 0),
pitch: int = -52,
yaw: int = 30,
distance: int = 3,
):
self.position = position
self.pitch = pitch
Expand All @@ -209,12 +215,15 @@ def __init__(
self._active = False
self._logging_id = None

self.pb_client.resetDebugVisualizerCamera(distance, yaw, pitch, position)
self.pb_client.resetDebugVisualizerCamera(
distance, yaw, pitch, position
)

def start_recording(self, path):
if not self._active:
self._logging_id = self.pb_client.startStateLogging(
pb.STATE_LOGGING_VIDEO_MP4, path)
pb.STATE_LOGGING_VIDEO_MP4, path
)
return self._logging_id
else:
return None
Expand All @@ -230,4 +239,5 @@ def reset(self, position, pitch, yaw, distance):
self.distance = distance

self.pb_client.resetDebugVisualizerCamera(
distance, yaw, pitch, position)
distance, yaw, pitch, position
)
Loading

0 comments on commit 11d8f69

Please sign in to comment.