Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rodrigodelazcano committed Sep 4, 2023
1 parent b315bc6 commit 2783a51
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 46 deletions.
6 changes: 3 additions & 3 deletions gymnasium_robotics/envs/franka_kitchen/kitchen_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def compute_reward(
desired_goal: "dict[str, np.ndarray]",
info: "dict[str, Any]",
):
self.step_task_completions.clear()
for task in self.tasks_to_complete:
distance = np.linalg.norm(achieved_goal[task] - desired_goal[task])
complete = distance < BONUS_THRESH
Expand Down Expand Up @@ -394,7 +395,7 @@ def _get_obs(self, robot_obs):
def step(self, action):
robot_obs, _, terminated, truncated, info = self.robot_env.step(action)
obs = self._get_obs(robot_obs)

reward = self.compute_reward(obs["achieved_goal"], self.goal, info)

if self.remove_task_when_completed:
Expand All @@ -411,7 +412,6 @@ def step(self, action):
if task not in self.episode_task_completions:
self.episode_task_completions.append(task)
info["episode_task_completions"] = self.episode_task_completions
self.step_task_completions.clear()
if self.terminate_on_tasks_completed:
# terminate if there are no more tasks to complete
terminated = len(self.episode_task_completions) == len(self.goal.keys())
Expand All @@ -425,7 +425,7 @@ def reset(self, *, seed: Optional[int] = None, **kwargs):
obs = self._get_obs(robot_obs)
self.tasks_to_complete = set(self.goal.keys())
info = {
"tasks_to_complete": self.tasks_to_complete,
"tasks_to_complete": list(self.tasks_to_complete),
"episode_task_completions": [],
"step_task_completions": [],
}
Expand Down
88 changes: 45 additions & 43 deletions tests/envs/franka_kitchen/test_kitchen_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
[[True, True], [False, False]],
)
def test_task_completion(remove_task_when_completed, terminate_on_tasks_completed):
"""
This test checks the different task completion configurations for the FrankaKitchen-v1 environment.
"""This test checks the different task completion configurations for the FrankaKitchen-v1 environment.
The test checks if the info items returned in each step (`tasks_to_complete`, `step_task_completions`, `episode_task_completions`) are correct and correspond
to the behavior of the environment configured at initialization with the arguments: `remove_task_when_completed` and `terminate_on_tasks_completed`.
Expand Down Expand Up @@ -52,9 +51,6 @@ def test_task_completion(remove_task_when_completed, terminate_on_tasks_complete
_, _, terminated, _, info = env.step(env.action_space.sample())
completed_tasks.add(task)

assert set(info["step_task_completions"]) == {
task
}, f"The key `step_task_completions` returned by info: {set(info['step_task_completions'])}, must be equal to the tasks completed after the current step: {task}."
assert (
set(info["episode_task_completions"]) == completed_tasks
), f"The key `episode_task_completions` returned by info: {set(info['episode_task_completions'])}, must be equal to the tasks along the current episode: {completed_tasks}."
Expand All @@ -63,11 +59,17 @@ def test_task_completion(remove_task_when_completed, terminate_on_tasks_complete
assert set(info["tasks_to_complete"]) == set(
tasks_to_complete
), f"If environment is initialized with `remove_task_when_completed=True` the item `tasks_to_complete` returned by info: {set(info['tasks_to_complete'])}, must be equal to the tasks that haven't been completed yet: {tasks_to_complete}."
assert set(info["step_task_completions"]) == {
task
}, f"The key `step_task_completions` returned by info: {set(info['step_task_completions'])}, must be equal to the tasks completed after the current step: {task}."

else:
assert set(info["tasks_to_complete"]) == set(
tasks_to_complete
), f"If environment is initialized with `remove_task_when_completed=False` the item `tasks_to_complete` returned by info: {set(info['tasks_to_complete'])}, must be equal to the set of tasks the environment was initialized with: {tasks_to_complete}."
assert (
set(info["step_task_completions"]) == completed_tasks
), f"The key `step_task_completions` returned by info: {set(info['step_task_completions'])}, must be equal to the tasks completed after the current step: {completed_tasks}."

if terminate_on_tasks_completed:
assert (
Expand All @@ -79,43 +81,43 @@ def test_task_completion(remove_task_when_completed, terminate_on_tasks_complete
), "If the environment is initialized with `terminate_on_tasks_complete=False`, the episode must not terminate after all tasks are completed."

# Complete a task during the same environment step
# for _ in range(3):
# tasks_to_complete = deepcopy(TASKS)
# completed_tasks = set()
# _, info = env.reset()

# terminated = False

# # Complete a task sequentially for each environment step
# for task in TASKS:
# # Force task to be achieved
# env.data.qpos[OBS_ELEMENT_INDICES[task]] = OBS_ELEMENT_GOALS[task]
# _, _, terminated, _, info = env.step(env.action_space.sample())
# completed_tasks.add(task)

# assert (
# set(info["step_task_completions"]) == completed_tasks
# ), f"The key `step_task_completions` returned by info: {set(info['step_task_completions'])}, must be equal to the tasks completed after the current step: {completed_tasks}."
# assert (
# set(info["episode_task_completions"]) == completed_tasks
# ), f"The key `episode_task_completions` returned by info: {set(info['episode_task_completions'])}, must be equal to the tasks along the current episode: {completed_tasks}."
# if remove_task_when_completed:
# assert (
# len(info["tasks_to_complete"]) == 0
# ), f"If environment is initialized with `remove_task_when_completed=True` and all tasks were completed the item `tasks_to_complete` returned by info: {set(info['tasks_to_complete'])}, must be empty."

# else:
# assert set(info["tasks_to_complete"]) == set(
# tasks_to_complete
# ), f"If environment is initialized with `remove_task_when_completed=False` the item `tasks_to_complete` returned by info: {set(info['tasks_to_complete'])}, must be equal to the set of tasks the environment was initialized with: {tasks_to_complete}."

# if terminate_on_tasks_completed:
# assert (
# terminated
# ), "If the environment is initialized with `terminate_on_tasks_complete=True`, the episode must terminate after all tasks are completed."
# else:
# assert (
# not terminated
# ), "If the environment is initialized with `terminate_on_tasks_complete=False`, the episode must not terminate after all tasks are completed."
for _ in range(3):
tasks_to_complete = deepcopy(TASKS)
completed_tasks = set()
_, info = env.reset()

terminated = False

# Complete a task sequentially for each environment step
for task in TASKS:
# Force task to be achieved
env.data.qpos[OBS_ELEMENT_INDICES[task]] = OBS_ELEMENT_GOALS[task]
completed_tasks.add(task)

_, _, terminated, _, info = env.step(env.action_space.sample())
assert (
set(info["step_task_completions"]) == completed_tasks
), f"The key `step_task_completions` returned by info: {set(info['step_task_completions'])}, must be equal to the tasks completed after the current step: {completed_tasks}."
assert (
set(info["episode_task_completions"]) == completed_tasks
), f"The key `episode_task_completions` returned by info: {set(info['episode_task_completions'])}, must be equal to the tasks along the current episode: {completed_tasks}."
if remove_task_when_completed:
assert (
len(info["tasks_to_complete"]) == 0
), f"If environment is initialized with `remove_task_when_completed=True` and all tasks were completed the item `tasks_to_complete` returned by info: {set(info['tasks_to_complete'])}, must be empty."

else:
assert set(info["tasks_to_complete"]) == set(
tasks_to_complete
), f"If environment is initialized with `remove_task_when_completed=False` the item `tasks_to_complete` returned by info: {set(info['tasks_to_complete'])}, must be equal to the set of tasks the environment was initialized with: {tasks_to_complete}."

if terminate_on_tasks_completed:
assert (
terminated
), "If the environment is initialized with `terminate_on_tasks_complete=True`, the episode must terminate after all tasks are completed."
else:
assert (
not terminated
), "If the environment is initialized with `terminate_on_tasks_complete=False`, the episode must not terminate after all tasks are completed."

env.close()

0 comments on commit 2783a51

Please sign in to comment.