Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check Action Type #712

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Others:
- Removed redundant return value from ``a2c.utils::total_episode_reward_logger``. (@shwang)
- Cleanup and refactoring in ``common/identity_env.py`` (@shwang)
- Added a Makefile to simplify common development tasks (build the doc, type check, run the tests)
- Action Type Check: Assertion added to ``VecEnv` to check if action is of type list or np.ndarray. Otherwise a developer friendly message is displayed on how to fix the issue. (@mentalgear)


Documentation:
Expand Down
2 changes: 2 additions & 0 deletions stable_baselines/common/vec_env/base_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def step(self, actions):
:param actions: ([int] or [float]) the action
:return: ([int] or [float], [float], [bool], dict) observation, reward, done, information
"""
if not isinstance( actions,( list, np.ndarray ) ):
raise TypeError( "Action must be of type list or np.ndarray. Try wrapping your action variable in a list [ ] to fix this issue." )
self.step_async(actions)
return self.step_wait()

Expand Down
19 changes: 19 additions & 0 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,22 @@ def wrong_step(_action):
check_env(env)


def test_action_format ( env, action ):
"""
Test if action format check works
:param env: (gym.Env)
:param action: (list, array)
"""

env = gym.make('CartPole-v0')

with pytest.raises(TypeError):
env.step( action )

with pytest.raises(not TypeError):
env.step( [ action ] )


def test_common_failures_step():
"""
Test that common failure cases of the `step` method are caught
Expand All @@ -147,3 +163,6 @@ def test_common_failures_step():
# Done is not a boolean
check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, {}))
check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, {}))

# action format must be [] or np.ndarray
test_action_format ( env, 0 )