Skip to content

Commit

Permalink
adapt tests for use of bullet client wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
mlaux1 committed Feb 11, 2024
1 parent 1ced7fa commit 926e31b
Show file tree
Hide file tree
Showing 18 changed files with 46 additions and 26 deletions.
13 changes: 7 additions & 6 deletions tests/envs/test_floating_mia_grasp_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,31 +61,32 @@ def test_initial_sensor_info(env: FloatingMiaGraspEnv):
def test_episode_reproducibility():
observations = []
termination_flags = []
actions = []

env = FloatingMiaGraspEnv(
verbose=False,
horizon=10,
horizon=3,
gui=False,
observable_object_pos=True,
object_name="insole_on_conveyor_belt/back",
difficulty_mode="hard",
)
env = RescaleAction(env, 0., 1.)

env.action_space.seed(SEED)

for _ in range(2):
observation, _ = env.reset(seed=SEED)
env.action_space.seed(SEED)

observations.append([observation])
terminated = False
termination_flags.append([terminated])
actions.append([])
while not terminated:
action = env.action_space.sample()
actions[-1].append(action)
observation, reward, terminated, truncated, info = env.step(action)

observations[-1].append(observation)
termination_flags[-1].append(terminated)

assert_allclose(actions[0], actions[1])
assert_allclose(observations[0], observations[1])
assert_allclose(termination_flags[0], termination_flags[1])

Expand Down
14 changes: 11 additions & 3 deletions tests/envs/test_parallel_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,16 @@


def test_parallel_envs():
env = gymnasium.make("FloatingMiaGraspInsole-v0", gui=False)
env2 = gymnasium.make("FloatingMiaGraspInsole-v0", gui=False)
env = gymnasium.make(
"FloatingMiaGraspInsole-v0",
gui=False,
horizon=10
)
env2 = gymnasium.make(
"FloatingMiaGraspInsole-v0",
gui=False,
horizon=10
)

obs, info = env.reset(seed=SEED)
num_steps = 0
Expand All @@ -26,7 +34,7 @@ def test_parallel_envs():

obs, reward, terminated, truncated, _ = env.step(action)
num_steps += 1
obs2, reward2, terminated2, truncated2, _ = env.step(action2)
obs2, reward2, terminated2, truncated2, _ = env2.step(action2)
num_steps2 += 1
episode_return += reward
episode_return2 += reward2
Expand Down
3 changes: 2 additions & 1 deletion tests/objects/test_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@


def test_box(simulation):
obj, _, _ = ObjectFactory().create("box", object_position=TEST_POS, object_orientation=TEST_ORN)
obj, _, _ = ObjectFactory(simulation.pb_client).create(
"box", object_position=TEST_POS, object_orientation=TEST_ORN)
pose = obj.get_pose()
assert_array_almost_equal(pose, np.array([0, 0, 1, 1, 0, 0, 0]))

Expand Down
3 changes: 2 additions & 1 deletion tests/objects/test_capsule.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@


def test_capsule_creation(simulation):
obj, _, _ = ObjectFactory().create("capsule", object_position=TEST_POS, object_orientation=TEST_ORN)
obj, _, _ = ObjectFactory(simulation.pb_client).create(
"capsule", object_position=TEST_POS, object_orientation=TEST_ORN)
pose = obj.get_pose()
assert_array_almost_equal(pose, np.array([0, 0, 1, 1, 0, 0, 0]))

Expand Down
3 changes: 2 additions & 1 deletion tests/objects/test_cylinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@


def test_cylinder_creation(simulation):
obj, _, _ = ObjectFactory().create("cylinder", object_position=TEST_POS, object_orientation=TEST_ORN)
obj, _, _ = ObjectFactory(simulation.pb_client).create(
"cylinder", object_position=TEST_POS, object_orientation=TEST_ORN)
pose = obj.get_pose()
assert_array_almost_equal(pose, np.array([0, 0, 1, 1, 0, 0, 0]))

Expand Down
3 changes: 2 additions & 1 deletion tests/objects/test_insole.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@


def test_insole_creation(simulation):
obj, _, _ = ObjectFactory().create("insole", object_position=TEST_POS, object_orientation=TEST_ORN)
obj, _, _ = ObjectFactory(simulation.pb_client).create(
"insole", object_position=TEST_POS, object_orientation=TEST_ORN)
pose = obj.get_pose()
assert_array_almost_equal(pose, np.array([0, 0, 1, 1, 0, 0, 0]), decimal=1)

Expand Down
2 changes: 1 addition & 1 deletion tests/objects/test_insole_on_conveyor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

