From aad197baf0a72e001e4e45465b3927b15a7ddb7f Mon Sep 17 00:00:00 2001 From: krishnan Date: Mon, 21 Sep 2020 11:32:30 -0700 Subject: [PATCH] Remove step in SAC._train_step arglist, fixes issue of redundant wrapping of HER replay_buffer --- docs/misc/changelog.rst | 3 ++- stable_baselines/her/her.py | 5 ++++- stable_baselines/sac/sac.py | 5 ++--- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 26fa36f504..0c5253ae9a 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -65,7 +65,8 @@ Bug Fixes: - Fixed a bug in CloudPickleWrapper's (used by VecEnvs) ``__setstate___`` where loading was incorrectly using ``pickle.loads`` (@shwang). - Fixed a bug in ``SAC`` and ``TD3`` where the log timesteps was not correct(@YangRui2015) - Fixed a bug where the environment was reset twice when using ``evaluate_policy`` -- Fixed a bug where ``SAC`` uses wrong step to log to tensorboard after multiple calls to ``SAC.learn(..., reset_num_timesteps=True)`` +- Fixed a bug where ``SAC`` uses wrong step to log to tensorboard after multiple calls to ``SAC.learn(..., reset_num_timesteps=True)`` (@krishpop) +- Fixed issue where HER replay buffer wrapper is used multiple times after multiple calls to ``HER.learn`` (@krishpop) Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines/her/her.py b/stable_baselines/her/her.py index 6a9e89f43d..5ce1ac9044 100644 --- a/stable_baselines/her/her.py +++ b/stable_baselines/her/her.py @@ -70,6 +70,7 @@ def _create_replay_wrapper(self, env): n_sampled_goal=self.n_sampled_goal, goal_selection_strategy=self.goal_selection_strategy, wrapped_env=self.env) + self.wrapped_buffer = False def set_env(self, env): assert not isinstance(env, VecEnvWrapper), "HER does not support VecEnvWrapper" @@ -108,9 +109,11 @@ def setup_model(self): def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="HER", reset_num_timesteps=True): + replay_wrapper = self.replay_wrapper if self.wrapped_buffer else None + self.wrapped_buffer = True return self.model.learn(total_timesteps, callback=callback, log_interval=log_interval, tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps, - replay_wrapper=self.replay_wrapper) + replay_wrapper=replay_wrapper) def _check_obs(self, observation): if isinstance(observation, dict): diff --git a/stable_baselines/sac/sac.py b/stable_baselines/sac/sac.py index e303dc7150..fbf9db0e61 100644 --- a/stable_baselines/sac/sac.py +++ b/stable_baselines/sac/sac.py @@ -313,8 +313,7 @@ def setup_model(self): self.summary = tf.summary.merge_all() - def _train_step(self, step, writer, learning_rate): - del step + def _train_step(self, writer, learning_rate): # Sample a batch from the replay buffer batch = self.replay_buffer.sample(self.batch_size, env=self._vec_normalize_env) batch_obs, batch_actions, batch_rewards, batch_next_obs, batch_dones = batch @@ -461,7 +460,7 @@ def learn(self, total_timesteps, callback=None, frac = 1.0 - step / total_timesteps current_lr = self.learning_rate(frac) # Update policy and critics (q functions) - mb_infos_vals.append(self._train_step(step, writer, current_lr)) + mb_infos_vals.append(self._train_step(writer, current_lr)) # Update target network if (step + grad_step) % self.target_update_interval == 0: # Update target network