diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 22379b94f7..c539f109f7 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -85,6 +85,7 @@ Others: - Removed redundant return value from ``a2c.utils::total_episode_reward_logger``. (@shwang) - Cleanup and refactoring in ``common/identity_env.py`` (@shwang) - Added a Makefile to simplify common development tasks (build the doc, type check, run the tests) +- Action Type Check: Assertion added to ``VecEnv` to check if action is of type list or np.ndarray. Otherwise a developer friendly message is displayed on how to fix the issue. (@mentalgear) Documentation: diff --git a/stable_baselines/common/vec_env/base_vec_env.py b/stable_baselines/common/vec_env/base_vec_env.py index 189416e52c..7231675ad1 100644 --- a/stable_baselines/common/vec_env/base_vec_env.py +++ b/stable_baselines/common/vec_env/base_vec_env.py @@ -146,6 +146,8 @@ def step(self, actions): :param actions: ([int] or [float]) the action :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information """ + if not isinstance( actions,( list, np.ndarray ) ): + raise TypeError( "Action must be of type list or np.ndarray. Try wrapping your action variable in a list [ ] to fix this issue." ) self.step_async(actions) return self.step_wait() diff --git a/tests/test_envs.py b/tests/test_envs.py index d436f18062..31d6d6eb46 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -127,6 +127,22 @@ def wrong_step(_action): check_env(env) +def test_action_format ( env, action ): + """ + Test if action format check works + :param env: (gym.Env) + :param action: (list, array) + """ + + env = gym.make('CartPole-v0') + + with pytest.raises(TypeError): + env.step( action ) + + with pytest.raises(not TypeError): + env.step( [ action ] ) + + def test_common_failures_step(): """ Test that common failure cases of the `step` method are caught @@ -147,3 +163,6 @@ def test_common_failures_step(): # Done is not a boolean check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, {})) check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, {})) + + # action format must be [] or np.ndarray + test_action_format ( env, 0 )