Skip to content

Commit

Permalink
Fix normalization for off-policy algorithms (#732)
Browse files Browse the repository at this point in the history
* Fix normalization for off-policy algorithms
  • Loading branch information
araffin authored Mar 10, 2020
1 parent 63b1885 commit ed4e377
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 38 deletions.
2 changes: 2 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ Bug Fixes:
Only ``TRPO`` and ``PPO1`` update it differently (after synchronization) because they rely on MPI
- Fixed bug in ``TRPO`` with NaN standardized advantages (@richardwu)
- Fixed partial minibatch computation in ExpertDataset (@richardwu)
- Fixed normalization (with ``VecNormalize``) for off-policy algorithms
- Fixed ``sync_envs_normalization`` to sync the reward normalization too

Deprecations:
^^^^^^^^^^^^^
Expand Down
16 changes: 15 additions & 1 deletion stable_baselines/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from stable_baselines.common.save_util import data_to_json, json_to_data, params_to_bytes, bytes_to_params
from stable_baselines.common.policies import get_policy_from_name, ActorCriticPolicy
from stable_baselines.common.runners import AbstractEnvRunner
from stable_baselines.common.vec_env import VecEnvWrapper, VecEnv, DummyVecEnv
from stable_baselines.common.vec_env import (VecEnvWrapper, VecEnv, DummyVecEnv,
VecNormalize, unwrap_vec_normalize)
from stable_baselines.common.callbacks import BaseCallback, CallbackList, ConvertCallback
from stable_baselines import logger

Expand Down Expand Up @@ -91,6 +92,9 @@ def __init__(self, policy, env, verbose=0, *, requires_vec_env, policy_base,
" environment.")
self.n_envs = 1

# Get VecNormalize object if it exists
self._vec_normalize_env = unwrap_vec_normalize(self.env)

def get_env(self):
"""
returns the current environment (can be None if not defined)
Expand All @@ -99,6 +103,15 @@ def get_env(self):
"""
return self.env

def get_vec_normalize_env(self) -> Optional[VecNormalize]:
"""
Return the ``VecNormalize`` wrapper of the training env
if it exists.
:return: Optional[VecNormalize] The ``VecNormalize`` env.
"""
return self._vec_normalize_env

def set_env(self, env):
"""
Checks the validity of the environment, and if it is coherent, set it as the current environment.
Expand Down Expand Up @@ -142,6 +155,7 @@ def set_env(self, env):
self.n_envs = 1

self.env = env
self._vec_normalize_env = unwrap_vec_normalize(env)

# Invalidated by environment change.
self.episode_reward = None
Expand Down
53 changes: 41 additions & 12 deletions stable_baselines/common/buffers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import random
from typing import Optional, List, Union

import numpy as np

from stable_baselines.common.segment_tree import SumSegmentTree, MinSegmentTree
from stable_baselines.common.vec_env import VecNormalize


class ReplayBuffer(object):
def __init__(self, size):
def __init__(self, size: int):
"""
Implements a ring buffer (FIFO).
Expand All @@ -17,7 +19,7 @@ def __init__(self, size):
self._maxsize = size
self._next_idx = 0

def __len__(self):
def __len__(self) -> int:
return len(self._storage)

@property
Expand All @@ -26,11 +28,11 @@ def storage(self):
return self._storage

@property
def buffer_size(self):
def buffer_size(self) -> int:
"""float: Max capacity of the buffer"""
return self._maxsize

def can_sample(self, n_samples):
def can_sample(self, n_samples: int) -> bool:
"""
Check if n_samples samples can be sampled
from the buffer.
Expand All @@ -40,7 +42,7 @@ def can_sample(self, n_samples):
"""
return len(self) >= n_samples

def is_full(self):
def is_full(self) -> int:
"""
Check whether the replay buffer is full or not.
Expand Down Expand Up @@ -86,7 +88,27 @@ def extend(self, obs_t, action, reward, obs_tp1, done):
self._storage[self._next_idx] = data
self._next_idx = (self._next_idx + 1) % self._maxsize

def _encode_sample(self, idxes):
@staticmethod
def _normalize_obs(obs: np.ndarray,
env: Optional[VecNormalize] = None) -> np.ndarray:
"""
Helper for normalizing the observation.
"""
if env is not None:
return env.normalize_obs(obs)
return obs

@staticmethod
def _normalize_reward(reward: np.ndarray,
env: Optional[VecNormalize] = None) -> np.ndarray:
"""
Helper for normalizing the reward.
"""
if env is not None:
return env.normalize_reward(reward)
return reward

def _encode_sample(self, idxes: Union[List[int], np.ndarray], env: Optional[VecNormalize] = None):
obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], []
for i in idxes:
data = self._storage[i]
Expand All @@ -96,13 +118,19 @@ def _encode_sample(self, idxes):
rewards.append(reward)
obses_tp1.append(np.array(obs_tp1, copy=False))
dones.append(done)
return np.array(obses_t), np.array(actions), np.array(rewards), np.array(obses_tp1), np.array(dones)
return (self._normalize_obs(np.array(obses_t), env),
np.array(actions),
self._normalize_reward(np.array(rewards), env),
self._normalize_obs(np.array(obses_tp1), env),
np.array(dones))

def sample(self, batch_size, **_kwargs):
def sample(self, batch_size: int, env: Optional[VecNormalize] = None, **_kwargs):
"""
Sample a batch of experiences.
:param batch_size: (int) How many transitions to sample.
:param env: (Optional[VecNormalize]) associated gym VecEnv
to normalize the observations/rewards when sampling
:return:
- obs_batch: (np.ndarray) batch of observations
- act_batch: (numpy float) batch of actions executed given obs_batch
Expand All @@ -112,7 +140,7 @@ def sample(self, batch_size, **_kwargs):
and 0 otherwise.
"""
idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
return self._encode_sample(idxes)
return self._encode_sample(idxes, env=env)


class PrioritizedReplayBuffer(ReplayBuffer):
Expand Down Expand Up @@ -181,7 +209,7 @@ def _sample_proportional(self, batch_size):
idx = self._it_sum.find_prefixsum_idx(mass)
return idx

def sample(self, batch_size, beta=0):
def sample(self, batch_size: int, beta: float = 0, env: Optional[VecNormalize] = None):
"""
Sample a batch of experiences.
Expand All @@ -191,6 +219,8 @@ def sample(self, batch_size, beta=0):
:param batch_size: (int) How many transitions to sample.
:param beta: (float) To what degree to use importance weights (0 - no corrections, 1 - full correction)
:param env: (Optional[VecNormalize]) associated gym VecEnv
to normalize the observations/rewards when sampling
:return:
- obs_batch: (np.ndarray) batch of observations
- act_batch: (numpy float) batch of actions executed given obs_batch
Expand All @@ -210,7 +240,7 @@ def sample(self, batch_size, beta=0):
max_weight = (p_min * len(self._storage)) ** (-beta)
p_sample = self._it_sum[idxes] / self._it_sum.sum()
weights = (p_sample * len(self._storage)) ** (-beta) / max_weight
encoded_sample = self._encode_sample(idxes)
encoded_sample = self._encode_sample(idxes, env=env)
return tuple(list(encoded_sample) + [weights, idxes])

def update_priorities(self, idxes, priorities):
Expand All @@ -232,4 +262,3 @@ def update_priorities(self, idxes, priorities):
self._it_min[idxes] = priorities ** self._alpha

self._max_priority = max(self._max_priority, np.max(priorities))

3 changes: 2 additions & 1 deletion stable_baselines/common/vec_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ def sync_envs_normalization(env: Union[gym.Env, VecEnv], eval_env: Union[gym.Env
return
while isinstance(env_tmp, VecEnvWrapper):
if isinstance(env_tmp, VecNormalize):
# No need to sync the reward scaling
# sync reward and observation scaling
eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms)
env_tmp = env_tmp.venv
# Make pytype happy, in theory env and eval_env have the same type
assert isinstance(eval_env_tmp, VecEnvWrapper), "the second env differs from the first env"
Expand Down
41 changes: 30 additions & 11 deletions stable_baselines/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,8 @@ def _train_step(self, step, writer, log=False):
:return: (float, float) critic loss, actor loss
"""
# Get a batch
obs, actions, rewards, next_obs, terminals = self.replay_buffer.sample(batch_size=self.batch_size)
obs, actions, rewards, next_obs, terminals = self.replay_buffer.sample(batch_size=self.batch_size,
env=self._vec_normalize_env)
# Reshape to match previous behavior and placeholder shape
rewards = rewards.reshape(-1, 1)
terminals = terminals.reshape(-1, 1)
Expand Down Expand Up @@ -735,7 +736,8 @@ def _get_stats(self):
if self.stats_sample is None:
# Get a sample and keep that fixed for all further computations.
# This allows us to estimate the change in value for the same set of inputs.
obs, actions, rewards, next_obs, terminals = self.replay_buffer.sample(batch_size=self.batch_size)
obs, actions, rewards, next_obs, terminals = self.replay_buffer.sample(batch_size=self.batch_size,
env=self._vec_normalize_env)
self.stats_sample = {
'obs': obs,
'actions': actions,
Expand Down Expand Up @@ -777,7 +779,7 @@ def _adapt_param_noise(self):
return 0.

# Perturb a separate copy of the policy to adjust the scale for the next "real" perturbation.
obs, *_ = self.replay_buffer.sample(batch_size=self.batch_size)
obs, *_ = self.replay_buffer.sample(batch_size=self.batch_size, env=self._vec_normalize_env)
self.sess.run(self.perturb_adaptive_policy_ops, feed_dict={
self.param_noise_stddev: self.param_noise.current_stddev,
})
Expand Down Expand Up @@ -832,6 +834,9 @@ def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="D
# Prepare everything.
self._reset()
obs = self.env.reset()
# Retrieve unnormalized observation for saving into the buffer
if self._vec_normalize_env is not None:
obs_ = self._vec_normalize_env.get_original_obs().squeeze()
eval_obs = None
if self.eval_env is not None:
eval_obs = self.eval_env.reset()
Expand Down Expand Up @@ -894,23 +899,37 @@ def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="D
callback.on_training_end()
return self

if writer is not None:
ep_rew = np.array([reward]).reshape((1, -1))
ep_done = np.array([done]).reshape((1, -1))
tf_util.total_episode_reward_logger(self.episode_reward, ep_rew, ep_done,
writer, self.num_timesteps)
step += 1
total_steps += 1
if rank == 0 and self.render:
self.env.render()
episode_reward += reward
episode_step += 1

# Book-keeping.
epoch_actions.append(action)
epoch_qs.append(q_value)
self._store_transition(obs, action, reward, new_obs, done)

# Store only the unnormalized version
if self._vec_normalize_env is not None:
new_obs_ = self._vec_normalize_env.get_original_obs().squeeze()
reward_ = self._vec_normalize_env.get_original_reward().squeeze()
else:
# Avoid changing the original ones
obs_, new_obs_, reward_ = obs, new_obs, reward

self._store_transition(obs_, action, reward_, new_obs_, done)
obs = new_obs
# Save the unnormalized observation
if self._vec_normalize_env is not None:
obs_ = new_obs_

episode_reward += reward_
episode_step += 1

if writer is not None:
ep_rew = np.array([reward_]).reshape((1, -1))
ep_done = np.array([done]).reshape((1, -1))
tf_util.total_episode_reward_logger(self.episode_reward, ep_rew, ep_done,
writer, self.num_timesteps)

if done:
# Episode done.
Expand Down
25 changes: 20 additions & 5 deletions stable_baselines/deepq/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="D

reset = True
obs = self.env.reset()
# Retrieve unnormalized observation for saving into the buffer
if self._vec_normalize_env is not None:
obs_ = self._vec_normalize_env.get_original_obs().squeeze()

for _ in range(total_timesteps):
# Take action and update exploration to the newest value
Expand Down Expand Up @@ -221,17 +224,27 @@ def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="D
if callback.on_step() is False:
break

# Store only the unnormalized version
if self._vec_normalize_env is not None:
new_obs_ = self._vec_normalize_env.get_original_obs().squeeze()
reward_ = self._vec_normalize_env.get_original_reward().squeeze()
else:
# Avoid changing the original ones
obs_, new_obs_, reward_ = obs, new_obs, rew
# Store transition in the replay buffer.
self.replay_buffer.add(obs, action, rew, new_obs, float(done))
self.replay_buffer.add(obs_, action, reward_, new_obs_, float(done))
obs = new_obs
# Save the unnormalized observation
if self._vec_normalize_env is not None:
obs_ = new_obs_

if writer is not None:
ep_rew = np.array([rew]).reshape((1, -1))
ep_rew = np.array([reward_]).reshape((1, -1))
ep_done = np.array([done]).reshape((1, -1))
tf_util.total_episode_reward_logger(self.episode_reward, ep_rew, ep_done, writer,
self.num_timesteps)

episode_rewards[-1] += rew
episode_rewards[-1] += reward_
if done:
maybe_is_success = info.get('is_success')
if maybe_is_success is not None:
Expand All @@ -254,10 +267,12 @@ def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="D
assert self.beta_schedule is not None, \
"BUG: should be LinearSchedule when self.prioritized_replay True"
experience = self.replay_buffer.sample(self.batch_size,
beta=self.beta_schedule.value(self.num_timesteps))
beta=self.beta_schedule.value(self.num_timesteps),
env=self._vec_normalize_env)
(obses_t, actions, rewards, obses_tp1, dones, weights, batch_idxes) = experience
else:
obses_t, actions, rewards, obses_tp1, dones = self.replay_buffer.sample(self.batch_size)
obses_t, actions, rewards, obses_tp1, dones = self.replay_buffer.sample(self.batch_size,
env=self._vec_normalize_env)
weights, batch_idxes = np.ones_like(rewards), None
# pytype:enable=bad-unpacking

Expand Down
Loading

0 comments on commit ed4e377

Please sign in to comment.