diff --git a/gymnasium_robotics/envs/franka_kitchen/kitchen_env.py b/gymnasium_robotics/envs/franka_kitchen/kitchen_env.py index 47e57548..beba0af9 100644 --- a/gymnasium_robotics/envs/franka_kitchen/kitchen_env.py +++ b/gymnasium_robotics/envs/franka_kitchen/kitchen_env.py @@ -355,6 +355,7 @@ def compute_reward( desired_goal: "dict[str, np.ndarray]", info: "dict[str, Any]", ): + self.step_task_completions.clear() for task in self.tasks_to_complete: distance = np.linalg.norm(achieved_goal[task] - desired_goal[task]) complete = distance < BONUS_THRESH @@ -394,7 +395,7 @@ def _get_obs(self, robot_obs): def step(self, action): robot_obs, _, terminated, truncated, info = self.robot_env.step(action) obs = self._get_obs(robot_obs) - + reward = self.compute_reward(obs["achieved_goal"], self.goal, info) if self.remove_task_when_completed: @@ -411,7 +412,6 @@ def step(self, action): if task not in self.episode_task_completions: self.episode_task_completions.append(task) info["episode_task_completions"] = self.episode_task_completions - self.step_task_completions.clear() if self.terminate_on_tasks_completed: # terminate if there are no more tasks to complete terminated = len(self.episode_task_completions) == len(self.goal.keys()) @@ -425,7 +425,7 @@ def reset(self, *, seed: Optional[int] = None, **kwargs): obs = self._get_obs(robot_obs) self.tasks_to_complete = set(self.goal.keys()) info = { - "tasks_to_complete": self.tasks_to_complete, + "tasks_to_complete": list(self.tasks_to_complete), "episode_task_completions": [], "step_task_completions": [], } diff --git a/tests/envs/franka_kitchen/test_kitchen_env.py b/tests/envs/franka_kitchen/test_kitchen_env.py index 23dca2c6..e6616cf6 100644 --- a/tests/envs/franka_kitchen/test_kitchen_env.py +++ b/tests/envs/franka_kitchen/test_kitchen_env.py @@ -16,8 +16,7 @@ [[True, True], [False, False]], ) def test_task_completion(remove_task_when_completed, terminate_on_tasks_completed): - """ - This test checks the different task completion configurations for the FrankaKitchen-v1 environment. + """This test checks the different task completion configurations for the FrankaKitchen-v1 environment. The test checks if the info items returned in each step (`tasks_to_complete`, `step_task_completions`, `episode_task_completions`) are correct and correspond to the behavior of the environment configured at initialization with the arguments: `remove_task_when_completed` and `terminate_on_tasks_completed`. @@ -52,9 +51,6 @@ def test_task_completion(remove_task_when_completed, terminate_on_tasks_complete _, _, terminated, _, info = env.step(env.action_space.sample()) completed_tasks.add(task) - assert set(info["step_task_completions"]) == { - task - }, f"The key `step_task_completions` returned by info: {set(info['step_task_completions'])}, must be equal to the tasks completed after the current step: {task}." assert ( set(info["episode_task_completions"]) == completed_tasks ), f"The key `episode_task_completions` returned by info: {set(info['episode_task_completions'])}, must be equal to the tasks along the current episode: {completed_tasks}." @@ -63,11 +59,17 @@ def test_task_completion(remove_task_when_completed, terminate_on_tasks_complete assert set(info["tasks_to_complete"]) == set( tasks_to_complete ), f"If environment is initialized with `remove_task_when_completed=True` the item `tasks_to_complete` returned by info: {set(info['tasks_to_complete'])}, must be equal to the tasks that haven't been completed yet: {tasks_to_complete}." + assert set(info["step_task_completions"]) == { + task + }, f"The key `step_task_completions` returned by info: {set(info['step_task_completions'])}, must be equal to the tasks completed after the current step: {task}." else: assert set(info["tasks_to_complete"]) == set( tasks_to_complete ), f"If environment is initialized with `remove_task_when_completed=False` the item `tasks_to_complete` returned by info: {set(info['tasks_to_complete'])}, must be equal to the set of tasks the environment was initialized with: {tasks_to_complete}." + assert ( + set(info["step_task_completions"]) == completed_tasks + ), f"The key `step_task_completions` returned by info: {set(info['step_task_completions'])}, must be equal to the tasks completed after the current step: {completed_tasks}." if terminate_on_tasks_completed: assert ( @@ -79,43 +81,43 @@ def test_task_completion(remove_task_when_completed, terminate_on_tasks_complete ), "If the environment is initialized with `terminate_on_tasks_complete=False`, the episode must not terminate after all tasks are completed." # Complete a task during the same environment step - # for _ in range(3): - # tasks_to_complete = deepcopy(TASKS) - # completed_tasks = set() - # _, info = env.reset() - - # terminated = False - - # # Complete a task sequentially for each environment step - # for task in TASKS: - # # Force task to be achieved - # env.data.qpos[OBS_ELEMENT_INDICES[task]] = OBS_ELEMENT_GOALS[task] - # _, _, terminated, _, info = env.step(env.action_space.sample()) - # completed_tasks.add(task) - - # assert ( - # set(info["step_task_completions"]) == completed_tasks - # ), f"The key `step_task_completions` returned by info: {set(info['step_task_completions'])}, must be equal to the tasks completed after the current step: {completed_tasks}." - # assert ( - # set(info["episode_task_completions"]) == completed_tasks - # ), f"The key `episode_task_completions` returned by info: {set(info['episode_task_completions'])}, must be equal to the tasks along the current episode: {completed_tasks}." - # if remove_task_when_completed: - # assert ( - # len(info["tasks_to_complete"]) == 0 - # ), f"If environment is initialized with `remove_task_when_completed=True` and all tasks were completed the item `tasks_to_complete` returned by info: {set(info['tasks_to_complete'])}, must be empty." - - # else: - # assert set(info["tasks_to_complete"]) == set( - # tasks_to_complete - # ), f"If environment is initialized with `remove_task_when_completed=False` the item `tasks_to_complete` returned by info: {set(info['tasks_to_complete'])}, must be equal to the set of tasks the environment was initialized with: {tasks_to_complete}." - - # if terminate_on_tasks_completed: - # assert ( - # terminated - # ), "If the environment is initialized with `terminate_on_tasks_complete=True`, the episode must terminate after all tasks are completed." - # else: - # assert ( - # not terminated - # ), "If the environment is initialized with `terminate_on_tasks_complete=False`, the episode must not terminate after all tasks are completed." + for _ in range(3): + tasks_to_complete = deepcopy(TASKS) + completed_tasks = set() + _, info = env.reset() + + terminated = False + + # Complete a task sequentially for each environment step + for task in TASKS: + # Force task to be achieved + env.data.qpos[OBS_ELEMENT_INDICES[task]] = OBS_ELEMENT_GOALS[task] + completed_tasks.add(task) + + _, _, terminated, _, info = env.step(env.action_space.sample()) + assert ( + set(info["step_task_completions"]) == completed_tasks + ), f"The key `step_task_completions` returned by info: {set(info['step_task_completions'])}, must be equal to the tasks completed after the current step: {completed_tasks}." + assert ( + set(info["episode_task_completions"]) == completed_tasks + ), f"The key `episode_task_completions` returned by info: {set(info['episode_task_completions'])}, must be equal to the tasks along the current episode: {completed_tasks}." + if remove_task_when_completed: + assert ( + len(info["tasks_to_complete"]) == 0 + ), f"If environment is initialized with `remove_task_when_completed=True` and all tasks were completed the item `tasks_to_complete` returned by info: {set(info['tasks_to_complete'])}, must be empty." + + else: + assert set(info["tasks_to_complete"]) == set( + tasks_to_complete + ), f"If environment is initialized with `remove_task_when_completed=False` the item `tasks_to_complete` returned by info: {set(info['tasks_to_complete'])}, must be equal to the set of tasks the environment was initialized with: {tasks_to_complete}." + + if terminate_on_tasks_completed: + assert ( + terminated + ), "If the environment is initialized with `terminate_on_tasks_complete=True`, the episode must terminate after all tasks are completed." + else: + assert ( + not terminated + ), "If the environment is initialized with `terminate_on_tasks_complete=False`, the episode must not terminate after all tasks are completed." env.close()