Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/interference bug #30

Merged
merged 20 commits into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion deformable_gym/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .ur5_mia_grasp_env import UR5MiaGraspEnv
from .ur10_shadow_grasp_env import UR10ShadowGraspEnv


__all__ = [
"BaseBulletEnv",
"FloatingMiaGraspEnv",
Expand Down
98 changes: 67 additions & 31 deletions deformable_gym/envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import numpy.typing as npt
import pybullet as pb
import pytransform3d.rotations as pr

from gymnasium import spaces
from deformable_gym.robots.bullet_robot import BulletRobot
from pybullet_utils import bullet_client as bc

from deformable_gym.envs.bullet_simulation import BulletSimulation
from deformable_gym.helpers.pybullet_helper import MultibodyPose
from deformable_gym.robots.bullet_robot import BulletRobot


class BaseBulletEnv(gym.Env, abc.ABC):
Expand Down Expand Up @@ -53,8 +54,14 @@ def __init__(
mode = pb.GUI if gui else pb.DIRECT

self.simulation = BulletSimulation(
soft=soft, time_delta=time_delta, real_time=real_time, mode=mode,
verbose_dt=verbose_dt, pybullet_options=pybullet_options)
soft=soft,
time_delta=time_delta,
real_time=real_time,
mode=mode,
verbose_dt=verbose_dt,
pybullet_options=pybullet_options)

self.pb_client = self.simulation.pb_client

# TODO should we make this configurable? this results in 100 Hz
self.simulation.timing.add_subsystem("time_step", 100)
Expand All @@ -67,9 +74,11 @@ def _create_robot(self):

def _load_objects(self):
"""Load objects to PyBullet simulation."""
self.plane = pb.loadURDF("plane.urdf",
(0, 0, 0),
useFixedBase=1)
self.plane = self.pb_client.loadURDF(
"plane.urdf",
(0, 0, 0),
useFixedBase=1
)

def _hard_reset(self):
"""Hard reset the PyBullet simulation and reload all objects. May be
Expand All @@ -87,7 +96,7 @@ def reset(self, seed=None, options=None) -> npt.ArrayLike:

:return: Initial state.
"""
super().reset(seed=seed)
super().reset(seed=seed, options=options)

if options is not None and "hard_reset" in options:
self._hard_reset()
Expand Down Expand Up @@ -129,18 +138,25 @@ def _get_observation(self) -> npt.ArrayLike:
"""
return self.robot.get_joint_positions()

def _get_info(self):
def _get_info(
self,
observation: npt.ArrayLike = None,
action: npt.ArrayLike = None,
reward: float = None,
next_observation: npt.ArrayLike = None
) -> dict:
"""Returns the current environment state.

:return: The observation
"""
return {}

def _is_terminated(self,
observation: npt.ArrayLike,
action: npt.ArrayLike,
next_observation: npt.ArrayLike
) -> bool:
def _is_terminated(
self,
observation: npt.ArrayLike,
action: npt.ArrayLike,
next_observation: npt.ArrayLike
) -> bool:
"""Checks whether the current episode is terminated.

:param observation: observation before action was taken
Expand All @@ -150,10 +166,12 @@ def _is_terminated(self,
"""
return self.step_counter >= self.horizon

def _is_truncated(self,
state: npt.ArrayLike,
action: npt.ArrayLike,
next_state: npt.ArrayLike) -> bool:
def _is_truncated(
self,
state: npt.ArrayLike,
action: npt.ArrayLike,
next_state: npt.ArrayLike
) -> bool:
"""Checks whether the current episode is truncated.

:param state: State
Expand Down Expand Up @@ -196,22 +214,31 @@ def step(self, action: npt.ArrayLike):
truncated = self._is_truncated(observation, action, next_observation)

# calculate the reward
reward = self.calculate_reward(observation, action, next_observation, terminated)
reward = self.calculate_reward(
observation, action, next_observation, terminated)

info = self._get_info(observation, action, reward)

if self.verbose:
print(f"Finished environment step: {next_observation=}, {reward=}, {terminated=}, {truncated=}")
print(f"Finished environment step: "
f"{next_observation=}, "
f"{reward=}, "
f"{terminated=}, "
f"{truncated=}")

return next_observation, reward, terminated, truncated, {}
return next_observation, reward, terminated, truncated, info

def close(self):
self.simulation.disconnect()

@abc.abstractmethod
def calculate_reward(self,
state: npt.ArrayLike,
action: npt.ArrayLike,
next_state: npt.ArrayLike,
terminated: bool):
def calculate_reward(
self,
state: npt.ArrayLike,
action: npt.ArrayLike,
next_state: npt.ArrayLike,
terminated: bool
) -> float:
"""Calculate reward.

:param state: State of the environment.
Expand Down Expand Up @@ -280,19 +307,28 @@ def _init_hand_pose(self, robot: BulletRobot):
:param robot: Floating hand.
"""
desired_robot2world_pos = self.hand_world_pose[:3]
desired_robot2world_orn = pb.getQuaternionFromEuler(self.hand_world_pose[3:])
self.multibody_pose = MultibodyPose(robot.get_id(), desired_robot2world_pos, desired_robot2world_orn)
desired_robot2world_orn = pb.getQuaternionFromEuler(
self.hand_world_pose[3:])
self.multibody_pose = MultibodyPose(
robot.get_id(),
desired_robot2world_pos,
desired_robot2world_orn
)

def set_world_pose(self, world_pose):
"""Set pose of the hand.

:param world_pose: world pose of hand given as position and
quaternion: (x, y, z, qw, qx, qy, qz)
"""
self.hand_world_pose = MultibodyPose.external_pose_to_internal_pose(world_pose)
self.hand_world_pose = MultibodyPose.external_pose_to_internal_pose(
world_pose)
desired_robot2world_pos = self.hand_world_pose[:3]
desired_robot2world_orn = pb.getQuaternionFromEuler(self.hand_world_pose[3:])
self.multibody_pose.set_pose(desired_robot2world_pos, desired_robot2world_orn)
desired_robot2world_orn = pb.getQuaternionFromEuler(
self.hand_world_pose[3:])
self.multibody_pose.set_pose(
desired_robot2world_pos,
desired_robot2world_orn)

def get_world_pose(self):
"""Get pose of the hand.
Expand Down
95 changes: 56 additions & 39 deletions deformable_gym/envs/bullet_simulation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pybullet as pb
import pybullet_data

from pybullet_utils import bullet_client as bc
from deformable_gym.robots.bullet_robot import BulletRobot


Expand All @@ -16,9 +18,13 @@ class BulletSimulation:
connection command.
"""
def __init__(
self, time_delta: float = 0.001, mode: int = pb.GUI,
gravity: float = -9.81, soft: bool = False,
real_time: bool = False, verbose_dt: float = 0.01,
self,
time_delta: float = 0.001,
mode: int = pb.GUI,
gravity: float = -9.81,
soft: bool = False,
real_time: bool = False,
verbose_dt: float = 0.01,
pybullet_options: str = ""):

self.time_delta = time_delta
Expand All @@ -27,33 +33,37 @@ def __init__(
self.soft = soft
self.real_time = real_time

self._client = pb.connect(self.mode, options=pybullet_options)
self.timing = BulletTiming(dt=time_delta, verbose_dt=verbose_dt,
client_id=self._client)
self.pb_client = bc.BulletClient(
connection_mode=self.mode,
options=pybullet_options
)
self.pb_client.setAdditionalSearchPath(pybullet_data.getDataPath())

pb.setAdditionalSearchPath(pybullet_data.getDataPath())
self.timing = BulletTiming(
pb_client=self.pb_client,
dt=time_delta,
verbose_dt=verbose_dt,
)

self.reset()

self.camera = BulletCamera()
self.camera = BulletCamera(self.pb_client)

def reset(self):
"""Reset and initialize simulation."""
if self.soft:
pb.resetSimulation(pb.RESET_USE_DEFORMABLE_WORLD)
self.pb_client.resetSimulation(pb.RESET_USE_DEFORMABLE_WORLD)
else:
pb.resetSimulation()
self.pb_client.resetSimulation()

print(f"resetting client {self._client}")
print(f"resetting client {self.pb_client}")

pb.setGravity(0, 0, self.gravity, physicsClientId=self._client)
pb.setRealTimeSimulation(self.real_time, physicsClientId=self._client)
pb.setTimeStep(self.time_delta, physicsClientId=self._client)
self.pb_client.setGravity(0, 0, self.gravity)
self.pb_client.setRealTimeSimulation(self.real_time)
self.pb_client.setTimeStep(self.time_delta)

pb.configureDebugVisualizer(
pb.COV_ENABLE_RENDERING, 1, physicsClientId=self._client)
pb.configureDebugVisualizer(
pb.COV_ENABLE_GUI, 0, physicsClientId=self._client)
self.pb_client.configureDebugVisualizer(pb.COV_ENABLE_RENDERING, 1)
self.pb_client.configureDebugVisualizer(pb.COV_ENABLE_GUI, 0)

def add_robot(self, robot: BulletRobot):
"""Add robot to this simulation.
Expand Down Expand Up @@ -88,34 +98,32 @@ def simulate_time(self, time):
for _ in range(int(time / self.time_delta)):
self.timing.step()

def get_physics_client_id(self):
"""Get physics client ID of PyBullet instance.

:return: Physics client ID.
"""
return self._client

def disconnect(self):
def disconnect(self) -> None:
"""Shut down physics client instance."""
pb.disconnect(self._client)
self.pb_client.disconnect()


class BulletTiming:
"""This class handles all timing issues for a single BulletSimulation."""
def __init__(self, dt=0.001, verbose_dt=0.01, client_id=0):
def __init__(
self,
pb_client: bc.BulletClient,
dt: float = 0.001,
verbose_dt: float = 0.01,
):

"""
Create new BulletTiming instance.

:param dt: The time delta used in the BulletSimulation.
:param verbose_dt: Time after we print debug info.
:param client_id: PyBullet instance ID.
:param pb_client: PyBullet instance ID.
"""

# initialise values
self.dt = dt
self.verbose_dt = verbose_dt
self._client = client_id
self._pb_client = pb_client
self.time_step = 0
self.sim_time = 0.0

Expand All @@ -135,8 +143,8 @@ def add_subsystem(self, name, frequency, callback=None):
is triggered.
"""
if name not in self.subsystems.keys():
self.subsystems[name] = (max(1, round(1.0/frequency/self.dt)),
callback)
self.subsystems[name] = (
max(1, round(1.0/frequency/self.dt)), callback)

def remove_subsystem(self, name):
"""
Expand Down Expand Up @@ -167,7 +175,7 @@ def step(self):
"""
triggers = self.get_triggered_subsystems()
self._run_callbacks(triggers)
pb.stepSimulation(physicsClientId=self._client)
self._pb_client.stepSimulation()
self.time_step += 1
self.sim_time += self.dt

Expand All @@ -182,33 +190,42 @@ def reset(self):

class BulletCamera:
"""This class handles all camera operations for one BulletSimulation."""
def __init__(self, position=(0, 0, 0), pitch=-52, yaw=30, distance=3):
def __init__(
self,
pb_client: bc.BulletClient,
position: tuple = (0, 0, 0),
pitch: int = -52,
yaw: int = 30,
distance: int = 3,
):
self.position = position
self.pitch = pitch
self.yaw = yaw
self.distance = distance
self.pb_client = pb_client

self._active = False
self._logging_id = None

pb.resetDebugVisualizerCamera(distance, yaw, pitch, position)
self.pb_client.resetDebugVisualizerCamera(distance, yaw, pitch, position)

def start_recording(self, path):
if not self._active:
self._logging_id = pb.startStateLogging(pb.STATE_LOGGING_VIDEO_MP4,
path)
self._logging_id = self.pb_client.startStateLogging(
pb.STATE_LOGGING_VIDEO_MP4, path)
return self._logging_id
else:
return None

def stop_recording(self):
if self._active:
pb.stopStateLogging(self._logging_id)
self.pb_client.stopStateLogging(self._logging_id)

def reset(self, position, pitch, yaw, distance):
self.position = position
self.pitch = pitch
self.yaw = yaw
self.distance = distance

pb.resetDebugVisualizerCamera(distance, yaw, pitch, position)
self.pb_client.resetDebugVisualizerCamera(
distance, yaw, pitch, position)
Loading
Loading