From c6a3c601e445f84bd4d479f3b48a9072dfbad1eb Mon Sep 17 00:00:00 2001 From: Melvin Laux Date: Mon, 27 May 2024 18:35:07 +0200 Subject: [PATCH] observable objects in ur+mia --- deformable_gym/envs/ur5_mia_grasp_env.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/deformable_gym/envs/ur5_mia_grasp_env.py b/deformable_gym/envs/ur5_mia_grasp_env.py index c31591a..68df1d3 100644 --- a/deformable_gym/envs/ur5_mia_grasp_env.py +++ b/deformable_gym/envs/ur5_mia_grasp_env.py @@ -85,6 +85,7 @@ def __init__( object_name: str = "insole", thumb_adducted: bool = True, object_scale: float = 1.0, + observable_object_pos: bool = False, **kwargs ): @@ -100,6 +101,7 @@ def __init__( ) self.robot = self._create_robot() + self._observable_object_pos = observable_object_pos limits = pbh.get_limit_array(self.robot.motors.values()) @@ -111,6 +113,12 @@ def __init__( np.array([2, 2, 2]), np.ones(4), limits[1][6:], np.array([5, 5, 5])], axis=0) + if self._observable_object_pos: + lower_observations = np.append( + lower_observations, -np.full(3, 2.)) + upper_observations = np.append( + upper_observations, np.full(3, 2.)) + self.observation_space = spaces.Box( low=lower_observations, high=upper_observations) @@ -174,7 +182,13 @@ def _get_observation(self): ee_pose = self.robot.get_ee_pose() sensor_readings = self.robot.get_sensor_readings() - return np.concatenate([ee_pose, joint_pos, sensor_readings], axis=0) + obs = np.concatenate([ee_pose, joint_pos, sensor_readings], axis=0) + + if self._observable_object_pos: + obj_pos = self.object_to_grasp.get_pose()[:3] + obs = np.append(obs, obj_pos) + + return obs def calculate_reward(self, state, action, next_state, terminated): """