diff --git a/docs/guide/checking_nan.rst b/docs/guide/checking_nan.rst new file mode 100644 index 0000000000..c51bdc0a63 --- /dev/null +++ b/docs/guide/checking_nan.rst @@ -0,0 +1,253 @@ +Dealing with NaNs and infs +========================== + +During the training of a model on a given environment, it is possible that the RL model becomes completely +corrupted when a NaN or an inf is given or returned from the RL model. + +How and why? +------------ + +The issue arises then NaNs or infs do not crash, but simply get propagated through the training, +until all the floating point number converge to NaN or inf. This is in line with the +`IEEE Standard for Floating-Point Arithmetic (IEEE 754) `_ standard, as it says: + +.. note:: + Five possible exceptions can occur: + - Invalid operation (:math:`\sqrt{-1}`, :math:`\inf \times 1`, :math:`\text{NaN}\ \mathrm{mod}\ 1`, ...) return NaN + - Division by zero: + - if the operand is not zero (:math:`1/0`, :math:`-2/0`, ...) returns :math:`\pm\inf` + - if the operand is zero (:math:`0/0`) returns signaling NaN + - Overflow (exponent too high to represent) returns :math:`\pm\inf` + - Underflow (exponent too low to represent) returns :math:`0` + - Inexact (not representable exactly in base 2, eg: :math:`1/5`) returns the rounded value (ex: :code:`assert (1/5) * 3 == 0.6000000000000001`) + +And of these, only ``Division by zero`` will signal an exception, the rest will propagate invalid values quietly. + +In python, dividing by zero will indeed raise the exception: ``ZeroDivisionError: float division by zero``, +but ignores the rest. + +The default in numpy, will warn: ``RuntimeWarning: invalid value encountered`` +but will not halt the code. + +And the worst of all, Tensorflow will not signal anything + +.. code-block:: python + + import tensorflow as tf + import numpy as np + + print("tensorflow test:") + + a = tf.constant(1.0) + b = tf.constant(0.0) + c = a / b + + sess = tf.Session() + val = sess.run(c) # this will be quiet + print(val) + sess.close() + + print("\r\nnumpy test:") + + a = np.float64(1.0) + b = np.float64(0.0) + val = a / b # this will warn + print(val) + + print("\r\npure python test:") + + a = 1.0 + b = 0.0 + val = a / b # this will raise an exception and halt. + print(val) + +Unfortunately, most of the floating point operations are handled by Tensorflow and numpy, +meaning you might get little to no warning when a invalid value occurs. + +Numpy parameters +---------------- + +Numpy has a convenient way of dealing with invalid value: `numpy.seterr `_, +which defines for the python process, how it should handle floating point error. + +.. code-block:: python + + import numpy as np + + np.seterr(all='raise') # define before your code. + + print("numpy test:") + + a = np.float64(1.0) + b = np.float64(0.0) + val = a / b # this will now raise an exception instead of a warning. + print(val) + +but this will also avoid overflow issues on floating point numbers: + +.. code-block:: python + + import numpy as np + + np.seterr(all='raise') # define before your code. + + print("numpy overflow test:") + + a = np.float64(10) + b = np.float64(1000) + val = a ** b # this will now raise an exception + print(val) + +but will not avoid the propagation issues: + +.. code-block:: python + + import numpy as np + + np.seterr(all='raise') # define before your code. + + print("numpy propagation test:") + + a = np.float64('NaN') + b = np.float64(1.0) + val = a + b # this will neither warn nor raise anything + print(val) + +Tensorflow parameters +--------------------- + +Tensorflow can add checks for detecting and dealing with invalid value: `tf.add_check_numerics_ops `_ and `tf.check_numerics `_, +however they will add operations to the Tensorflow graph and raise the computation time. + +.. code-block:: python + + import tensorflow as tf + + print("tensorflow test:") + + a = tf.constant(1.0) + b = tf.constant(0.0) + c = a / b + + check_nan = tf.add_check_numerics_ops() # add after your graph definition. + + sess = tf.Session() + val, _ = sess.run([c, check_nan]) # this will now raise an exception + print(val) + sess.close() + +but this will also avoid overflow issues on floating point numbers: + +.. code-block:: python + + import tensorflow as tf + + print("tensorflow overflow test:") + + check_nan = [] # the list of check_numerics operations + + a = tf.constant(10) + b = tf.constant(1000) + c = a ** b + + check_nan.append(tf.check_numerics(c, "")) # check the 'c' operations + + sess = tf.Session() + val, _ = sess.run([c] + check_nan) # this will now raise an exception + print(val) + sess.close() + +and catch propagation issues: + +.. code-block:: python + + import tensorflow as tf + + print("tensorflow propagation test:") + + check_nan = [] # the list of check_numerics operations + + a = tf.constant('NaN') + b = tf.constant(1.0) + c = a + b + + check_nan.append(tf.check_numerics(c, "")) # check the 'c' operations + + sess = tf.Session() + val, _ = sess.run([c] + check_nan) # this will now raise an exception + print(val) + sess.close() + + +VecCheckNan Wrapper +------------------- + +In order to find when and from where the invalid value originated from, stable-baselines comes with a ``VecCheckNan`` wrapper. + +It will monitor the actions, observations, and rewards, indicating what action or observation caused it and from what. + +.. code-block:: python + + import gym + from gym import spaces + import numpy as np + + from stable_baselines import PPO2 + from stable_baselines.common.vec_env import DummyVecEnv, VecCheckNan + + class NanAndInfEnv(gym.Env): + """Custom Environment that raised NaNs and Infs""" + metadata = {'render.modes': ['human']} + + def __init__(self): + super(NanAndInfEnv, self).__init__() + self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64) + self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64) + + def step(self, _action): + randf = np.random.rand() + if randf > 0.99: + obs = float('NaN') + elif randf > 0.98: + obs = float('inf') + else: + obs = randf + return [obs], 0.0, False, {} + + def reset(self): + return [0.0] + + def render(self, mode='human', close=False): + pass + + # Create environment + env = DummyVecEnv([lambda: NanAndInfEnv()]) + env = VecCheckNan(env, raise_exception=True) + + # Instantiate the agent + model = PPO2('MlpPolicy', env) + + # Train the agent + model.learn(total_timesteps=int(2e5)) # this will crash explaining that the invalid value originated from the environment. + +RL Model hyperparameters +------------------------ + +Depending on your hyperparameters, NaN can occurs much more often. +A great example of this: https://github.com/hill-a/stable-baselines/issues/340 + +Be aware, the hyperparameters given by default seem to work in most cases, +however your environment might not play nice with them. +If this is the case, try to read up on the effect each hyperparameters has on the model, +so that you can try and tune them to get a stable model. Alternatively, you can try automatic hyperparameter tuning (included in the rl zoo). + +Missing values from datasets +---------------------------- + +If your environment is generated from an external dataset, do not forget to make sure your dataset does not contain NaNs. +As some datasets will sometimes fill missing values with NaNs as a surrogate value. + +Here is some reading material about finding NaNs: https://pandas.pydata.org/pandas-docs/stable/user_guide/missing_data.html + +And filling the missing values with something else (imputation): https://towardsdatascience.com/how-to-handle-missing-data-8646b18db0d4 + diff --git a/docs/guide/vec_envs.rst b/docs/guide/vec_envs.rst index 91dbd48408..1876bb3924 100644 --- a/docs/guide/vec_envs.rst +++ b/docs/guide/vec_envs.rst @@ -5,8 +5,8 @@ Vectorized Environments ======================= -Vectorized Environments are a method for multiprocess training. Instead of training an RL agent -on 1 environment, it allows us to train it on `n` environments using `n` processes. +Vectorized Environments are a method for stacking multiple independent environments into a single environment. +Instead of training an RL agent on 1 environment per step, it allows us to train it on `n` environments per step. Because of this, `actions` passed to the environment are now a vector (of dimension `n`). It is the same for `observations`, `rewards` and end of episode signals (`dones`). In the case of non-array observation spaces such as `Dict` or `Tuple`, where different sub-spaces @@ -69,3 +69,10 @@ VecVideoRecorder .. autoclass:: VecVideoRecorder :members: + + +VecCheckNan +~~~~~~~~~~~~~~~~ + +.. autoclass:: VecCheckNan + :members: diff --git a/docs/index.rst b/docs/index.rst index a37c782355..895f1e6a35 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -47,6 +47,7 @@ This toolset is a fork of OpenAI Baselines, with a major structural refactoring, guide/tensorboard guide/rl_zoo guide/pretrain + guide/checking_nan .. toctree:: diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index f354b08ea9..dbb9c52b80 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -6,7 +6,7 @@ Changelog For download links, please look at `Github release page `_. -Pre-Release 2.6.0a0 (WIP) +Pre-Release 2.6.0a1 (WIP) ------------------------- **Hindsight Experience Replay (HER) - Reloaded | get/load parameters** @@ -35,6 +35,9 @@ Pre-Release 2.6.0a0 (WIP) ``find_trainable_params`` was returning all trainable variables, discarding the scope argument. This bug was causing the model to save duplicated parameters (for DDPG and SAC) but did not affect the performance. +- added guide for managing ``NaN`` and ``inf`` +- added ``VecCheckNan`` wrapper +- updated ven_env doc **Breaking Change:** DDPG replay buffer was unified with DQN/SAC replay buffer. As a result, when loading a DDPG model trained with stable_baselines<2.6.0, it throws an import error. diff --git a/stable_baselines/common/vec_env/__init__.py b/stable_baselines/common/vec_env/__init__.py index b8597551d3..29638fa722 100644 --- a/stable_baselines/common/vec_env/__init__.py +++ b/stable_baselines/common/vec_env/__init__.py @@ -6,3 +6,4 @@ from stable_baselines.common.vec_env.vec_frame_stack import VecFrameStack from stable_baselines.common.vec_env.vec_normalize import VecNormalize from stable_baselines.common.vec_env.vec_video_recorder import VecVideoRecorder +from stable_baselines.common.vec_env.vec_check_nan import VecCheckNan diff --git a/stable_baselines/common/vec_env/dummy_vec_env.py b/stable_baselines/common/vec_env/dummy_vec_env.py index 52b3d8c940..2d811f4071 100644 --- a/stable_baselines/common/vec_env/dummy_vec_env.py +++ b/stable_baselines/common/vec_env/dummy_vec_env.py @@ -7,7 +7,10 @@ class DummyVecEnv(VecEnv): """ - Creates a simple vectorized wrapper for multiple environments + Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current + Python process. This is useful for computationally simple environment such as ``cartpole-v1``, as the overhead of + multiprocess or multithread outweighs the environment computation time. This can also be used for RL methods that + require a vectorized environment, but that you want a single environments to train with. :param env_fns: ([Gym Environment]) the list of environments to vectorize """ diff --git a/stable_baselines/common/vec_env/subproc_vec_env.py b/stable_baselines/common/vec_env/subproc_vec_env.py index 3144bd0c14..dd989515e0 100644 --- a/stable_baselines/common/vec_env/subproc_vec_env.py +++ b/stable_baselines/common/vec_env/subproc_vec_env.py @@ -44,7 +44,11 @@ def _worker(remote, parent_remote, env_fn_wrapper): class SubprocVecEnv(VecEnv): """ - Creates a multiprocess vectorized wrapper for multiple environments + Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own + process, allowing significant speed up when the environment is computationally complex. + + For performance reasons, if your environment is not IO bound, the number of environments should not exceed the + number of logical cores on your CPU. .. warning:: diff --git a/stable_baselines/common/vec_env/vec_check_nan.py b/stable_baselines/common/vec_env/vec_check_nan.py new file mode 100644 index 0000000000..b7440e4ed7 --- /dev/null +++ b/stable_baselines/common/vec_env/vec_check_nan.py @@ -0,0 +1,86 @@ +import warnings + +import numpy as np + +from stable_baselines.common.vec_env import VecEnvWrapper + + +class VecCheckNan(VecEnvWrapper): + """ + NaN and inf checking wrapper for vectorized environment, will raise a warning by default, + allowing you to know from what the NaN of inf originated from. + + :param venv: (VecEnv) the vectorized environment to wrap + :param raise_exception: (bool) Whether or not to raise a ValueError, instead of a UserWarning + :param warn_once: (bool) Whether or not to only warn once. + :param check_inf: (bool) Whether or not to check for +inf or -inf as well + """ + + def __init__(self, venv, raise_exception=False, warn_once=True, check_inf=True): + VecEnvWrapper.__init__(self, venv) + self.raise_exception = raise_exception + self.warn_once = warn_once + self.check_inf = check_inf + self._actions = None + self._observations = None + self._user_warned = False + + def step_async(self, actions): + self._check_val(async_step=True, actions=actions) + + self._actions = actions + self.venv.step_async(actions) + + def step_wait(self): + observations, rewards, news, infos = self.venv.step_wait() + + self._check_val(async_step=False, observations=observations, rewards=rewards, news=news) + + self._observations = observations + return observations, rewards, news, infos + + def reset(self): + observations = self.venv.reset() + self._actions = None + + self._check_val(async_step=False, observations=observations) + + self._observations = observations + return observations + + def _check_val(self, *, async_step, **kwargs): + # if warn and warn once and have warned once: then stop checking + if not self.raise_exception and self.warn_once and self._user_warned: + return + + found = [] + for name, val in kwargs.items(): + has_nan = any(np.isnan(val)) + has_inf = self.check_inf and any(np.isinf(val)) + if has_inf: + found.append((name, "inf")) + if has_nan: + found.append((name, "nan")) + + if found: + self._user_warned = True + msg = "" + for i, (name, type_val) in enumerate(found): + msg += "found {} in {}".format(type_val, name) + if i != len(found) - 1: + msg += ", " + + msg += ".\r\nOriginated from the " + + if not async_step: + if self._actions is None: + msg += "environment observation (at reset)" + else: + msg += "environment, Last given value was: \r\n\taction={}".format(self._actions) + else: + msg += "RL model, Last given value was: \r\n\tobservations={}".format(self._observations) + + if self.raise_exception: + raise ValueError(msg) + else: + warnings.warn(msg, UserWarning) diff --git a/tests/test_vec_check_nan.py b/tests/test_vec_check_nan.py new file mode 100644 index 0000000000..7f27a152ab --- /dev/null +++ b/tests/test_vec_check_nan.py @@ -0,0 +1,70 @@ +import gym +from gym import spaces +import numpy as np + +from stable_baselines.common.vec_env import DummyVecEnv, VecCheckNan + + +class NanAndInfEnv(gym.Env): + """Custom Environment that raised NaNs and Infs""" + metadata = {'render.modes': ['human']} + + def __init__(self): + super(NanAndInfEnv, self).__init__() + self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64) + self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64) + + @staticmethod + def step(action): + if all(np.array(action) > 0): + obs = float('NaN') + elif all(np.array(action) < 0): + obs = float('inf') + else: + obs = 0 + return [obs], 0.0, False, {} + + @staticmethod + def reset(): + return [0.0] + + def render(self, mode='human', close=False): + pass + + +def test_check_nan(): + """Test VecCheckNan Object""" + + env = DummyVecEnv([NanAndInfEnv]) + env = VecCheckNan(env, raise_exception=True) + + env.step([[0]]) + + try: + env.step([[float('NaN')]]) + except ValueError: + pass + else: + assert False + + try: + env.step([[float('inf')]]) + except ValueError: + pass + else: + assert False + + try: + env.step([[-1]]) + except ValueError: + pass + else: + assert False + + try: + env.step([[1]]) + except ValueError: + pass + else: + assert False +