From 65ed3969e8859092e32e0cf89ac42959a7f283d6 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 6 Jun 2019 11:03:13 +0200 Subject: [PATCH] Bug fix when not enough samples in the replay buffer (#354) * Bug fix when not enough samples in the replay buffer * Correct typo --- docs/misc/changelog.rst | 4 +++- stable_baselines/ddpg/ddpg.py | 4 ++++ stable_baselines/deepq/dqn.py | 8 ++++++-- stable_baselines/deepq/replay_buffer.py | 10 ++++++++++ stable_baselines/her/replay_buffer.py | 11 +++++++++++ stable_baselines/sac/sac.py | 5 ++++- tests/test_her.py | 22 ++++++++++++++++++++++ 7 files changed, 60 insertions(+), 4 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 0dfb0a7c67..6cdabe71e6 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -29,6 +29,8 @@ Pre-Release 2.6.0a0 (WIP) - **important change** switched to using dictionaries rather than lists when storing parameters, with tensorflow Variable names being the keys. (@Miffyli) - added specific hyperparameter for PPO2 to clip the value function (``cliprange_vf``) - fixed ``num_timesteps`` (total_timesteps) variable in PPO2 that was wrongly computed. +- fixed a bug in DDPG/DQN/SAC, when there were the number of samples in the replay buffer was lesser than the batch size + (thanks to @dwiel for spotting the bug) **Breaking Change:** DDPG replay buffer was unified with DQN/SAC replay buffer. As a result, when loading a DDPG model trained with stable_baselines<2.6.0, it throws an import error. @@ -342,4 +344,4 @@ In random order... Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk @JohannesAck @EliasHasle @mrakgr @Bleyddyn @antoine-galataud @junhyeokahn @AdamGleave @keshaviyengar @tperol @XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs -@Miffyli +@Miffyli @dwiel diff --git a/stable_baselines/ddpg/ddpg.py b/stable_baselines/ddpg/ddpg.py index 6cb9d04daa..6ea80338b0 100644 --- a/stable_baselines/ddpg/ddpg.py +++ b/stable_baselines/ddpg/ddpg.py @@ -914,6 +914,10 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ epoch_critic_losses = [] epoch_adaptive_distances = [] for t_train in range(self.nb_train_steps): + # Not enough samples in the replay buffer + if not self.replay_buffer.can_sample(self.batch_size): + break + # Adapt param noise, if necessary. if len(self.replay_buffer) >= self.batch_size and \ t_train % self.param_noise_adaption_interval == 0: diff --git a/stable_baselines/deepq/dqn.py b/stable_baselines/deepq/dqn.py index 4fd504c6cf..1369c972a4 100644 --- a/stable_baselines/deepq/dqn.py +++ b/stable_baselines/deepq/dqn.py @@ -229,7 +229,11 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ episode_rewards.append(0.0) reset = True - if self.num_timesteps > self.learning_starts and self.num_timesteps % self.train_freq == 0: + # Do not train if the warmup phase is not over + # or if there are not enough samples in the replay buffer + can_sample = self.replay_buffer.can_sample(self.batch_size) + if can_sample and self.num_timesteps > self.learning_starts \ + and self.num_timesteps % self.train_freq == 0: # Minimize the error in Bellman's equation on a batch sampled from replay buffer. if self.prioritized_replay: experience = self.replay_buffer.sample(self.batch_size, @@ -261,7 +265,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ new_priorities = np.abs(td_errors) + self.prioritized_replay_eps self.replay_buffer.update_priorities(batch_idxes, new_priorities) - if self.num_timesteps > self.learning_starts and \ + if can_sample and self.num_timesteps > self.learning_starts and \ self.num_timesteps % self.target_network_update_freq == 0: # Update target network periodically. self.update_target(sess=self.sess) diff --git a/stable_baselines/deepq/replay_buffer.py b/stable_baselines/deepq/replay_buffer.py index a8771a7570..6c78328829 100644 --- a/stable_baselines/deepq/replay_buffer.py +++ b/stable_baselines/deepq/replay_buffer.py @@ -30,6 +30,16 @@ def buffer_size(self): """float: Max capacity of the buffer""" return self._maxsize + def can_sample(self, n_samples): + """ + Check if n_samples samples can be sampled + from the buffer. + + :param n_samples: (int) + :return: (bool) + """ + return len(self) >= n_samples + def is_full(self): """ Check whether the replay buffer is full or not. diff --git a/stable_baselines/her/replay_buffer.py b/stable_baselines/her/replay_buffer.py index 82f61a7f49..b295890d13 100644 --- a/stable_baselines/her/replay_buffer.py +++ b/stable_baselines/her/replay_buffer.py @@ -82,9 +82,20 @@ def add(self, obs_t, action, reward, obs_tp1, done): def sample(self, *args, **kwargs): return self.replay_buffer.sample(*args, **kwargs) + def can_sample(self, n_samples): + """ + Check if n_samples samples can be sampled + from the buffer. + + :param n_samples: (int) + :return: (bool) + """ + return self.replay_buffer.can_sample(n_samples) + def __len__(self): return len(self.replay_buffer) + def _sample_achieved_goal(self, episode_transitions, transition_idx): """ Sample an achieved goal according to the sampling strategy. diff --git a/stable_baselines/sac/sac.py b/stable_baselines/sac/sac.py index 524ee2d75e..99313efe85 100644 --- a/stable_baselines/sac/sac.py +++ b/stable_baselines/sac/sac.py @@ -436,7 +436,10 @@ def learn(self, total_timesteps, callback=None, seed=None, mb_infos_vals = [] # Update policy, critics and target networks for grad_step in range(self.gradient_steps): - if self.num_timesteps < self.batch_size or self.num_timesteps < self.learning_starts: + # Break if the warmup phase is not over + # or if there are not enough samples in the replay buffer + if not self.replay_buffer.can_sample(self.batch_size) \ + or self.num_timesteps < self.learning_starts: break n_updates += 1 # Compute current learning_rate diff --git a/tests/test_her.py b/tests/test_her.py index 58a38e67ac..269c9a24de 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -45,6 +45,28 @@ def test_her(model_class, goal_selection_strategy, discrete_obs_space): model.learn(1000) +@pytest.mark.parametrize('model_class', [DDPG, SAC, DQN]) +def test_long_episode(model_class): + """ + Check that the model does not break when the replay buffer is still empty + after the first rollout (because the episode is not over). + """ + # n_bits > nb_rollout_steps + n_bits = 10 + env = BitFlippingEnv(n_bits, continuous=model_class in [DDPG, SAC], + max_steps=n_bits) + kwargs = {} + if model_class == DDPG: + kwargs['nb_rollout_steps'] = 9 # < n_bits + elif model_class in [DQN, SAC]: + kwargs['batch_size'] = 8 # < n_bits + kwargs['learning_starts'] = 0 + + model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy='future', + verbose=0, **kwargs) + model.learn(200) + + @pytest.mark.parametrize('goal_selection_strategy', [list(KEY_TO_GOAL_STRATEGY.keys())[0]]) @pytest.mark.parametrize('model_class', [DQN, SAC, DDPG]) def test_model_manipulation(model_class, goal_selection_strategy):