@pytest.mark.skip("TODO")
def test_insole_ob_conveyor_creation(simulation):
obj, _, _ = ObjectFactory().create(
obj, _, _ = ObjectFactory(simulation.pb_client).create(
"insole_on_conveyor_belt/back",
object_position=TEST_POS,
object_orientation=TEST_ORN)
Expand Down
2 changes: 1 addition & 1 deletion tests/objects/test_pillow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def test_pillow(simulation):
obj, _, _ = ObjectFactory().create("pillow_small", object_position=TEST_POS, object_orientation=TEST_ORN)
obj, _, _ = ObjectFactory(simulation.pb_client).create("pillow_small", object_position=TEST_POS, object_orientation=TEST_ORN)
pose = obj.get_pose()
assert_array_almost_equal(pose, np.array([0, 0, 1, 1, 0, 0, 0]), decimal=1)

Expand Down
5 changes: 4 additions & 1 deletion tests/objects/test_sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@


def test_sphere(simulation):
obj, _, _ = ObjectFactory().create("sphere", object_position=TEST_POS, object_orientation=TEST_ORN)
obj, _, _ = ObjectFactory(simulation.pb_client).create(
"sphere",
object_position=TEST_POS,
object_orientation=TEST_ORN)
pose = obj.get_pose()
assert_array_almost_equal(pose, np.array([0, 0, 1, 1, 0, 0, 0]))

Expand Down
8 changes: 6 additions & 2 deletions tests/objects/test_urdf_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@


def test_urdf_object_creation(simulation):
obj = UrdfObject("plane.urdf", world_pos=TEST_POS, world_orn=TEST_ORN, fixed=True,
client_id=simulation.get_physics_client_id())
obj = UrdfObject(
"plane.urdf",
simulation.pb_client,
world_pos=TEST_POS,
world_orn=TEST_ORN,
fixed=True)
pose = obj.get_pose()
assert_array_almost_equal(pose, np.array([0, 0, 1, 1, 0, 0, 0]))

Expand Down
2 changes: 1 addition & 1 deletion tests/robots/test_mia_hand_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@pytest.fixture
def robot(simulation):
robot = MiaHandPosition(
pb_client_id=simulation.get_physics_client_id(),
pb_client=simulation.pb_client,
world_pos=TEST_POS,
world_orn=TEST_ORN,
base_commands=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/robots/test_mia_hand_velocity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@pytest.fixture
def robot(simulation):
robot = MiaHandVelocity(
pb_client_id=simulation.get_physics_client_id(),
pb_client=simulation.pb_client,
world_pos=TEST_POS,
world_orn=TEST_ORN,
base_commands=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/robots/test_shadow_hand_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@pytest.fixture
def robot(simulation):
robot = ShadowHandPosition(
pb_client_id=simulation.get_physics_client_id(),
pb_client=simulation.pb_client,
world_pos=TEST_POS,
world_orn=TEST_ORN,
base_commands=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/robots/test_shadow_hand_velocity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@pytest.fixture
def robot(simulation):
robot = ShadowHandVelocity(
pb_client_id=simulation.get_physics_client_id(),
pb_client=simulation.pb_client,
world_pos=TEST_POS,
world_orn=TEST_ORN,
base_commands=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/robots/test_ur10_shadow_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@pytest.fixture
def robot(simulation):
robot = UR10ShadowPosition(pb_client_id=simulation.get_physics_client_id())
robot = UR10ShadowPosition(pb_client=simulation.pb_client)
return robot


Expand Down
2 changes: 1 addition & 1 deletion tests/robots/test_ur10_shadow_velocity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@pytest.fixture
def robot(simulation):
robot = UR10ShadowVelocity(pb_client_id=simulation.get_physics_client_id())
robot = UR10ShadowVelocity(pb_client=simulation.pb_client)
simulation.add_robot(robot)

return robot
Expand Down
2 changes: 1 addition & 1 deletion tests/robots/test_ur5_mia_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@pytest.fixture
def robot(simulation):
robot = UR5MiaPosition(pb_client_id=simulation.get_physics_client_id())
robot = UR5MiaPosition(pb_client=simulation.pb_client)

return robot

Expand Down
2 changes: 1 addition & 1 deletion tests/robots/test_ur5_mia_velocity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@pytest.fixture
def robot(simulation):
robot = UR5MiaVelocity(simulation.get_physics_client_id())
robot = UR5MiaVelocity(simulation.pb_client)

return robot

Expand Down

0 comments on commit 926e31b

Please sign in to comment.