Skip to content

Commit

Permalink
Fix 166 and tests (#169)
Browse files Browse the repository at this point in the history
* initial tests

* finish tests

* remove task during step

* fix tests

---------

Co-authored-by: rodrigodelazcano <[email protected]>
  • Loading branch information
rodrigodelazcano and rodrigodelazcano authored Sep 4, 2023
1 parent 58d5cfe commit 166e998
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 5 deletions.
10 changes: 5 additions & 5 deletions gymnasium_robotics/envs/franka_kitchen/kitchen_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,8 @@ def compute_reward(
desired_goal: "dict[str, np.ndarray]",
info: "dict[str, Any]",
):
for task in info["tasks_to_complete"]:
self.step_task_completions.clear()
for task in self.tasks_to_complete:
distance = np.linalg.norm(achieved_goal[task] - desired_goal[task])
complete = distance < BONUS_THRESH
if complete:
Expand Down Expand Up @@ -394,7 +395,6 @@ def _get_obs(self, robot_obs):
def step(self, action):
robot_obs, _, terminated, truncated, info = self.robot_env.step(action)
obs = self._get_obs(robot_obs)
info = {"tasks_to_complete": list(self.tasks_to_complete)}

reward = self.compute_reward(obs["achieved_goal"], self.goal, info)

Expand All @@ -405,13 +405,13 @@ def step(self, action):
for element in self.step_task_completions
]

info = {"tasks_to_complete": list(self.tasks_to_complete)}
info["step_task_completions"] = self.step_task_completions.copy()

for task in self.step_task_completions:
if task not in self.episode_task_completions:
self.episode_task_completions.append(task)
info["episode_task_completions"] = self.episode_task_completions
self.step_task_completions.clear()
if self.terminate_on_tasks_completed:
# terminate if there are no more tasks to complete
terminated = len(self.episode_task_completions) == len(self.goal.keys())
Expand All @@ -423,9 +423,9 @@ def reset(self, *, seed: Optional[int] = None, **kwargs):
self.episode_task_completions.clear()
robot_obs, _ = self.robot_env.reset(seed=seed)
obs = self._get_obs(robot_obs)
self.task_to_complete = set(self.goal.keys())
self.tasks_to_complete = set(self.goal.keys())
info = {
"tasks_to_complete": self.task_to_complete,
"tasks_to_complete": list(self.tasks_to_complete),
"episode_task_completions": [],
"step_task_completions": [],
}
Expand Down
123 changes: 123 additions & 0 deletions tests/envs/franka_kitchen/test_kitchen_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from copy import deepcopy

import gymnasium as gym
import pytest

from gymnasium_robotics.envs.franka_kitchen.kitchen_env import (
OBS_ELEMENT_GOALS,
OBS_ELEMENT_INDICES,
)

TASKS = ["microwave", "kettle"]


@pytest.mark.parametrize(
"remove_task_when_completed, terminate_on_tasks_completed",
[[True, True], [False, False]],
)
def test_task_completion(remove_task_when_completed, terminate_on_tasks_completed):
"""This test checks the different task completion configurations for the FrankaKitchen-v1 environment.
The test checks if the info items returned in each step (`tasks_to_complete`, `step_task_completions`, `episode_task_completions`) are correct and correspond
to the behavior of the environment configured at initialization with the arguments: `remove_task_when_completed` and `terminate_on_tasks_completed`.
"""
env = gym.make(
"FrankaKitchen-v1",
tasks_to_complete=TASKS,
remove_task_when_completed=remove_task_when_completed,
terminate_on_tasks_completed=terminate_on_tasks_completed,
)
# Test task completion for 3 consecutive episodes
for _ in range(3):
tasks_to_complete = deepcopy(TASKS)
completed_tasks = set()
_, info = env.reset()
assert set(info["tasks_to_complete"]) == set(
TASKS
), f"The item `tasks_to_complete` returned by info when the environment is reset: {set(info['tasks_to_complete'])}, must be equal to the `task_to_complete` argument used to initialize the environment: {tasks_to_complete}."
assert (
len(info["step_task_completions"]) == 0
), f"The key `step_task_completions` returned by info when the environment is reset: {set(info['step_task_completions'])}, must be empty."
assert (
len(info["episode_task_completions"]) == 0
), f"The key `episode_task_completions` returned by info when the environment is reset: {set(info['episode_task_completions'])}, must be empty."

