diff --git a/stable_baselines/td3/td3.py b/stable_baselines/td3/td3.py index 307c76bc24..de6d6cbd48 100644 --- a/stable_baselines/td3/td3.py +++ b/stable_baselines/td3/td3.py @@ -63,7 +63,7 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=3e-4, buffer_size=5000 _init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False): super(TD3, self).__init__(policy=policy, env=env, replay_buffer=None, verbose=verbose, - policy_base=TD3Policy, requires_vec_env=False, policy_kwargs=policy_kwargs) + policy_base=TD3Policy, requires_vec_env=True, policy_kwargs=policy_kwargs) self.buffer_size = buffer_size self.learning_rate = learning_rate @@ -294,31 +294,29 @@ def learn(self, total_timesteps, callback=None, seed=None, start_time = time.time() episode_rewards = [0.0] episode_successes = [] + obs = self.env.reset() if self.action_noise is not None: self.action_noise.reset() - obs = self.env.reset() self.episode_reward = np.zeros((1,)) ep_info_buf = deque(maxlen=100) n_updates = 0 infos_values = [] - for step in range(total_timesteps): + for step in range(0, total_timesteps, self.n_envs): if callback is not None: # Only stop training if return value is False, not when it is None. This is for backwards # compatibility with callbacks that have no return statement. if callback(locals(), globals()) is False: break - # Before training starts, randomly sample actions - # from a uniform distribution for better exploration. - # Afterwards, use the learned policy - # if random_exploration is set to 0 (normal setting) + prev_obs = obs if (self.num_timesteps < self.learning_starts - or np.random.rand() < self.random_exploration): + or np.random.rand() < self.random_exploration): # No need to rescale when sampling random action - rescaled_action = action = self.env.action_space.sample() + rescaled_action = action = [self.env.action_space.sample() for _ in range(self.n_envs)] else: - action = self.policy_tf.step(obs[None]).flatten() + action = self.policy_tf.step(prev_obs).flatten() + action = [np.array([a]) for a in action] # Add noise to the action, as the policy # is deterministic, this is required for exploration if self.action_noise is not None: @@ -326,89 +324,91 @@ def learn(self, total_timesteps, callback=None, seed=None, # Rescale from [-1, 1] to the correct bounds rescaled_action = action * np.abs(self.action_space.low) - assert action.shape == self.env.action_space.shape - - new_obs, reward, done, info = self.env.step(rescaled_action) - - # Store transition in the replay buffer. - self.replay_buffer.add(obs, action, reward, new_obs, float(done)) - obs = new_obs - - # Retrieve reward and episode length if using Monitor wrapper - maybe_ep_info = info.get('episode') - if maybe_ep_info is not None: - ep_info_buf.extend([maybe_ep_info]) - - if writer is not None: - # Write reward per episode to tensorboard - ep_reward = np.array([reward]).reshape((1, -1)) - ep_done = np.array([done]).reshape((1, -1)) - self.episode_reward = total_episode_reward_logger(self.episode_reward, ep_reward, - ep_done, writer, self.num_timesteps) - - if step % self.train_freq == 0: - mb_infos_vals = [] - # Update policy, critics and target networks - for grad_step in range(self.gradient_steps): - # Break if the warmup phase is not over - # or if there are not enough samples in the replay buffer - if not self.replay_buffer.can_sample(self.batch_size) \ - or self.num_timesteps < self.learning_starts: - break - n_updates += 1 - # Compute current learning_rate - frac = 1.0 - step / total_timesteps - current_lr = self.learning_rate(frac) - # Update policy and critics (q functions) - # Note: the policy is updated less frequently than the Q functions - # this is controlled by the `policy_delay` parameter - mb_infos_vals.append( - self._train_step(step, writer, current_lr, (step + grad_step) % self.policy_delay == 0)) - - # Log losses and entropy, useful for monitor training - if len(mb_infos_vals) > 0: - infos_values = np.mean(mb_infos_vals, axis=0) - - episode_rewards[-1] += reward - if done: - if self.action_noise is not None: - self.action_noise.reset() - if not isinstance(self.env, VecEnv): - obs = self.env.reset() - episode_rewards.append(0.0) + for i in range(self.n_envs): + assert action[i].shape == self.env.action_space.shape + + obs, reward, done, info = self.env.step(rescaled_action) + + for i in range(self.n_envs): + # Store transition in the replay buffer. + self.replay_buffer.add(prev_obs[i], action[i], reward[i], obs[i], float(done[i])) + # Retrieve reward and episode length if using Monitor wrapper + maybe_ep_info = info[i].get('episode') + if maybe_ep_info is not None: + ep_info_buf.extend([maybe_ep_info]) + if writer is not None: + # Write reward per episode to tensorboard + ep_reward = np.array([reward[i]]).reshape((1, -1)) + ep_done = np.array([done[i]]).reshape((1, -1)) + self.episode_reward = total_episode_reward_logger(self.episode_reward, ep_reward, + ep_done, writer, self.num_timesteps) + + if step % self.train_freq == 0: + mb_infos_vals = [] + # Update policy, critics and target networks + for grad_step in range(self.gradient_steps): + # Break if the warmup phase is not over + # or if there are not enough samples in the replay buffer + if not self.replay_buffer.can_sample(self.batch_size) \ + or self.num_timesteps < self.learning_starts: + break + n_updates += 1 + # Compute current learning_rate + frac = 1.0 - step / total_timesteps + current_lr = self.learning_rate(frac) + # Update policy and critics (q functions) + # Note: the policy is updated less frequently than the Q functions + # this is controlled by the `policy_delay` parameter + mb_infos_vals.append( + self._train_step(step, writer, current_lr, (step + grad_step) % self.policy_delay == 0)) + + # Log losses and entropy, useful for monitor training + if len(mb_infos_vals) > 0: + infos_values = np.mean(mb_infos_vals, axis=0) + + episode_rewards[-1] += reward[i] + if done[i]: + if self.action_noise is not None: + self.action_noise.reset() + if not isinstance(self.env, VecEnv): + runner.reset() + episode_rewards.append(0.0) + + maybe_is_success = info[i].get('is_success') + if maybe_is_success is not None: + episode_successes.append(float(maybe_is_success)) + + if len(episode_rewards[-101:-1]) == 0: + mean_reward = -np.inf + else: + mean_reward = round(float(np.mean(episode_rewards[-101:-1])), 1) + + num_episodes = len(episode_rewards) + self.num_timesteps += 1 + # Display training infos + if self.verbose >= 1 and done[i] and log_interval is not None and len(episode_rewards) % log_interval == 0: + fps = int(step / (time.time() - start_time)) + logger.logkv("episodes", num_episodes) + logger.logkv("mean 100 episode reward", mean_reward) + if len(ep_info_buf) > 0 and len(ep_info_buf[0]) > 0: + logger.logkv('ep_rewmean', safe_mean([ep_info['r'] for ep_info in ep_info_buf])) + logger.logkv('eplenmean', safe_mean([ep_info['l'] for ep_info in ep_info_buf])) + logger.logkv("n_updates", n_updates) + logger.logkv("current_lr", current_lr) + logger.logkv("fps", fps) + logger.logkv('time_elapsed', int(time.time() - start_time)) + if len(episode_successes) > 0: + logger.logkv("success rate", np.mean(episode_successes[-100:])) + if len(infos_values) > 0: + for (name, val) in zip(self.infos_names, infos_values): + logger.logkv(name, val) + logger.logkv("total timesteps", self.num_timesteps) + logger.dumpkvs() + # Reset infos: + infos_values = [] + + step += 1 - maybe_is_success = info.get('is_success') - if maybe_is_success is not None: - episode_successes.append(float(maybe_is_success)) - - if len(episode_rewards[-101:-1]) == 0: - mean_reward = -np.inf - else: - mean_reward = round(float(np.mean(episode_rewards[-101:-1])), 1) - - num_episodes = len(episode_rewards) - self.num_timesteps += 1 - # Display training infos - if self.verbose >= 1 and done and log_interval is not None and len(episode_rewards) % log_interval == 0: - fps = int(step / (time.time() - start_time)) - logger.logkv("episodes", num_episodes) - logger.logkv("mean 100 episode reward", mean_reward) - if len(ep_info_buf) > 0 and len(ep_info_buf[0]) > 0: - logger.logkv('ep_rewmean', safe_mean([ep_info['r'] for ep_info in ep_info_buf])) - logger.logkv('eplenmean', safe_mean([ep_info['l'] for ep_info in ep_info_buf])) - logger.logkv("n_updates", n_updates) - logger.logkv("current_lr", current_lr) - logger.logkv("fps", fps) - logger.logkv('time_elapsed', int(time.time() - start_time)) - if len(episode_successes) > 0: - logger.logkv("success rate", np.mean(episode_successes[-100:])) - if len(infos_values) > 0: - for (name, val) in zip(self.infos_names, infos_values): - logger.logkv(name, val) - logger.logkv("total timesteps", self.num_timesteps) - logger.dumpkvs() - # Reset infos: - infos_values = [] return self def action_probability(self, observation, state=None, mask=None, actions=None, logp=False):