diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index abd2a707b5..09432ab348 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -24,6 +24,6 @@ - [ ] My change requires a change to the documentation. - [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*). - [ ] I have updated the documentation accordingly. -- [ ] I have ensured `pytest` and `pytype` both pass. +- [ ] I have ensured `pytest` and `pytype` both pass (by running `make pytest` and `make type`). diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index fc505fbc5e..d11f2780b2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -57,17 +57,17 @@ from stable_baselines import PPO2 In general, we recommend using pycharm to format everything in an efficient way. -Please documentation each function/method using the following template: +Please document each function/method and [type](https://google.github.io/pytype/user_guide.html) them using the following template: ```python -def my_function(arg1, arg2): +def my_function(arg1: type1, arg2: type2) -> returntype: """ Short description of the function. - :param arg1: (arg1 type) describe what is arg1 - :param arg2: (arg2 type) describe what is arg2 - :return: (return type) describe what is returned + :param arg1: (type1) describe what is arg1 + :param arg2: (type2) describe what is arg2 + :return: (returntype) describe what is returned """ ... return my_variable @@ -77,7 +77,7 @@ def my_function(arg1, arg2): Before proposing a PR, please open an issue, where the feature will be discussed. This prevent from duplicated PR to be proposed and also ease the code review process. -Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a , @araffin or @erniejunior ). +Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @erniejunior, @AdamGleave or @Miffyli). A PR must pass the Continuous Integration tests (travis + codacy) to be merged with the master branch. Note: in rare cases, we can create exception for codacy failure. @@ -88,15 +88,34 @@ All new features must add tests in the `tests/` folder ensuring that everything We use [pytest](https://pytest.org/). Also, when a bug fix is proposed, tests should be added to avoid regression. -To run tests with `pytest` and type checking with `pytype`: +To run tests with `pytest`: ``` -./scripts/run_tests.sh +make pytest ``` +Type checking with `pytype`: + +``` +make type +``` + +Build the documentation: + +``` +make doc +``` + +Check documentation spelling (you need to install `sphinxcontrib.spelling` package for that): + +``` +make spelling +``` + + ## Changelog and Documentation -Please do not forget to update the changelog and add documentation if needed. +Please do not forget to update the changelog (`docs/misc/changelog.rst`) and add documentation if needed. A README is present in the `docs/` folder for instructions on how to build the documentation. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..c8de067857 --- /dev/null +++ b/Makefile @@ -0,0 +1,19 @@ +# Run pytest and coverage report +pytest: + ./scripts/run_tests.sh + +# Type check +type: + pytype + +# Build the doc +doc: + cd docs && make html + +# Check the spelling in the doc +spelling: + cd docs && make spelling + +# Clean the doc build folder +clean: + cd docs && make clean diff --git a/README.md b/README.md index 7925a89b4d..1c024316e1 100644 --- a/README.md +++ b/README.md @@ -190,7 +190,7 @@ Some of the baselines examples use [MuJoCo](http://www.mujoco.org) (multi-joint All unit tests in baselines can be run using pytest runner: ``` pip install pytest pytest-cov -pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. +make pytest ``` ## Projects Using Stable-Baselines diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index c5d745feb7..5967e60cd8 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -632,4 +632,4 @@ Bonus: Make a GIF of a Trained Agent obs, _, _ ,_ = model.env.step(action) img = model.env.render(mode='rgb_array') - imageio.mimsave('lander_a2c.gif', [np.array(img[0]) for i, img in enumerate(images) if i%2 == 0], fps=29) + imageio.mimsave('lander_a2c.gif', [np.array(img) for i, img in enumerate(images) if i%2 == 0], fps=29) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 7848c6c056..215589c88d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -15,8 +15,8 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - Parallelized updating and sampling from the replay buffer in DQN. (@flodorner) - - Docker build script, `scripts/build_docker.sh`, can push images automatically. +- Added a seeding method for vectorized environments. (@NeoExtended) Bug Fixes: ^^^^^^^^^^ @@ -26,12 +26,15 @@ Bug Fixes: - Fixed Docker GPU run script, `scripts/run_docker_gpu.sh`, to work with new NVidia Container Toolkit. - Repeated calls to `RLModel.learn()` now preserve internal counters for some episode logging statistics that used to be zeroed at the start of every call. +- Fix `DummyVecEnv.render` for `num_envs > 1`. This used to print a warning and then not render at all. (@shwang) - Fixed a bug in PPO2, ACER, A2C, and ACKTR where repeated calls to `learn(total_timesteps)` reset the environment on every call, potentially biasing samples toward early episode timesteps. (@shwang) - - - Fixed by adding lazy property `ActorCriticRLModel.runner`. Subclasses now use lazily-generated +- Fixed by adding lazy property `ActorCriticRLModel.runner`. Subclasses now use lazily-generated `self.runner` instead of reinitializing a new Runner every time `learn()` is called. +- Fixed a bug in `check_env` where it would fail on high dimensional action spaces +- Fixed `Monitor.close()` that was not calling the parent method +- Fixed a bug in `BaseRLModel` when seeding vectorized environments. (@NeoExtended) Deprecations: ^^^^^^^^^^^^^ @@ -40,10 +43,12 @@ Others: ^^^^^^^ - Removed redundant return value from `a2c.utils::total_episode_reward_logger`. (@shwang) - Cleanup and refactoring in `common/identity_env.py` (@shwang) +- Added a Makefile to simplify common development tasks (build the doc, type check, run the tests) - Add windows CI (@ChengYen-Tang) Documentation: ^^^^^^^^^^^^^^ +- Fixed example for creating a GIF (@KuKuXia) Release 2.9.0 (2019-12-20) @@ -610,4 +615,4 @@ Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk @XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs @Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket @MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching -@flodorner @ChengYen-Tang +@flodorner @KuKuXia @NeoExtended @ChengYen-Tang diff --git a/stable_baselines/bench/monitor.py b/stable_baselines/bench/monitor.py index fd9542b0ba..8d460d2761 100644 --- a/stable_baselines/bench/monitor.py +++ b/stable_baselines/bench/monitor.py @@ -5,26 +5,33 @@ import os import time from glob import glob +from typing import Tuple, Dict, Any, List, Optional +import gym import pandas -from gym.core import Wrapper +import numpy as np -class Monitor(Wrapper): +class Monitor(gym.Wrapper): EXT = "monitor.csv" file_handler = None - def __init__(self, env, filename, allow_early_resets=True, reset_keywords=(), info_keywords=()): + def __init__(self, + env: gym.Env, + filename: Optional[str], + allow_early_resets: bool = True, + reset_keywords=(), + info_keywords=()): """ A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data. - :param env: (Gym environment) The environment - :param filename: (str) the location to save a log file, can be None for no log + :param env: (gym.Env) The environment + :param filename: (Optional[str]) the location to save a log file, can be None for no log :param allow_early_resets: (bool) allows the reset of the environment before it is done :param reset_keywords: (tuple) extra keywords for the reset call, if extra parameters are needed at reset :param info_keywords: (tuple) extra information to log, from the information return of environment.step """ - Wrapper.__init__(self, env=env) + super(Monitor, self).__init__(env=env) self.t_start = time.time() if filename is None: self.file_handler = None @@ -53,12 +60,12 @@ def __init__(self, env, filename, allow_early_resets=True, reset_keywords=(), in self.total_steps = 0 self.current_reset_info = {} # extra info about the current episode, that was passed in during reset() - def reset(self, **kwargs): + def reset(self, **kwargs) -> np.ndarray: """ Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True :param kwargs: Extra keywords saved for the next episode. only if defined by reset_keywords - :return: ([int] or [float]) the first observation of the environment + :return: (np.ndarray) the first observation of the environment """ if not self.allow_early_resets and not self.needs_reset: raise RuntimeError("Tried to reset an environment before done. If you want to allow early resets, " @@ -68,16 +75,16 @@ def reset(self, **kwargs): for key in self.reset_keywords: value = kwargs.get(key) if value is None: - raise ValueError('Expected you to pass kwarg %s into reset' % key) + raise ValueError('Expected you to pass kwarg {} into reset'.format(key)) self.current_reset_info[key] = value return self.env.reset(**kwargs) - def step(self, action): + def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, Dict[Any, Any]]: """ Step the environment with the given action - :param action: ([int] or [float]) the action - :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information + :param action: (np.ndarray) the action + :return: (Tuple[np.ndarray, float, bool, Dict[Any, Any]]) observation, reward, done, information """ if self.needs_reset: raise RuntimeError("Tried to step environment that needs reset") @@ -105,10 +112,11 @@ def close(self): """ Closes the environment """ + super(Monitor, self).close() if self.file_handler is not None: self.file_handler.close() - def get_total_steps(self): + def get_total_steps(self) -> int: """ Returns the total number of timesteps @@ -116,7 +124,7 @@ def get_total_steps(self): """ return self.total_steps - def get_episode_rewards(self): + def get_episode_rewards(self) -> List[float]: """ Returns the rewards of all the episodes @@ -124,7 +132,7 @@ def get_episode_rewards(self): """ return self.episode_rewards - def get_episode_lengths(self): + def get_episode_lengths(self) -> List[int]: """ Returns the number of timesteps of all the episodes @@ -132,7 +140,7 @@ def get_episode_lengths(self): """ return self.episode_lengths - def get_episode_times(self): + def get_episode_times(self) -> List[float]: """ Returns the runtime in seconds of all the episodes @@ -148,7 +156,7 @@ class LoadMonitorResultsError(Exception): pass -def get_monitor_files(path): +def get_monitor_files(path: str) -> List[str]: """ get all the monitor files in the given path @@ -158,12 +166,12 @@ def get_monitor_files(path): return glob(os.path.join(path, "*" + Monitor.EXT)) -def load_results(path): +def load_results(path: str) -> pandas.DataFrame: """ Load all Monitor logs from a given directory path matching ``*monitor.csv`` and ``*monitor.json`` :param path: (str) the directory path containing the log file(s) - :return: (Pandas DataFrame) the logged data + :return: (pandas.DataFrame) the logged data """ # get both csv and (old) json files monitor_files = (glob(os.path.join(path, "*monitor.json")) + get_monitor_files(path)) diff --git a/stable_baselines/common/base_class.py b/stable_baselines/common/base_class.py index 4d1643e95a..89a9d0c655 100644 --- a/stable_baselines/common/base_class.py +++ b/stable_baselines/common/base_class.py @@ -179,12 +179,7 @@ def set_random_seed(self, seed): # Seed python, numpy and tf random generator set_global_seeds(seed) if self.env is not None: - if isinstance(self.env, VecEnv): - # Use a different seed for each env - for idx in range(self.env.num_envs): - self.env.env_method("seed", seed + idx) - else: - self.env.seed(seed) + self.env.seed(seed) # Seed the action space # useful when selecting random actions self.env.action_space.seed(seed) diff --git a/stable_baselines/common/env_checker.py b/stable_baselines/common/env_checker.py index 6c6dd0fcbd..751b204f95 100644 --- a/stable_baselines/common/env_checker.py +++ b/stable_baselines/common/env_checker.py @@ -135,7 +135,7 @@ def _check_spaces(env: gym.Env) -> None: assert isinstance(env.action_space, spaces.Space), "The action space must inherit from gym.spaces" + gym_spaces -def _check_render(env: gym.Env, warn=True, headless=False) -> None: +def _check_render(env: gym.Env, warn: bool = True, headless: bool = False) -> None: """ Check the declared render modes and the `render()`/`close()` method of the environment. @@ -163,7 +163,7 @@ def _check_render(env: gym.Env, warn=True, headless=False) -> None: env.close() -def check_env(env: gym.Env, warn=True, skip_render_check=True) -> None: +def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -> None: """ Check that an environment follows Gym API. This is particularly useful when using a custom environment. @@ -205,8 +205,8 @@ def check_env(env: gym.Env, warn=True, skip_render_check=True) -> None: # Check for the action space, it may lead to hard-to-debug issues if (isinstance(action_space, spaces.Box) and - (np.abs(action_space.low) != np.abs(action_space.high) - or np.abs(action_space.low) > 1 or np.abs(action_space.high) > 1)): + (np.any(np.abs(action_space.low) != np.abs(action_space.high)) + or np.any(np.abs(action_space.low) > 1) or np.any(np.abs(action_space.high) > 1))): warnings.warn("We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) " "cf https://stable-baselines.readthedocs.io/en/master/guide/rl_tips.html") diff --git a/stable_baselines/common/vec_env/base_vec_env.py b/stable_baselines/common/vec_env/base_vec_env.py index 1062b5ab12..189416e52c 100644 --- a/stable_baselines/common/vec_env/base_vec_env.py +++ b/stable_baselines/common/vec_env/base_vec_env.py @@ -1,9 +1,13 @@ from abc import ABC, abstractmethod import inspect import pickle +from typing import Sequence, Optional, List, Union import cloudpickle +import numpy as np + from stable_baselines import logger +from stable_baselines.common.tile_images import tile_images class AlreadySteppingError(Exception): @@ -123,6 +127,18 @@ def env_method(self, method_name, *method_args, indices=None, **method_kwargs): """ pass + @abstractmethod + def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + """ + Sets the random seeds for all environments, based on a given seed. + Each individual environment will still get its own seed, by incrementing the given seed. + + :param seed: (Optional[int]) The random seed. May be None for completely random seeding. + :return: (List[Union[None, int]]) Returns a list containing the seeds for each individual env. + Note that all list elements may be None, if the env does not return anything when being seeded. + """ + pass + def step(self, actions): """ Step the environments with the given action @@ -133,19 +149,34 @@ def step(self, actions): self.step_async(actions) return self.step_wait() - def get_images(self): + def get_images(self, *args, **kwargs) -> Sequence[np.ndarray]: """ Return RGB images from each environment """ raise NotImplementedError - def render(self, *args, **kwargs): + def render(self, mode: str, *args, **kwargs): """ Gym environment rendering - :param mode: (str) the rendering type + :param mode: the rendering type """ - logger.warn('Render not defined for %s' % self) + try: + imgs = self.get_images(*args, **kwargs) + except NotImplementedError: + logger.warn('Render not defined for {}'.format(self)) + return + + # Create a big image by tiling images from subprocesses + bigimg = tile_images(imgs) + if mode == 'human': + import cv2 # pytype:disable=import-error + cv2.imshow('vecenv', bigimg[:, :, ::-1]) + cv2.waitKey(1) + elif mode == 'rgb_array': + return bigimg + else: + raise NotImplementedError @property def unwrapped(self): @@ -206,6 +237,9 @@ def reset(self): def step_wait(self): pass + def seed(self, seed=None): + return self.venv.seed(seed) + def close(self): return self.venv.close() diff --git a/stable_baselines/common/vec_env/dummy_vec_env.py b/stable_baselines/common/vec_env/dummy_vec_env.py index 2fb9d7b962..85bfb8c427 100644 --- a/stable_baselines/common/vec_env/dummy_vec_env.py +++ b/stable_baselines/common/vec_env/dummy_vec_env.py @@ -1,5 +1,6 @@ from collections import OrderedDict import numpy as np +from typing import Sequence from stable_baselines.common.vec_env.base_vec_env import VecEnv from stable_baselines.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info @@ -47,6 +48,12 @@ def step_wait(self): return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), self.buf_infos.copy()) + def seed(self, seed=None): + seeds = list() + for idx, env in enumerate(self.envs): + seeds.append(env.seed(seed + idx)) + return seeds + def reset(self): for env_idx in range(self.num_envs): obs = self.envs[env_idx].reset() @@ -57,10 +64,21 @@ def close(self): for env in self.envs: env.close() - def get_images(self): - return [env.render(mode='rgb_array') for env in self.envs] + def get_images(self, *args, **kwargs) -> Sequence[np.ndarray]: + return [env.render(*args, mode='rgb_array', **kwargs) for env in self.envs] def render(self, *args, **kwargs): + """ + Gym environment rendering. If there are multiple environments then + they are tiled together in one image via `BaseVecEnv.render()`. + Otherwise (if `self.num_envs == 1`), we pass the render call directly to the + underlying environment. + + Therefore, some arguments such as `mode` will have values that are valid + only when `num_envs == 1`. + + :param mode: The rendering type. + """ if self.num_envs == 1: return self.envs[0].render(*args, **kwargs) else: diff --git a/stable_baselines/common/vec_env/subproc_vec_env.py b/stable_baselines/common/vec_env/subproc_vec_env.py index e725cd10c8..6d25741363 100644 --- a/stable_baselines/common/vec_env/subproc_vec_env.py +++ b/stable_baselines/common/vec_env/subproc_vec_env.py @@ -1,11 +1,11 @@ import multiprocessing from collections import OrderedDict +from typing import Sequence import gym import numpy as np from stable_baselines.common.vec_env.base_vec_env import VecEnv, CloudpickleWrapper -from stable_baselines.common.tile_images import tile_images def _worker(remote, parent_remote, env_fn_wrapper): @@ -21,6 +21,8 @@ def _worker(remote, parent_remote, env_fn_wrapper): info['terminal_observation'] = observation observation = env.reset() remote.send((observation, reward, done, info)) + elif cmd == 'seed': + remote.send(env.seed(data)) elif cmd == 'reset': observation = env.reset() remote.send(observation) @@ -107,6 +109,11 @@ def step_wait(self): obs, rews, dones, infos = zip(*results) return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos + def seed(self, seed=None): + for idx, remote in enumerate(self.remotes): + remote.send(('seed', seed + idx)) + return [remote.recv() for remote in self.remotes] + def reset(self): for remote in self.remotes: remote.send(('reset', None)) @@ -125,27 +132,12 @@ def close(self): process.join() self.closed = True - def render(self, mode='human', *args, **kwargs): + def get_images(self, *args, **kwargs) -> Sequence[np.ndarray]: for pipe in self.remotes: # gather images from subprocesses # `mode` will be taken into account later pipe.send(('render', (args, {'mode': 'rgb_array', **kwargs}))) imgs = [pipe.recv() for pipe in self.remotes] - # Create a big image by tiling images from subprocesses - bigimg = tile_images(imgs) - if mode == 'human': - import cv2 # pytype:disable=import-error - cv2.imshow('vecenv', bigimg[:, :, ::-1]) - cv2.waitKey(1) - elif mode == 'rgb_array': - return bigimg - else: - raise NotImplementedError - - def get_images(self): - for pipe in self.remotes: - pipe.send(('render', {"mode": 'rgb_array'})) - imgs = [pipe.recv() for pipe in self.remotes] return imgs def get_attr(self, attr_name, indices=None): diff --git a/tests/test_envs.py b/tests/test_envs.py index 87397d8f23..d436f18062 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -38,6 +38,21 @@ def test_custom_envs(env_class): check_env(env) +def test_high_dimension_action_space(): + """ + Test for continuous action space + with more than one action. + """ + env = gym.make('Pendulum-v0') + # Patch the action space + env.action_space = spaces.Box(low=-1, high=1, shape=(20,), dtype=np.float32) + # Patch to avoid error + def patched_step(_action): + return env.observation_space.sample(), 0.0, False, {} + env.step = patched_step + check_env(env) + + @pytest.mark.parametrize("new_obs_space", [ # Small image spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8), diff --git a/tests/test_mpi_adam.py b/tests/test_mpi_adam.py index 73dc9fb77e..cc62521e05 100644 --- a/tests/test_mpi_adam.py +++ b/tests/test_mpi_adam.py @@ -1,3 +1,4 @@ +import platform import subprocess from .test_common import _assert_eq @@ -5,14 +6,23 @@ def test_mpi_adam(): """Test RunningMeanStd object for MPI""" - return_code = subprocess.call(['mpirun', '--allow-run-as-root', '-np', '2', - 'python', '-m', 'stable_baselines.common.mpi_adam']) + if platform.system() == 'Windows': + return_code = subprocess.call(['mpiexec', '-np', '2', + 'python', '-m', 'stable_baselines.common.mpi_adam']) + else: + return_code = subprocess.call(['mpirun', '--allow-run-as-root', '-np', '2', + 'python', '-m', 'stable_baselines.common.mpi_adam']) _assert_eq(return_code, 0) def test_mpi_adam_ppo1(): """Running test for ppo1""" - return_code = subprocess.call(['mpirun', '--allow-run-as-root', '-np', '2', - 'python', '-m', - 'stable_baselines.ppo1.experiments.train_cartpole']) + if platform.system() == 'Windows': + return_code = subprocess.call(['mpiexec', '-np', '2', + 'python', '-m', + 'stable_baselines.ppo1.experiments.train_cartpole']) + else: + return_code = subprocess.call(['mpirun', '--allow-run-as-root', '-np', '2', + 'python', '-m', + 'stable_baselines.ppo1.experiments.train_cartpole']) _assert_eq(return_code, 0) diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 6e97ed3c7c..cc2ea9fefd 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -1,3 +1,4 @@ +import platform import subprocess import gym @@ -125,8 +126,12 @@ def test_normalize_external(): def test_mpi_runningmeanstd(): """Test RunningMeanStd object for MPI""" - return_code = subprocess.call(['mpirun', '--allow-run-as-root', '-np', '2', - 'python', '-m', 'stable_baselines.common.mpi_running_mean_std']) + if platform.system() == 'Windows': + return_code = subprocess.call(['mpiexec', '-np', '2', + 'python', '-m', 'stable_baselines.common.mpi_running_mean_std']) + else: + return_code = subprocess.call(['mpirun', '--allow-run-as-root', '-np', '2', + 'python', '-m', 'stable_baselines.common.mpi_running_mean_std']) _assert_eq(return_code, 0) @@ -134,6 +139,11 @@ def test_mpi_moments(): """ test running mean std function """ - subprocess.check_call(['mpirun', '--allow-run-as-root', '-np', '3', 'python', '-c', - 'from stable_baselines.common.mpi_moments ' - 'import _helper_runningmeanstd; _helper_runningmeanstd()']) + if platform.system() == 'Windows': + subprocess.check_call(['mpiexec', '-np', '3', 'python', '-c', + 'from stable_baselines.common.mpi_moments ' + 'import _helper_runningmeanstd; _helper_runningmeanstd()']) + else: + subprocess.check_call(['mpirun', '--allow-run-as-root', '-np', '3', 'python', '-c', + 'from stable_baselines.common.mpi_moments ' + 'import _helper_runningmeanstd; _helper_runningmeanstd()'])