Skip to content

Commit

Permalink
update test_envs
Browse files Browse the repository at this point in the history
  • Loading branch information
wjxgeorge committed Nov 30, 2024
1 parent ce6734e commit 0d0662e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tests/envs/hand/test_reach.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


def test_serialize_deserialize():
env1 = gym.make("HandReach-v2", distance_threshold=1e-6)
env1 = gym.make("HandReach-v3", distance_threshold=1e-6)
env1.reset()
env2 = pickle.loads(pickle.dumps(env1))

Expand Down
46 changes: 46 additions & 0 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pickle
import warnings

import numpy as np

import gymnasium as gym
import pytest
from gymnasium.envs.mujoco.utils import check_mujoco_reset_state
Expand Down Expand Up @@ -163,3 +165,47 @@ def test_pickle_env(env_spec):
data_equivalence(env.step(action), pickled_env.step(action))
env.close()
pickled_env.close()

_test_robot_env_reset_list = ['Fetch', 'HandReach']

@pytest.mark.parametrize(
"spec",
[spec for spec in non_mujoco_py_env_specs if np.any([tar in spec.id for tar in _test_robot_env_reset_list])],
ids=[spec.id for spec in non_mujoco_py_env_specs if np.any([tar in spec.id for tar in _test_robot_env_reset_list])]
)
def test_robot_env_reset(spec):
"""Checks initial state of robotic environment, i.e. Fetch and Shadow Dexterous Hand Reach,
whether their initial states align with the document."""

def _test_initial_states(env, seed=None):
diag_dict = {}

init_obs = env.reset(seed = seed)

if isinstance(init_obs[0], dict):
diag_dict.update(**init_obs[0])
elif isinstance(init_obs[0], np.ndarray):
diag_dict.update({'observation': init_obs[0]})
diag_dict.update({
'qpos': env.unwrapped.data.qpos,
'qvel': env.unwrapped.data.qvel,
'init_qpos': env.unwrapped.initial_qpos,
'init_qvel': env.unwrapped.initial_qvel
})

# exclude object location from environments
# if spec.id[:-3] in ['FetchPush', 'FetchPickAndPlace', 'FetchSlide',]:
if np.any([tar in spec.id for tar in ['FetchPush', 'FetchPickAndPlace', 'FetchSlide',]]):
diag_dict['qpos'] = np.delete(diag_dict['qpos'], np.s_[-7:-5])
diag_dict['init_qpos'] = np.delete(diag_dict['init_qpos'], np.s_[-7:-5])

# testing
assert np.allclose(diag_dict['qpos'], diag_dict['init_qpos'])
assert np.allclose(diag_dict['qvel'], diag_dict['init_qvel'])
return diag_dict

cur_env: gym.Env = spec.make()

_test_initial_states(cur_env, seed=24)
_test_initial_states(cur_env, seed=10)
return

0 comments on commit 0d0662e

Please sign in to comment.