terminated = False

# Complete a task sequentially for each environment step
for task in TASKS:
# Force task to be achieved
env.data.qpos[OBS_ELEMENT_INDICES[task]] = OBS_ELEMENT_GOALS[task]
_, _, terminated, _, info = env.step(env.action_space.sample())
completed_tasks.add(task)

assert (
set(info["episode_task_completions"]) == completed_tasks
), f"The key `episode_task_completions` returned by info: {set(info['episode_task_completions'])}, must be equal to the tasks along the current episode: {completed_tasks}."
if remove_task_when_completed:
tasks_to_complete.remove(task)
assert set(info["tasks_to_complete"]) == set(
tasks_to_complete
), f"If environment is initialized with `remove_task_when_completed=True` the item `tasks_to_complete` returned by info: {set(info['tasks_to_complete'])}, must be equal to the tasks that haven't been completed yet: {tasks_to_complete}."
assert set(info["step_task_completions"]) == {
task
}, f"The key `step_task_completions` returned by info: {set(info['step_task_completions'])}, must be equal to the tasks completed after the current step: {task}."

else:
assert set(info["tasks_to_complete"]) == set(
tasks_to_complete
), f"If environment is initialized with `remove_task_when_completed=False` the item `tasks_to_complete` returned by info: {set(info['tasks_to_complete'])}, must be equal to the set of tasks the environment was initialized with: {tasks_to_complete}."
assert (
set(info["step_task_completions"]) == completed_tasks
), f"The key `step_task_completions` returned by info: {set(info['step_task_completions'])}, must be equal to the tasks completed after the current step: {completed_tasks}."

if terminate_on_tasks_completed:
assert (
terminated
), "If the environment is initialized with `terminate_on_tasks_complete=True`, the episode must terminate after all tasks are completed."
else:
assert (
not terminated
), "If the environment is initialized with `terminate_on_tasks_complete=False`, the episode must not terminate after all tasks are completed."

# Complete a task during the same environment step
for _ in range(3):
tasks_to_complete = deepcopy(TASKS)
completed_tasks = set()
_, info = env.reset()

terminated = False

# Complete a task sequentially for each environment step
for task in TASKS:
# Force task to be achieved
env.data.qpos[OBS_ELEMENT_INDICES[task]] = OBS_ELEMENT_GOALS[task]
completed_tasks.add(task)

_, _, terminated, _, info = env.step(env.action_space.sample())
assert (
set(info["step_task_completions"]) == completed_tasks
), f"The key `step_task_completions` returned by info: {set(info['step_task_completions'])}, must be equal to the tasks completed after the current step: {completed_tasks}."
assert (
set(info["episode_task_completions"]) == completed_tasks
), f"The key `episode_task_completions` returned by info: {set(info['episode_task_completions'])}, must be equal to the tasks along the current episode: {completed_tasks}."
if remove_task_when_completed:
assert (
len(info["tasks_to_complete"]) == 0
), f"If environment is initialized with `remove_task_when_completed=True` and all tasks were completed the item `tasks_to_complete` returned by info: {set(info['tasks_to_complete'])}, must be empty."

else:
assert set(info["tasks_to_complete"]) == set(
tasks_to_complete
), f"If environment is initialized with `remove_task_when_completed=False` the item `tasks_to_complete` returned by info: {set(info['tasks_to_complete'])}, must be equal to the set of tasks the environment was initialized with: {tasks_to_complete}."

if terminate_on_tasks_completed:
assert (
terminated
), "If the environment is initialized with `terminate_on_tasks_complete=True`, the episode must terminate after all tasks are completed."
else:
assert (
not terminated
), "If the environment is initialized with `terminate_on_tasks_complete=False`, the episode must not terminate after all tasks are completed."

env.close()

0 comments on commit 166e998

Please sign in to comment.