From d3c17c50bf111a91bb2cdf4977d5fb3aaef8a415 Mon Sep 17 00:00:00 2001 From: MentalGear Date: Thu, 27 Feb 2020 19:48:42 +0100 Subject: [PATCH 1/5] Check Action Type https://github.com/hill-a/stable-baselines/issues/707 --- stable_baselines/common/vec_env/base_vec_env.py | 1 + 1 file changed, 1 insertion(+) diff --git a/stable_baselines/common/vec_env/base_vec_env.py b/stable_baselines/common/vec_env/base_vec_env.py index 189416e52c..4ebb438aa5 100644 --- a/stable_baselines/common/vec_env/base_vec_env.py +++ b/stable_baselines/common/vec_env/base_vec_env.py @@ -146,6 +146,7 @@ def step(self, actions): :param actions: ([int] or [float]) the action :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information """ + assert isinstance( actions, ( list, np.ndarray ) ), "Action must be of type list or np.ndarray. Try wrapping your action variable in a list [ ] to fix this issue." self.step_async(actions) return self.step_wait() From 79990ac02067febc8664e724f4cb12738c0e3124 Mon Sep 17 00:00:00 2001 From: MentalGear Date: Thu, 27 Feb 2020 22:55:22 +0100 Subject: [PATCH 2/5] Update base_vec_env.py --- stable_baselines/common/vec_env/base_vec_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines/common/vec_env/base_vec_env.py b/stable_baselines/common/vec_env/base_vec_env.py index 4ebb438aa5..fdb73ec575 100644 --- a/stable_baselines/common/vec_env/base_vec_env.py +++ b/stable_baselines/common/vec_env/base_vec_env.py @@ -146,7 +146,7 @@ def step(self, actions): :param actions: ([int] or [float]) the action :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information """ - assert isinstance( actions, ( list, np.ndarray ) ), "Action must be of type list or np.ndarray. Try wrapping your action variable in a list [ ] to fix this issue." + assert isinstance( actions,( list, np.ndarray ) ), "Action must be of type list or np.ndarray. Try wrapping your action variable in a list [ ] to fix this issue." self.step_async(actions) return self.step_wait() From 1e29461b9c57914e397f567aa7af9d20a03de769 Mon Sep 17 00:00:00 2001 From: Tom Date: Thu, 27 Feb 2020 23:06:53 +0100 Subject: [PATCH 3/5] Update changelog.rst --- docs/misc/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 8ca14d3042..e8f92ec584 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -57,6 +57,7 @@ Others: - Removed redundant return value from ``a2c.utils::total_episode_reward_logger``. (@shwang) - Cleanup and refactoring in ``common/identity_env.py`` (@shwang) - Added a Makefile to simplify common development tasks (build the doc, type check, run the tests) +- Action Type Check: Assertion added to ``VecEnv` to check if action is of type list or np.ndarray. Otherwise a developer friendly message is displayed on how to fix the issue. Documentation: ^^^^^^^^^^^^^^ From 5ed28e5496e0f67d4549610aff83316a20336561 Mon Sep 17 00:00:00 2001 From: Tom Date: Sat, 29 Feb 2020 17:46:49 +0100 Subject: [PATCH 4/5] Added Tests, changed changlog --- docs/misc/changelog.rst | 2 +- stable_baselines/common/vec_env/base_vec_env.py | 3 ++- tests/test_envs.py | 17 +++++++++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e8f92ec584..6390bc21bc 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -57,7 +57,7 @@ Others: - Removed redundant return value from ``a2c.utils::total_episode_reward_logger``. (@shwang) - Cleanup and refactoring in ``common/identity_env.py`` (@shwang) - Added a Makefile to simplify common development tasks (build the doc, type check, run the tests) -- Action Type Check: Assertion added to ``VecEnv` to check if action is of type list or np.ndarray. Otherwise a developer friendly message is displayed on how to fix the issue. +- Action Type Check: Assertion added to ``VecEnv` to check if action is of type list or np.ndarray. Otherwise a developer friendly message is displayed on how to fix the issue. (@mentalgear) Documentation: ^^^^^^^^^^^^^^ diff --git a/stable_baselines/common/vec_env/base_vec_env.py b/stable_baselines/common/vec_env/base_vec_env.py index fdb73ec575..7231675ad1 100644 --- a/stable_baselines/common/vec_env/base_vec_env.py +++ b/stable_baselines/common/vec_env/base_vec_env.py @@ -146,7 +146,8 @@ def step(self, actions): :param actions: ([int] or [float]) the action :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information """ - assert isinstance( actions,( list, np.ndarray ) ), "Action must be of type list or np.ndarray. Try wrapping your action variable in a list [ ] to fix this issue." + if not isinstance( actions,( list, np.ndarray ) ): + raise TypeError( "Action must be of type list or np.ndarray. Try wrapping your action variable in a list [ ] to fix this issue." ) self.step_async(actions) return self.step_wait() diff --git a/tests/test_envs.py b/tests/test_envs.py index d436f18062..eb355ecb0e 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -127,6 +127,20 @@ def wrong_step(_action): check_env(env) +def test_action_format ( env, action ): + """ + Helper to check that the error is caught. + :param env: (gym.Env) + :param new_step_return: (tuple) + """ + + with pytest.raises(TypeError): + step( 0 ) + + with pytest.raises(not TypeError): + step( [ 0 ] ) + + def test_common_failures_step(): """ Test that common failure cases of the `step` method are caught @@ -147,3 +161,6 @@ def test_common_failures_step(): # Done is not a boolean check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, {})) check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, {})) + + # action format must be [] or np.ndarray + test_action_format ( env, 0 ) From ccd4587979cd1c3ffc295272218521ee365bac20 Mon Sep 17 00:00:00 2001 From: Tom Date: Sat, 29 Feb 2020 19:04:29 +0100 Subject: [PATCH 5/5] Update test_envs.py --- tests/test_envs.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_envs.py b/tests/test_envs.py index eb355ecb0e..31d6d6eb46 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -129,16 +129,18 @@ def wrong_step(_action): def test_action_format ( env, action ): """ - Helper to check that the error is caught. + Test if action format check works :param env: (gym.Env) - :param new_step_return: (tuple) + :param action: (list, array) """ + env = gym.make('CartPole-v0') + with pytest.raises(TypeError): - step( 0 ) + env.step( action ) with pytest.raises(not TypeError): - step( [ 0 ] ) + env.step( [ action ] ) def test_common_failures_step():