diff --git a/README.md b/README.md
index 663b559677..f03227f45c 100644
--- a/README.md
+++ b/README.md
@@ -20,6 +20,8 @@ This toolset is a fork of OpenAI Baselines, with a major structural refactoring,
- PEP8 compliant (unified code style)
- Documented functions and classes
- More tests & more code coverage
+- Additional algorithms: SAC and TD3 (+ HER support for DQN, DDPG, SAC and TD3)
+
| **Features** | **Stable-Baselines** | **OpenAI Baselines** |
| --------------------------- | --------------------------------- | --------------------------------- |
@@ -33,7 +35,7 @@ This toolset is a fork of OpenAI Baselines, with a major structural refactoring,
| PEP8 code style | :heavy_check_mark: | :heavy_check_mark: (5) |
| Custom callback | :heavy_check_mark: | :heavy_minus_sign: (6) |
-(1): Forked from previous version of OpenAI baselines, with now SAC in addition
+(1): Forked from previous version of OpenAI baselines, with now SAC and TD3 in addition
(2): Currently not available for DDPG, and only from the run script.
(3): Only via the run script.
(4): Rudimentary logging of training information (no loss nor graph).
@@ -156,15 +158,16 @@ All the following examples can be executed online using Google colab notebooks:
| PPO1 | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: (4) |
| PPO2 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| SAC | :heavy_check_mark: | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: |
+| TD3 | :heavy_check_mark: | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: |
| TRPO | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: (4) |
(1): Whether or not the algorithm has be refactored to fit the ```BaseRLModel``` class.
(2): Only implemented for TRPO.
-(3): Re-implemented from scratch
+(3): Re-implemented from scratch, now supports DQN, DDPG, SAC and TD3
(4): Multi Processing with [MPI](https://mpi4py.readthedocs.io/en/stable/).
(5): TODO, in project scope.
-NOTE: Soft Actor-Critic (SAC) was not part of the original baselines and HER was reimplemented from scratch.
+NOTE: Soft Actor-Critic (SAC) and Twin Delayed DDPG (TD3) were not part of the original baselines and HER was reimplemented from scratch.
Actions ```gym.spaces```:
* ```Box```: A N-dimensional box that containes every point in the action space.
@@ -220,4 +223,4 @@ If you want to contribute, please read **CONTRIBUTING.md** guide first.
Stable Baselines was created in the [robotics lab U2IS](http://u2is.ensta-paristech.fr/index.php?lang=en) ([INRIA Flowers](https://flowers.inria.fr/) team) at [ENSTA ParisTech](http://www.ensta-paristech.fr/en).
-Logo credits: L.M. Tenkes
+Logo credits: [L.M. Tenkes](https://www.instagram.com/lucillehue/)
diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst
index 354ade7f97..d3c29b15c5 100644
--- a/docs/guide/algos.rst
+++ b/docs/guide/algos.rst
@@ -25,6 +25,7 @@ GAIL [#f2]_ ✔️ ✔️ ✔️ ✔️
PPO1 ✔️ ❌ ✔️ ✔️ ✔️ [#f3]_
PPO2 ✔️ ✔️ ✔️ ✔️ ✔️
SAC ✔️ ❌ ✔️ ❌ ❌
+TD3 ✔️ ❌ ✔️ ❌ ❌
TRPO ✔️ ❌ ✔️ ✔ ✔️ [#f3]_
============ ======================== ========= =========== ============ ================
@@ -34,8 +35,8 @@ TRPO ✔️ ❌ ✔️ ✔
.. [#f4] TODO, in project scope.
.. note::
- Non-array spaces such as `Dict` or `Tuple` are not currently supported by any algorithm,
- except HER for dict when working with gym.GoalEnv
+ Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm,
+ except HER for dict when working with ``gym.GoalEnv``
Actions ``gym.spaces``:
diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst
index 42f38535ab..ad6757a038 100644
--- a/docs/guide/examples.rst
+++ b/docs/guide/examples.rst
@@ -50,7 +50,7 @@ In the following example, we will train, save and load a DQN model on the Lunar
``load`` function re-creates model from scratch on each call, which can be slow.
If you need to e.g. evaluate same model with multiple different sets of parameters, consider
using ``load_parameters`` instead.
-
+
.. code-block:: python
import gym
@@ -318,7 +318,7 @@ Accessing and modifying model parameters
----------------------------------------
You can access model's parameters via ``load_parameters`` and ``get_parameters`` functions, which
-use dictionaries that map variable names to NumPy arrays.
+use dictionaries that map variable names to NumPy arrays.
These functions are useful when you need to e.g. evaluate large set of models with same network structure,
visualize different layers of the network or modify parameters manually.
@@ -326,7 +326,7 @@ visualize different layers of the network or modify parameters manually.
You can access original Tensorflow Variables with function ``get_parameter_list``.
Following example demonstrates reading parameters, modifying some of them and loading them to model
-by implementing `evolution strategy `_
+by implementing `evolution strategy `_
for solving ``CartPole-v1`` environment. The initial guess for parameters is obtained by running
A2C policy gradient updates on the model.
@@ -466,7 +466,7 @@ The parking env is a goal-conditioned continuous control task, in which the vehi
import highway_env
import numpy as np
- from stable_baselines import HER, SAC, DDPG
+ from stable_baselines import HER, SAC, DDPG, TD3
from stable_baselines.ddpg import NormalActionNoise
env = gym.make("parking-v0")
diff --git a/docs/guide/vec_envs.rst b/docs/guide/vec_envs.rst
index 6bcbf76f2f..3e20c712fb 100644
--- a/docs/guide/vec_envs.rst
+++ b/docs/guide/vec_envs.rst
@@ -36,6 +36,11 @@ SubprocVecEnv ✔️ ✔️ ✔️ ✔️ ✔️
For more information, see Python's `multiprocessing guidelines `_.
+VecEnv
+------
+
+.. autoclass:: VecEnv
+ :members:
DummyVecEnv
-----------
diff --git a/docs/index.rst b/docs/index.rst
index 895f1e6a35..6605124803 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -30,6 +30,7 @@ This toolset is a fork of OpenAI Baselines, with a major structural refactoring,
- PEP8 compliant (unified code style)
- Documented functions and classes
- More tests & more code coverage
+- Additional algorithms: SAC and TD3 (+ HER support for DQN, DDPG, SAC and TD3)
.. toctree::
@@ -66,6 +67,7 @@ This toolset is a fork of OpenAI Baselines, with a major structural refactoring,
modules/ppo1
modules/ppo2
modules/sac
+ modules/td3
modules/trpo
.. toctree::
diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst
index 2f434815b4..e84d1087e8 100644
--- a/docs/misc/changelog.rst
+++ b/docs/misc/changelog.rst
@@ -6,14 +6,17 @@ Changelog
For download links, please look at `Github release page `_.
-Pre-Release 2.6.1a0 (WIP)
+Pre-Release 2.7.0a0 (WIP)
--------------------------
+**Twin Delayed DDPG (TD3)**
+
Breaking Changes:
^^^^^^^^^^^^^^^^^
New Features:
^^^^^^^^^^^^^
+- added Twin Delayed DDPG (TD3) algorithm, with HER support
- Add support for continuous action spaces to `action_probability`, computing the PDF of a Gaussian
policy in addition to the existing support for categorical stochastic policies.
@@ -34,6 +37,7 @@ Others:
- renamed some keys in ``traj_segment_generator`` to be more meaningful
- retrieve unnormalized reward when using Monitor wrapper with TRPO, PPO1 and GAIL
to display them in the logs (mean episode reward)
+- Clean up DDPG code (renamed variables)
Documentation:
^^^^^^^^^^^^^^
diff --git a/docs/modules/her.rst b/docs/modules/her.rst
index 7b8cb8ed05..8539dfaf9f 100644
--- a/docs/modules/her.rst
+++ b/docs/modules/her.rst
@@ -8,7 +8,7 @@ HER
`Hindsight Experience Replay (HER) `_
-HER is a method wrapper that works with Off policy methods (DQN, SAC and DDPG for example).
+HER is a method wrapper that works with Off policy methods (DQN, SAC, TD3 and DDPG for example).
.. note::
@@ -39,20 +39,20 @@ Notes
Can I use?
----------
-Please refer to the wrapped model (DQN, SAC or DDPG) for that section.
+Please refer to the wrapped model (DQN, SAC, TD3 or DDPG) for that section.
Example
-------
.. code-block:: python
- from stable_baselines import HER, DQN, SAC, DDPG
+ from stable_baselines import HER, DQN, SAC, DDPG, TD3
from stable_baselines.her import GoalSelectionStrategy, HERGoalEnvWrapper
from stable_baselines.common.bit_flipping_env import BitFlippingEnv
- model_class = DQN # works also with SAC and DDPG
+ model_class = DQN # works also with SAC, DDPG and TD3
- env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC], max_steps=N_BITS)
+ env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS)
# Available strategies (cf paper): future, final, episode, random
goal_selection_strategy = 'future' # equivalent to GoalSelectionStrategy.FUTURE
diff --git a/docs/modules/policies.rst b/docs/modules/policies.rst
index 48a97c8445..b630f57322 100644
--- a/docs/modules/policies.rst
+++ b/docs/modules/policies.rst
@@ -15,7 +15,7 @@ If you need more control on the policy architecture, you can also create a custo
CnnPolicies are for images only. MlpPolicies are made for other type of features (e.g. robot joints)
.. warning::
- For all algorithms (except DDPG and SAC), continuous actions are clipped during training and testing
+ For all algorithms (except DDPG, TD3 and SAC), continuous actions are clipped during training and testing
(to avoid out of bound error).
diff --git a/docs/modules/ppo2.rst b/docs/modules/ppo2.rst
index 44fef2903f..aebe57a314 100644
--- a/docs/modules/ppo2.rst
+++ b/docs/modules/ppo2.rst
@@ -8,7 +8,7 @@ PPO2
The `Proximal Policy Optimization `_ algorithm combines ideas from A2C (having multiple workers)
and TRPO (it uses a trust region to improve the actor).
-The main idea is that after an update, the new policy should be not too far form the `old` policy.
+The main idea is that after an update, the new policy should be not too far form the old policy.
For that, ppo uses clipping to avoid too large update.
.. note::
diff --git a/docs/modules/sac.rst b/docs/modules/sac.rst
index 09d51b58c5..dbe1fae41e 100644
--- a/docs/modules/sac.rst
+++ b/docs/modules/sac.rst
@@ -5,8 +5,13 @@
SAC
===
+
`Soft Actor Critic (SAC) `_ Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.
+SAC is the successor of `Soft Q-Learning SQL `_ and incorporates the double Q-learning trick from TD3.
+A key feature of SAC, and a major difference with common RL algorithms, is that it is trained to maximize a trade-off between expected return and entropy, a measure of randomness in the policy.
+
+
.. warning::
The SAC model does not support ``stable_baselines.common.policies`` because it uses double q-values
diff --git a/docs/modules/td3.rst b/docs/modules/td3.rst
new file mode 100644
index 0000000000..52a48c3c13
--- /dev/null
+++ b/docs/modules/td3.rst
@@ -0,0 +1,163 @@
+.. _td3:
+
+.. automodule:: stable_baselines.td3
+
+
+TD3
+===
+
+`Twin Delayed DDPG (TD3) `_ Addressing Function Approximation Error in Actor-Critic Methods.
+
+TD3 is a direct successor of DDPG and improves it using three major tricks: clipped double Q-Learning, delayed policy update and target policy smoothing.
+We recommend reading `OpenAI Spinning guide on TD3 `_ to learn more about those.
+
+
+.. warning::
+
+ The TD3 model does not support ``stable_baselines.common.policies`` because it uses double q-values
+ estimation, as a result it must use its own policy models (see :ref:`td3_policies`).
+
+
+.. rubric:: Available Policies
+
+.. autosummary::
+ :nosignatures:
+
+ MlpPolicy
+ LnMlpPolicy
+ CnnPolicy
+ LnCnnPolicy
+
+Notes
+-----
+
+- Original paper: https://arxiv.org/pdf/1802.09477.pdf
+- OpenAI Spinning Guide for TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html
+- Original Implementation: https://github.com/sfujim/TD3
+
+.. note::
+
+ The default policies for TD3 differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation,
+ to match the original paper
+
+
+Can I use?
+----------
+
+- Recurrent policies: ❌
+- Multi processing: ❌
+- Gym spaces:
+
+
+============= ====== ===========
+Space Action Observation
+============= ====== ===========
+Discrete ❌ ✔️
+Box ✔️ ✔️
+MultiDiscrete ❌ ✔️
+MultiBinary ❌ ✔️
+============= ====== ===========
+
+
+Example
+-------
+
+.. code-block:: python
+
+ import gym
+ import numpy as np
+
+ from stable_baselines import TD3
+ from stable_baselines.td3.policies import MlpPolicy
+ from stable_baselines.common.vec_env import DummyVecEnv
+ from stable_baselines.ddpg.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
+
+ env = gym.make('Pendulum-v0')
+ env = DummyVecEnv([lambda: env])
+
+ # The noise objects for TD3
+ n_actions = env.action_space.shape[-1]
+ action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
+
+ model = TD3(MlpPolicy, env, action_noise=action_noise, verbose=1)
+ model.learn(total_timesteps=50000, log_interval=10)
+ model.save("td3_pendulum")
+
+ del model # remove to demonstrate saving and loading
+
+ model = TD3.load("td3_pendulum")
+
+ obs = env.reset()
+ while True:
+ action, _states = model.predict(obs)
+ obs, rewards, dones, info = env.step(action)
+ env.render()
+
+Parameters
+----------
+
+.. autoclass:: TD3
+ :members:
+ :inherited-members:
+
+.. _td3_policies:
+
+TD3 Policies
+-------------
+
+.. autoclass:: MlpPolicy
+ :members:
+ :inherited-members:
+
+
+.. autoclass:: LnMlpPolicy
+ :members:
+ :inherited-members:
+
+
+.. autoclass:: CnnPolicy
+ :members:
+ :inherited-members:
+
+
+.. autoclass:: LnCnnPolicy
+ :members:
+ :inherited-members:
+
+
+Custom Policy Network
+---------------------
+
+Similarly to the example given in the `examples <../guide/custom_policy.html>`_ page.
+You can easily define a custom architecture for the policy network:
+
+.. code-block:: python
+
+ import gym
+ import numpy as np
+
+ from stable_baselines import TD3
+ from stable_baselines.td3.policies import FeedForwardPolicy
+ from stable_baselines.common.vec_env import DummyVecEnv
+ from stable_baselines.ddpg.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
+
+ # Custom MLP policy with two layers
+ class CustomTD3Policy(FeedForwardPolicy):
+ def __init__(self, *args, **kwargs):
+ super(CustomTD3Policy, self).__init__(*args, **kwargs,
+ layers=[400, 300],
+ layer_norm=False,
+ feature_extraction="mlp")
+
+ # Create and wrap the environment
+ env = gym.make('Pendulum-v0')
+ env = DummyVecEnv([lambda: env])
+
+ # The noise objects for TD3
+ n_actions = env.action_space.shape[-1]
+ action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
+
+
+ model = TD3(CustomTD3Policy, env, action_noise=action_noise, verbose=1)
+ # Train the agent
+ model.learn(total_timesteps=80000)
diff --git a/setup.py b/setup.py
index 6fca25bd43..336c4f2f8f 100644
--- a/setup.py
+++ b/setup.py
@@ -50,6 +50,7 @@
- PEP8 compliant (unified code style)
- Documented functions and classes
- More tests & more code coverage
+- Additional algorithms: SAC and TD3 (+ HER support for DQN, DDPG, SAC and TD3)
## Links
@@ -137,7 +138,7 @@
license="MIT",
long_description=long_description,
long_description_content_type='text/markdown',
- version="2.6.1a0",
+ version="2.7.0a0",
)
# python setup.py sdist
diff --git a/stable_baselines/__init__.py b/stable_baselines/__init__.py
index 67704e0846..93d1b0e193 100644
--- a/stable_baselines/__init__.py
+++ b/stable_baselines/__init__.py
@@ -7,6 +7,7 @@
from stable_baselines.gail import GAIL
from stable_baselines.ppo1 import PPO1
from stable_baselines.ppo2 import PPO2
+from stable_baselines.td3 import TD3
from stable_baselines.trpo_mpi import TRPO
from stable_baselines.sac import SAC
diff --git a/stable_baselines/common/base_class.py b/stable_baselines/common/base_class.py
index f371758837..14e892a88b 100644
--- a/stable_baselines/common/base_class.py
+++ b/stable_baselines/common/base_class.py
@@ -428,7 +428,6 @@ def save(self, save_path):
:param save_path: (str or file-like object) the save location
"""
- # self._save_to_file(save_path, data={}, params=None)
raise NotImplementedError()
@classmethod
@@ -442,7 +441,6 @@ def load(cls, load_path, env=None, **kwargs):
(can be None if you only need prediction from a trained model)
:param kwargs: extra arguments to change the model when loading
"""
- # data, param = cls._load_from_file(load_path)
raise NotImplementedError()
@staticmethod
@@ -681,6 +679,14 @@ def save(self, save_path):
@classmethod
def load(cls, load_path, env=None, **kwargs):
+ """
+ Load the model from file
+
+ :param load_path: (str or file-like) the saved parameter location
+ :param env: (Gym Envrionment) the new environment to run the loaded model on
+ (can be None if you only need prediction from a trained model)
+ :param kwargs: extra arguments to change the model when loading
+ """
data, params = cls._load_from_file(load_path)
if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']:
@@ -712,7 +718,8 @@ class OffPolicyRLModel(BaseRLModel):
:param policy_base: (BasePolicy) the base policy used by this method
"""
- def __init__(self, policy, env, replay_buffer, verbose=0, *, requires_vec_env, policy_base, policy_kwargs=None):
+ def __init__(self, policy, env, replay_buffer=None, _init_setup_model=False, verbose=0, *,
+ requires_vec_env=False, policy_base=None, policy_kwargs=None):
super(OffPolicyRLModel, self).__init__(policy, env, verbose=verbose, requires_vec_env=requires_vec_env,
policy_base=policy_base, policy_kwargs=policy_kwargs)
@@ -740,10 +747,31 @@ def save(self, save_path):
pass
@classmethod
- @abstractmethod
def load(cls, load_path, env=None, **kwargs):
- pass
+ """
+ Load the model from file
+ :param load_path: (str or file-like) the saved parameter location
+ :param env: (Gym Envrionment) the new environment to run the loaded model on
+ (can be None if you only need prediction from a trained model)
+ :param kwargs: extra arguments to change the model when loading
+ """
+ data, params = cls._load_from_file(load_path)
+
+ if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']:
+ raise ValueError("The specified policy kwargs do not equal the stored policy kwargs. "
+ "Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'],
+ kwargs['policy_kwargs']))
+
+ model = cls(policy=data["policy"], env=None, _init_setup_model=False)
+ model.__dict__.update(data)
+ model.__dict__.update(kwargs)
+ model.set_env(env)
+ model.setup_model()
+
+ model.load_parameters(params)
+
+ return model
class _UnvecWrapper(VecEnvWrapper):
def __init__(self, venv):
diff --git a/stable_baselines/ddpg/ddpg.py b/stable_baselines/ddpg/ddpg.py
index 0401350445..eedc38301b 100644
--- a/stable_baselines/ddpg/ddpg.py
+++ b/stable_baselines/ddpg/ddpg.py
@@ -98,9 +98,20 @@ def get_target_updates(_vars, target_vars, tau, verbose=0):
return tf.group(*init_updates), tf.group(*soft_updates)
+def get_perturbable_vars(scope):
+ """
+ Get the trainable variables that can be perturbed when using
+ parameter noise.
+
+ :param scope: (str) tensorflow scope of the variables
+ :return: ([tf.Variables])
+ """
+ return [var for var in tf_util.get_trainable_vars(scope) if 'LayerNorm' not in var.name]
+
+
def get_perturbed_actor_updates(actor, perturbed_actor, param_noise_stddev, verbose=0):
"""
- get the actor update, with noise.
+ Get the actor update, with noise.
:param actor: (str) the actor
:param perturbed_actor: (str) the pertubed actor
@@ -108,19 +119,15 @@ def get_perturbed_actor_updates(actor, perturbed_actor, param_noise_stddev, verb
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
:return: (TensorFlow Operation) the update function
"""
- # TODO: simplify this to this:
- # assert len(actor.vars) == len(perturbed_actor.vars)
- # assert len(actor.perturbable_vars) == len(perturbed_actor.perturbable_vars)
-
assert len(tf_util.get_globals_vars(actor)) == len(tf_util.get_globals_vars(perturbed_actor))
- assert len([var for var in tf_util.get_trainable_vars(actor) if 'LayerNorm' not in var.name]) == \
- len([var for var in tf_util.get_trainable_vars(perturbed_actor) if 'LayerNorm' not in var.name])
+ assert len(get_perturbable_vars(actor)) == len(get_perturbable_vars(perturbed_actor))
updates = []
for var, perturbed_var in zip(tf_util.get_globals_vars(actor), tf_util.get_globals_vars(perturbed_actor)):
- if var in [var for var in tf_util.get_trainable_vars(actor) if 'LayerNorm' not in var.name]:
+ if var in get_perturbable_vars(actor):
if verbose >= 2:
logger.info(' {} <- {} + noise'.format(perturbed_var.name, var.name))
+ # Add gaussian noise to the parameter
updates.append(tf.assign(perturbed_var,
var + tf.random_normal(tf.shape(var), mean=0., stddev=param_noise_stddev)))
else:
@@ -278,7 +285,7 @@ def __init__(self, policy, env, gamma=0.99, memory_policy=None, eval_env=None, n
self.action_noise_ph = None
self.obs_adapt_noise = None
self.action_adapt_noise = None
- self.terminals1 = None
+ self.terminals_ph = None
self.rewards = None
self.actions = None
self.critic_target = None
@@ -341,9 +348,9 @@ def setup_model(self):
self.obs_target = self.target_policy.obs_ph
self.action_target = self.target_policy.action_ph
- normalized_obs0 = tf.clip_by_value(normalize(self.policy_tf.processed_obs, self.obs_rms),
+ normalized_obs = tf.clip_by_value(normalize(self.policy_tf.processed_obs, self.obs_rms),
self.observation_range[0], self.observation_range[1])
- normalized_obs1 = tf.clip_by_value(normalize(self.target_policy.processed_obs, self.obs_rms),
+ normalized_next_obs = tf.clip_by_value(normalize(self.target_policy.processed_obs, self.obs_rms),
self.observation_range[0], self.observation_range[1])
if self.param_noise is not None:
@@ -363,7 +370,7 @@ def setup_model(self):
# Inputs.
self.obs_train = self.policy_tf.obs_ph
self.action_train_ph = self.policy_tf.action_ph
- self.terminals1 = tf.placeholder(tf.float32, shape=(None, 1), name='terminals1')
+ self.terminals_ph = tf.placeholder(tf.float32, shape=(None, 1), name='terminals')
self.rewards = tf.placeholder(tf.float32, shape=(None, 1), name='rewards')
self.actions = tf.placeholder(tf.float32, shape=(None,) + self.action_space.shape, name='actions')
self.critic_target = tf.placeholder(tf.float32, shape=(None, 1), name='critic_target')
@@ -371,18 +378,18 @@ def setup_model(self):
# Create networks and core TF parts that are shared across setup parts.
with tf.variable_scope("model", reuse=False):
- self.actor_tf = self.policy_tf.make_actor(normalized_obs0)
- self.normalized_critic_tf = self.policy_tf.make_critic(normalized_obs0, self.actions)
- self.normalized_critic_with_actor_tf = self.policy_tf.make_critic(normalized_obs0,
+ self.actor_tf = self.policy_tf.make_actor(normalized_obs)
+ self.normalized_critic_tf = self.policy_tf.make_critic(normalized_obs, self.actions)
+ self.normalized_critic_with_actor_tf = self.policy_tf.make_critic(normalized_obs,
self.actor_tf,
reuse=True)
# Noise setup
if self.param_noise is not None:
- self._setup_param_noise(normalized_obs0)
+ self._setup_param_noise(normalized_obs)
with tf.variable_scope("target", reuse=False):
- critic_target = self.target_policy.make_critic(normalized_obs1,
- self.target_policy.make_actor(normalized_obs1))
+ critic_target = self.target_policy.make_critic(normalized_next_obs,
+ self.target_policy.make_actor(normalized_next_obs))
with tf.variable_scope("loss", reuse=False):
self.critic_tf = denormalize(
@@ -394,8 +401,8 @@ def setup_model(self):
self.return_range[0], self.return_range[1]),
self.ret_rms)
- q_obs1 = denormalize(critic_target, self.ret_rms)
- self.target_q = self.rewards + (1. - self.terminals1) * self.gamma * q_obs1
+ q_next_obs = denormalize(critic_target, self.ret_rms)
+ self.target_q = self.rewards + (1. - self.terminals_ph) * self.gamma * q_next_obs
tf.summary.scalar('critic_target', tf.reduce_mean(self.critic_target))
if self.full_tensorboard_log:
@@ -449,19 +456,19 @@ def _setup_target_network_updates(self):
self.target_init_updates = init_updates
self.target_soft_updates = soft_updates
- def _setup_param_noise(self, normalized_obs0):
+ def _setup_param_noise(self, normalized_obs):
"""
- set the parameter noise operations
+ Setup the parameter noise operations
- :param normalized_obs0: (TensorFlow Tensor) the normalized observation
+ :param normalized_obs: (TensorFlow Tensor) the normalized observation
"""
assert self.param_noise is not None
with tf.variable_scope("noise", reuse=False):
- self.perturbed_actor_tf = self.param_noise_actor.make_actor(normalized_obs0)
+ self.perturbed_actor_tf = self.param_noise_actor.make_actor(normalized_obs)
with tf.variable_scope("noise_adapt", reuse=False):
- adaptive_actor_tf = self.adaptive_param_noise_actor.make_actor(normalized_obs0)
+ adaptive_actor_tf = self.adaptive_param_noise_actor.make_actor(normalized_obs)
with tf.variable_scope("noise_update_func", reuse=False):
if self.verbose >= 2:
@@ -549,10 +556,24 @@ def _setup_popart(self):
def _setup_stats(self):
"""
- setup the running means and std of the inputs and outputs of the model
+ Setup the stat logger for DDPG.
"""
- ops = []
- names = []
+ ops = [
+ tf.reduce_mean(self.critic_tf),
+ reduce_std(self.critic_tf),
+ tf.reduce_mean(self.critic_with_actor_tf),
+ reduce_std(self.critic_with_actor_tf),
+ tf.reduce_mean(self.actor_tf),
+ reduce_std(self.actor_tf)
+ ]
+ names = [
+ 'reference_Q_mean',
+ 'reference_Q_std',
+ 'reference_actor_Q_mean',
+ 'reference_actor_Q_std',
+ 'reference_action_mean',
+ 'reference_action_std'
+ ]
if self.normalize_returns:
ops += [self.ret_rms.mean, self.ret_rms.std]
@@ -562,26 +583,9 @@ def _setup_stats(self):
ops += [tf.reduce_mean(self.obs_rms.mean), tf.reduce_mean(self.obs_rms.std)]
names += ['obs_rms_mean', 'obs_rms_std']
- ops += [tf.reduce_mean(self.critic_tf)]
- names += ['reference_Q_mean']
- ops += [reduce_std(self.critic_tf)]
- names += ['reference_Q_std']
-
- ops += [tf.reduce_mean(self.critic_with_actor_tf)]
- names += ['reference_actor_Q_mean']
- ops += [reduce_std(self.critic_with_actor_tf)]
- names += ['reference_actor_Q_std']
-
- ops += [tf.reduce_mean(self.actor_tf)]
- names += ['reference_action_mean']
- ops += [reduce_std(self.actor_tf)]
- names += ['reference_action_std']
-
if self.param_noise:
- ops += [tf.reduce_mean(self.perturbed_actor_tf)]
- names += ['reference_perturbed_action_mean']
- ops += [reduce_std(self.perturbed_actor_tf)]
- names += ['reference_perturbed_action_std']
+ ops += [tf.reduce_mean(self.perturbed_actor_tf), reduce_std(self.perturbed_actor_tf)]
+ names += ['reference_perturbed_action_mean', 'reference_perturbed_action_std']
self.stats_ops = ops
self.stats_names = names
@@ -617,20 +621,20 @@ def _policy(self, obs, apply_noise=True, compute_q=True):
action = np.clip(action, -1, 1)
return action, q_value
- def _store_transition(self, obs0, action, reward, obs1, terminal1):
+ def _store_transition(self, obs, action, reward, next_obs, done):
"""
Store a transition in the replay buffer
- :param obs0: ([float] or [int]) the last observation
+ :param obs: ([float] or [int]) the last observation
:param action: ([float]) the action
:param reward: (float] the reward
- :param obs1: ([float] or [int]) the current observation
- :param terminal1: (bool) Whether the episode is over
+ :param next_obs: ([float] or [int]) the current observation
+ :param done: (bool) Whether the episode is over
"""
reward *= self.reward_scale
- self.replay_buffer.add(obs0, action, reward, obs1, float(terminal1))
+ self.replay_buffer.add(obs, action, reward, next_obs, float(done))
if self.normalize_observations:
- self.obs_rms.update(np.array([obs0]))
+ self.obs_rms.update(np.array([obs]))
def _train_step(self, step, writer, log=False):
"""
@@ -642,17 +646,17 @@ def _train_step(self, step, writer, log=False):
:return: (float, float) critic loss, actor loss
"""
# Get a batch
- obs0, actions, rewards, obs1, terminals1 = self.replay_buffer.sample(batch_size=self.batch_size)
+ obs, actions, rewards, next_obs, terminals = self.replay_buffer.sample(batch_size=self.batch_size)
# Reshape to match previous behavior and placeholder shape
rewards = rewards.reshape(-1, 1)
- terminals1 = terminals1.reshape(-1, 1)
+ terminals = terminals.reshape(-1, 1)
if self.normalize_returns and self.enable_popart:
old_mean, old_std, target_q = self.sess.run([self.ret_rms.mean, self.ret_rms.std, self.target_q],
feed_dict={
- self.obs_target: obs1,
+ self.obs_target: next_obs,
self.rewards: rewards,
- self.terminals1: terminals1
+ self.terminals_ph: terminals
})
self.ret_rms.update(target_q.flatten())
self.sess.run(self.renormalize_q_outputs_op, feed_dict={
@@ -662,15 +666,15 @@ def _train_step(self, step, writer, log=False):
else:
target_q = self.sess.run(self.target_q, feed_dict={
- self.obs_target: obs1,
+ self.obs_target: next_obs,
self.rewards: rewards,
- self.terminals1: terminals1
+ self.terminals_ph: terminals
})
# Get all gradients and perform a synced update.
ops = [self.actor_grads, self.actor_loss, self.critic_grads, self.critic_loss]
td_map = {
- self.obs_train: obs0,
+ self.obs_train: obs,
self.actions: actions,
self.action_train_ph: actions,
self.rewards: rewards,
@@ -727,13 +731,13 @@ def _get_stats(self):
if self.stats_sample is None:
# Get a sample and keep that fixed for all further computations.
# This allows us to estimate the change in value for the same set of inputs.
- obs0, actions, rewards, obs1, terminals1 = self.replay_buffer.sample(batch_size=self.batch_size)
+ obs, actions, rewards, next_obs, terminals = self.replay_buffer.sample(batch_size=self.batch_size)
self.stats_sample = {
- 'obs0': obs0,
+ 'obs': obs,
'actions': actions,
'rewards': rewards,
- 'obs1': obs1,
- 'terminals1': terminals1
+ 'next_obs': next_obs,
+ 'terminals': terminals
}
feed_dict = {
@@ -746,7 +750,7 @@ def _get_stats(self):
for placeholder in [self.obs_train, self.obs_target, self.obs_adapt_noise, self.obs_noise]:
if placeholder is not None:
- feed_dict[placeholder] = self.stats_sample['obs0']
+ feed_dict[placeholder] = self.stats_sample['obs']
values = self.sess.run(self.stats_ops, feed_dict=feed_dict)
@@ -769,12 +773,12 @@ def _adapt_param_noise(self):
return 0.
# Perturb a separate copy of the policy to adjust the scale for the next "real" perturbation.
- obs0, *_ = self.replay_buffer.sample(batch_size=self.batch_size)
+ obs, *_ = self.replay_buffer.sample(batch_size=self.batch_size)
self.sess.run(self.perturb_adaptive_policy_ops, feed_dict={
self.param_noise_stddev: self.param_noise.current_stddev,
})
distance = self.sess.run(self.adaptive_policy_distance, feed_dict={
- self.obs_adapt_noise: obs0, self.obs_train: obs0,
+ self.obs_adapt_noise: obs, self.obs_train: obs,
self.param_noise_stddev: self.param_noise.current_stddev,
})
@@ -1040,7 +1044,7 @@ def predict(self, observation, state=None, mask=None, deterministic=True):
return actions, None
def action_probability(self, observation, state=None, mask=None, actions=None, logp=False):
- observation = np.array(observation)
+ _ = np.array(observation)
if actions is not None:
raise ValueError("Error: DDPG does not have action probabilities.")
diff --git a/stable_baselines/deepq/dqn.py b/stable_baselines/deepq/dqn.py
index 70a3ef462d..55d1928bcf 100644
--- a/stable_baselines/deepq/dqn.py
+++ b/stable_baselines/deepq/dqn.py
@@ -362,22 +362,3 @@ def save(self, save_path):
params_to_save = self.get_parameters()
self._save_to_file(save_path, data=data, params=params_to_save)
-
- @classmethod
- def load(cls, load_path, env=None, **kwargs):
- data, params = cls._load_from_file(load_path)
-
- if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']:
- raise ValueError("The specified policy kwargs do not equal the stored policy kwargs. "
- "Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'],
- kwargs['policy_kwargs']))
-
- model = cls(policy=data["policy"], env=env, _init_setup_model=False)
- model.__dict__.update(data)
- model.__dict__.update(kwargs)
- model.set_env(env)
- model.setup_model()
-
- model.load_parameters(params)
-
- return model
diff --git a/stable_baselines/her/her.py b/stable_baselines/her/her.py
index 2b5c8f88d2..fc53e77b0f 100644
--- a/stable_baselines/her/her.py
+++ b/stable_baselines/her/her.py
@@ -47,7 +47,7 @@ def __init__(self, policy, env, model_class, n_sampled_goal=4,
self._create_replay_wrapper(self.env)
assert issubclass(model_class, OffPolicyRLModel), \
- "Error: HER only works with Off policy model (such as DDPG, SAC and DQN)."
+ "Error: HER only works with Off policy model (such as DDPG, SAC, TD3 and DQN)."
self.model = self.model_class(policy, self.env, *args, **kwargs)
# Patch to support saving/loading
diff --git a/stable_baselines/sac/policies.py b/stable_baselines/sac/policies.py
index 53fefd4db9..2d2c5053cc 100644
--- a/stable_baselines/sac/policies.py
+++ b/stable_baselines/sac/policies.py
@@ -175,7 +175,6 @@ class FeedForwardPolicy(SACPolicy):
:param feature_extraction: (str) The feature extraction type ("cnn" or "mlp")
:param layer_norm: (bool) enable layer normalisation
:param reg_weight: (float) Regularization loss weight for the policy parameters
- :param reg_weight: (float) Regularization loss weight for the policy parameters
:param act_fun: (tf.func) the activation function to use in the neural network.
:param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
"""
diff --git a/stable_baselines/sac/sac.py b/stable_baselines/sac/sac.py
index 7804a6c922..4a323ea3a6 100644
--- a/stable_baselines/sac/sac.py
+++ b/stable_baselines/sac/sac.py
@@ -221,7 +221,7 @@ def setup_model(self):
# Take the min of the two Q-Values (Double-Q Learning)
min_qf_pi = tf.minimum(qf1_pi, qf2_pi)
- # Targets for Q and V regression
+ # Target for Q value regression
q_backup = tf.stop_gradient(
self.rewards_ph +
(1 - self.terminals_ph) * self.gamma * self.value_target
@@ -250,6 +250,8 @@ def setup_model(self):
# policy_loss = (policy_kl_loss + policy_regularization_loss)
policy_loss = policy_kl_loss
+
+ # Target for value fn regression
# We update the vf towards the min of two Q-functions in order to
# reduce overestimation bias from function approximation error.
v_backup = tf.stop_gradient(min_qf_pi - self.ent_coef * logp_pi)
@@ -553,22 +555,3 @@ def save(self, save_path):
params_to_save = self.get_parameters()
self._save_to_file(save_path, data=data, params=params_to_save)
-
- @classmethod
- def load(cls, load_path, env=None, **kwargs):
- data, params = cls._load_from_file(load_path)
-
- if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']:
- raise ValueError("The specified policy kwargs do not equal the stored policy kwargs. "
- "Stored kwargs: {}, specified kwargs: {}".format(data['policy_kwargs'],
- kwargs['policy_kwargs']))
-
- model = cls(policy=data["policy"], env=env, _init_setup_model=False)
- model.__dict__.update(data)
- model.__dict__.update(kwargs)
- model.set_env(env)
- model.setup_model()
-
- model.load_parameters(params)
-
- return model
diff --git a/stable_baselines/td3/__init__.py b/stable_baselines/td3/__init__.py
new file mode 100644
index 0000000000..95b15ff518
--- /dev/null
+++ b/stable_baselines/td3/__init__.py
@@ -0,0 +1,3 @@
+from stable_baselines.td3.td3 import TD3
+from stable_baselines.td3.policies import MlpPolicy, CnnPolicy, LnMlpPolicy, LnCnnPolicy
+from stable_baselines.ddpg.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
diff --git a/stable_baselines/td3/policies.py b/stable_baselines/td3/policies.py
new file mode 100644
index 0000000000..9e0c83fd6b
--- /dev/null
+++ b/stable_baselines/td3/policies.py
@@ -0,0 +1,244 @@
+import tensorflow as tf
+import numpy as np
+from gym.spaces import Box
+
+from stable_baselines.common.policies import BasePolicy, nature_cnn, register_policy
+from stable_baselines.sac.policies import mlp
+
+
+class TD3Policy(BasePolicy):
+ """
+ Policy object that implements a TD3-like actor critic
+
+ :param sess: (TensorFlow session) The current TensorFlow session
+ :param ob_space: (Gym Space) The observation space of the environment
+ :param ac_space: (Gym Space) The action space of the environment
+ :param n_env: (int) The number of environments to run
+ :param n_steps: (int) The number of steps to run for each environment
+ :param n_batch: (int) The number of batch to run (n_envs * n_steps)
+ :param reuse: (bool) If the policy is reusable or not
+ :param scale: (bool) whether or not to scale the input
+ """
+
+ def __init__(self, sess, ob_space, ac_space, n_env=1, n_steps=1, n_batch=None, reuse=False, scale=False):
+ super(TD3Policy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse, scale=scale)
+ assert isinstance(ac_space, Box), "Error: the action space must be of type gym.spaces.Box"
+ assert (np.abs(ac_space.low) == ac_space.high).all(), "Error: the action space low and high must be symmetric"
+
+ self.qf1 = None
+ self.qf2 = None
+ self.policy = None
+
+ def make_actor(self, obs=None, reuse=False, scope="pi"):
+ """
+ Creates an actor object
+
+ :param obs: (TensorFlow Tensor) The observation placeholder (can be None for default placeholder)
+ :param reuse: (bool) whether or not to resue parameters
+ :param scope: (str) the scope name of the actor
+ :return: (TensorFlow Tensor) the output tensor
+ """
+ raise NotImplementedError
+
+ def make_critics(self, obs=None, action=None, reuse=False,
+ scope="qvalues_fn"):
+ """
+ Creates the two Q-Values approximator
+
+ :param obs: (TensorFlow Tensor) The observation placeholder (can be None for default placeholder)
+ :param action: (TensorFlow Tensor) The action placeholder
+ :param reuse: (bool) whether or not to resue parameters
+ :param scope: (str) the scope name
+ :return: ([tf.Tensor]) Mean, action and log probability
+ """
+ raise NotImplementedError
+
+ def step(self, obs, state=None, mask=None):
+ """
+ Returns the policy for a single step
+
+ :param obs: ([float] or [int]) The current observation of the environment
+ :param state: ([float]) The last states (used in recurrent policies)
+ :param mask: ([float]) The last masks (used in recurrent policies)
+ :return: ([float]) actions
+ """
+ raise NotImplementedError
+
+ def proba_step(self, obs, state=None, mask=None):
+ """
+ Returns the policy for a single step
+
+ :param obs: ([float] or [int]) The current observation of the environment
+ :param state: ([float]) The last states (used in recurrent policies)
+ :param mask: ([float]) The last masks (used in recurrent policies)
+ :return: ([float]) actions
+ """
+ return self.step(obs, state, mask)
+
+
+class FeedForwardPolicy(TD3Policy):
+ """
+ Policy object that implements a DDPG-like actor critic, using a feed forward neural network.
+
+ :param sess: (TensorFlow session) The current TensorFlow session
+ :param ob_space: (Gym Space) The observation space of the environment
+ :param ac_space: (Gym Space) The action space of the environment
+ :param n_env: (int) The number of environments to run
+ :param n_steps: (int) The number of steps to run for each environment
+ :param n_batch: (int) The number of batch to run (n_envs * n_steps)
+ :param reuse: (bool) If the policy is reusable or not
+ :param layers: ([int]) The size of the Neural network for the policy (if None, default to [64, 64])
+ :param cnn_extractor: (function (TensorFlow Tensor, ``**kwargs``): (TensorFlow Tensor)) the CNN feature extraction
+ :param feature_extraction: (str) The feature extraction type ("cnn" or "mlp")
+ :param layer_norm: (bool) enable layer normalisation
+ :param act_fun: (tf.func) the activation function to use in the neural network.
+ :param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
+ """
+
+ def __init__(self, sess, ob_space, ac_space, n_env=1, n_steps=1, n_batch=None, reuse=False, layers=None,
+ cnn_extractor=nature_cnn, feature_extraction="cnn",
+ layer_norm=False, act_fun=tf.nn.relu, **kwargs):
+ super(FeedForwardPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch,
+ reuse=reuse, scale=(feature_extraction == "cnn"))
+
+ self._kwargs_check(feature_extraction, kwargs)
+ self.layer_norm = layer_norm
+ self.feature_extraction = feature_extraction
+ self.cnn_kwargs = kwargs
+ self.cnn_extractor = cnn_extractor
+ self.reuse = reuse
+ if layers is None:
+ layers = [64, 64]
+ self.layers = layers
+
+ assert len(layers) >= 1, "Error: must have at least one hidden layer for the policy."
+
+ self.activ_fn = act_fun
+
+ def make_actor(self, obs=None, reuse=False, scope="pi"):
+ if obs is None:
+ obs = self.processed_obs
+
+ with tf.variable_scope(scope, reuse=reuse):
+ if self.feature_extraction == "cnn":
+ pi_h = self.cnn_extractor(obs, **self.cnn_kwargs)
+ else:
+ pi_h = tf.layers.flatten(obs)
+
+ pi_h = mlp(pi_h, self.layers, self.activ_fn, layer_norm=self.layer_norm)
+
+ self.policy = policy = tf.layers.dense(pi_h, self.ac_space.shape[0], activation=tf.tanh)
+
+ return policy
+
+ def make_critics(self, obs=None, action=None, reuse=False, scope="values_fn"):
+ if obs is None:
+ obs = self.processed_obs
+
+ with tf.variable_scope(scope, reuse=reuse):
+ if self.feature_extraction == "cnn":
+ critics_h = self.cnn_extractor(obs, **self.cnn_kwargs)
+ else:
+ critics_h = tf.layers.flatten(obs)
+
+ # Concatenate preprocessed state and action
+ qf_h = tf.concat([critics_h, action], axis=-1)
+
+ # Double Q values to reduce overestimation
+ with tf.variable_scope('qf1', reuse=reuse):
+ qf1_h = mlp(qf_h, self.layers, self.activ_fn, layer_norm=self.layer_norm)
+ qf1 = tf.layers.dense(qf1_h, 1, name="qf1")
+
+ with tf.variable_scope('qf2', reuse=reuse):
+ qf2_h = mlp(qf_h, self.layers, self.activ_fn, layer_norm=self.layer_norm)
+ qf2 = tf.layers.dense(qf2_h, 1, name="qf2")
+
+ self.qf1 = qf1
+ self.qf2 = qf2
+
+ return self.qf1, self.qf2
+
+ def step(self, obs, state=None, mask=None):
+ return self.sess.run(self.policy, {self.obs_ph: obs})
+
+
+class CnnPolicy(FeedForwardPolicy):
+ """
+ Policy object that implements actor critic, using a CNN (the nature CNN)
+
+ :param sess: (TensorFlow session) The current TensorFlow session
+ :param ob_space: (Gym Space) The observation space of the environment
+ :param ac_space: (Gym Space) The action space of the environment
+ :param n_env: (int) The number of environments to run
+ :param n_steps: (int) The number of steps to run for each environment
+ :param n_batch: (int) The number of batch to run (n_envs * n_steps)
+ :param reuse: (bool) If the policy is reusable or not
+ :param _kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
+ """
+
+ def __init__(self, sess, ob_space, ac_space, n_env=1, n_steps=1, n_batch=None, reuse=False, **_kwargs):
+ super(CnnPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse,
+ feature_extraction="cnn", **_kwargs)
+
+
+class LnCnnPolicy(FeedForwardPolicy):
+ """
+ Policy object that implements actor critic, using a CNN (the nature CNN), with layer normalisation
+
+ :param sess: (TensorFlow session) The current TensorFlow session
+ :param ob_space: (Gym Space) The observation space of the environment
+ :param ac_space: (Gym Space) The action space of the environment
+ :param n_env: (int) The number of environments to run
+ :param n_steps: (int) The number of steps to run for each environment
+ :param n_batch: (int) The number of batch to run (n_envs * n_steps)
+ :param reuse: (bool) If the policy is reusable or not
+ :param _kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
+ """
+
+ def __init__(self, sess, ob_space, ac_space, n_env=1, n_steps=1, n_batch=None, reuse=False, **_kwargs):
+ super(LnCnnPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse,
+ feature_extraction="cnn", layer_norm=True, **_kwargs)
+
+
+class MlpPolicy(FeedForwardPolicy):
+ """
+ Policy object that implements actor critic, using a MLP (2 layers of 64)
+
+ :param sess: (TensorFlow session) The current TensorFlow session
+ :param ob_space: (Gym Space) The observation space of the environment
+ :param ac_space: (Gym Space) The action space of the environment
+ :param n_env: (int) The number of environments to run
+ :param n_steps: (int) The number of steps to run for each environment
+ :param n_batch: (int) The number of batch to run (n_envs * n_steps)
+ :param reuse: (bool) If the policy is reusable or not
+ :param _kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
+ """
+
+ def __init__(self, sess, ob_space, ac_space, n_env=1, n_steps=1, n_batch=None, reuse=False, **_kwargs):
+ super(MlpPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse,
+ feature_extraction="mlp", **_kwargs)
+
+
+class LnMlpPolicy(FeedForwardPolicy):
+ """
+ Policy object that implements actor critic, using a MLP (2 layers of 64), with layer normalisation
+
+ :param sess: (TensorFlow session) The current TensorFlow session
+ :param ob_space: (Gym Space) The observation space of the environment
+ :param ac_space: (Gym Space) The action space of the environment
+ :param n_env: (int) The number of environments to run
+ :param n_steps: (int) The number of steps to run for each environment
+ :param n_batch: (int) The number of batch to run (n_envs * n_steps)
+ :param reuse: (bool) If the policy is reusable or not
+ :param _kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
+ """
+
+ def __init__(self, sess, ob_space, ac_space, n_env=1, n_steps=1, n_batch=None, reuse=False, **_kwargs):
+ super(LnMlpPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse,
+ feature_extraction="mlp", layer_norm=True, **_kwargs)
+
+
+register_policy("CnnPolicy", CnnPolicy)
+register_policy("LnCnnPolicy", LnCnnPolicy)
+register_policy("MlpPolicy", MlpPolicy)
+register_policy("LnMlpPolicy", LnMlpPolicy)
diff --git a/stable_baselines/td3/td3.py b/stable_baselines/td3/td3.py
new file mode 100644
index 0000000000..905a3eb8b7
--- /dev/null
+++ b/stable_baselines/td3/td3.py
@@ -0,0 +1,475 @@
+import sys
+import time
+import multiprocessing
+from collections import deque
+import warnings
+
+import numpy as np
+import tensorflow as tf
+
+from stable_baselines.a2c.utils import total_episode_reward_logger
+from stable_baselines.common import tf_util, OffPolicyRLModel, SetVerbosity, TensorboardWriter
+from stable_baselines.common.vec_env import VecEnv
+from stable_baselines.deepq.replay_buffer import ReplayBuffer
+from stable_baselines.ppo2.ppo2 import safe_mean, get_schedule_fn
+from stable_baselines.sac.sac import get_vars
+from stable_baselines.td3.policies import TD3Policy
+from stable_baselines import logger
+
+
+class TD3(OffPolicyRLModel):
+ """
+ Twin Delayed DDPG (TD3)
+ Addressing Function Approximation Error in Actor-Critic Methods.
+
+ Original implementation: https://github.com/sfujim/TD3
+ Paper: https://arxiv.org/pdf/1802.09477.pdf
+ Introduction to TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html
+
+ :param policy: (TD3Policy or str) The policy model to use (MlpPolicy, CnnPolicy, LnMlpPolicy, ...)
+ :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
+ :param gamma: (float) the discount factor
+ :param learning_rate: (float or callable) learning rate for adam optimizer,
+ the same learning rate will be used for all networks (Q-Values and Actor networks)
+ it can be a function of the current progress (from 1 to 0)
+ :param buffer_size: (int) size of the replay buffer
+ :param batch_size: (int) Minibatch size for each gradient update
+ :param tau: (float) the soft update coefficient ("polyak update" of the target networks, between 0 and 1)
+ :param policy_delay: (int) Policy and target networks will only be updated once every policy_delay steps
+ per training steps. The Q values will be updated policy_delay more often (update every training step).
+ :param action_noise: (ActionNoise) the action noise type. Cf DDPG for the different action noise type.
+ :param target_policy_noise: (float) Standard deviation of gaussian noise added to target policy
+ (smoothing noise)
+ :param target_noise_clip: (float) Limit for absolute value of target policy smoothing noise.
+ :param train_freq: (int) Update the model every `train_freq` steps.
+ :param learning_starts: (int) how many steps of the model to collect transitions for before learning starts
+ :param gradient_steps: (int) How many gradient update after each step
+ :param random_exploration: (float) Probability of taking a random action (as in an epsilon-greedy strategy)
+ This is not needed for TD3 normally but can help exploring when using HER + TD3.
+ This hack was present in the original OpenAI Baselines repo (DDPG + HER)
+ :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
+ :param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
+ :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
+ :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
+ :param full_tensorboard_log: (bool) enable additional logging when using tensorboard
+ Note: this has no effect on TD3 logging for now
+ """
+
+ def __init__(self, policy, env, gamma=0.99, learning_rate=3e-4, buffer_size=50000,
+ learning_starts=100, train_freq=100, gradient_steps=100, batch_size=128,
+ tau=0.005, policy_delay=2, action_noise=None,
+ target_policy_noise=0.2, target_noise_clip=0.5,
+ random_exploration=0.0, verbose=0, tensorboard_log=None,
+ _init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False):
+
+ super(TD3, self).__init__(policy=policy, env=env, replay_buffer=None, verbose=verbose,
+ policy_base=TD3Policy, requires_vec_env=False, policy_kwargs=policy_kwargs)
+
+ self.buffer_size = buffer_size
+ self.learning_rate = learning_rate
+ self.learning_starts = learning_starts
+ self.train_freq = train_freq
+ self.batch_size = batch_size
+ self.tau = tau
+ self.gradient_steps = gradient_steps
+ self.gamma = gamma
+ self.action_noise = action_noise
+ self.random_exploration = random_exploration
+ self.policy_delay = policy_delay
+ self.target_noise_clip = target_noise_clip
+ self.target_policy_noise = target_policy_noise
+
+ self.graph = None
+ self.replay_buffer = None
+ self.episode_reward = None
+ self.sess = None
+ self.tensorboard_log = tensorboard_log
+ self.verbose = verbose
+ self.params = None
+ self.summary = None
+ self.policy_tf = None
+ self.full_tensorboard_log = full_tensorboard_log
+
+ self.obs_target = None
+ self.target_policy_tf = None
+ self.actions_ph = None
+ self.rewards_ph = None
+ self.terminals_ph = None
+ self.observations_ph = None
+ self.action_target = None
+ self.next_observations_ph = None
+ self.step_ops = None
+ self.target_ops = None
+ self.infos_names = None
+ self.target_params = None
+ self.learning_rate_ph = None
+ self.processed_obs_ph = None
+ self.processed_next_obs_ph = None
+ self.policy_out = None
+ self.policy_train_op = None
+ self.policy_loss = None
+
+ if _init_setup_model:
+ self.setup_model()
+
+ def _get_pretrain_placeholders(self):
+ policy = self.policy_tf
+ # Rescale
+ policy_out = self.policy_out * np.abs(self.action_space.low)
+ return policy.obs_ph, self.actions_ph, policy_out
+
+ def setup_model(self):
+ with SetVerbosity(self.verbose):
+ self.graph = tf.Graph()
+ with self.graph.as_default():
+ n_cpu = multiprocessing.cpu_count()
+ if sys.platform == 'darwin':
+ n_cpu //= 2
+ self.sess = tf_util.make_session(num_cpu=n_cpu, graph=self.graph)
+
+ self.replay_buffer = ReplayBuffer(self.buffer_size)
+
+ with tf.variable_scope("input", reuse=False):
+ # Create policy and target TF objects
+ self.policy_tf = self.policy(self.sess, self.observation_space, self.action_space,
+ **self.policy_kwargs)
+ self.target_policy_tf = self.policy(self.sess, self.observation_space, self.action_space,
+ **self.policy_kwargs)
+
+ # Initialize Placeholders
+ self.observations_ph = self.policy_tf.obs_ph
+ # Normalized observation for pixels
+ self.processed_obs_ph = self.policy_tf.processed_obs
+ self.next_observations_ph = self.target_policy_tf.obs_ph
+ self.processed_next_obs_ph = self.target_policy_tf.processed_obs
+ self.action_target = self.target_policy_tf.action_ph
+ self.terminals_ph = tf.placeholder(tf.float32, shape=(None, 1), name='terminals')
+ self.rewards_ph = tf.placeholder(tf.float32, shape=(None, 1), name='rewards')
+ self.actions_ph = tf.placeholder(tf.float32, shape=(None,) + self.action_space.shape,
+ name='actions')
+ self.learning_rate_ph = tf.placeholder(tf.float32, [], name="learning_rate_ph")
+
+ with tf.variable_scope("model", reuse=False):
+ # Create the policy
+ self.policy_out = policy_out = self.policy_tf.make_actor(self.processed_obs_ph)
+ # Use two Q-functions to improve performance by reducing overestimation bias
+ qf1, qf2 = self.policy_tf.make_critics(self.processed_obs_ph, self.actions_ph)
+ # Q value when following the current policy
+ qf1_pi, _ = self.policy_tf.make_critics(self.processed_obs_ph,
+ policy_out, reuse=True)
+
+ with tf.variable_scope("target", reuse=False):
+ # Create target networks
+ target_policy_out = self.target_policy_tf.make_actor(self.processed_next_obs_ph)
+ # Target policy smoothing, by adding clipped noise to target actions
+ target_noise = tf.random_normal(tf.shape(target_policy_out), stddev=self.target_policy_noise)
+ target_noise = tf.clip_by_value(target_noise, -self.target_noise_clip, self.target_noise_clip)
+ # Clip the noisy action to remain in the bounds [-1, 1] (output of a tanh)
+ noisy_target_action = tf.clip_by_value(target_policy_out + target_noise, -1, 1)
+ # Q values when following the target policy
+ qf1_target, qf2_target = self.target_policy_tf.make_critics(self.processed_next_obs_ph,
+ noisy_target_action)
+
+ with tf.variable_scope("loss", reuse=False):
+ # Take the min of the two target Q-Values (clipped Double-Q Learning)
+ min_qf_target = tf.minimum(qf1_target, qf2_target)
+
+ # Targets for Q value regression
+ q_backup = tf.stop_gradient(
+ self.rewards_ph +
+ (1 - self.terminals_ph) * self.gamma * min_qf_target
+ )
+
+ # Compute Q-Function loss
+ qf1_loss = tf.reduce_mean((q_backup - qf1) ** 2)
+ qf2_loss = tf.reduce_mean((q_backup - qf2) ** 2)
+
+ qvalues_losses = qf1_loss + qf2_loss
+
+ # Policy loss: maximise q value
+ self.policy_loss = policy_loss = -tf.reduce_mean(qf1_pi)
+
+ # Policy train op
+ # will be called only every n training steps,
+ # where n is the policy delay
+ policy_optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate_ph)
+ policy_train_op = policy_optimizer.minimize(policy_loss, var_list=get_vars('model/pi'))
+ self.policy_train_op = policy_train_op
+
+ # Q Values optimizer
+ qvalues_optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate_ph)
+ qvalues_params = get_vars('model/values_fn/')
+
+ # Q Values and policy target params
+ source_params = get_vars("model/")
+ target_params = get_vars("target/")
+
+ # Polyak averaging for target variables
+ self.target_ops = [
+ tf.assign(target, (1 - self.tau) * target + self.tau * source)
+ for target, source in zip(target_params, source_params)
+ ]
+
+ # Initializing target to match source variables
+ target_init_op = [
+ tf.assign(target, source)
+ for target, source in zip(target_params, source_params)
+ ]
+
+ train_values_op = qvalues_optimizer.minimize(qvalues_losses, var_list=qvalues_params)
+
+ self.infos_names = ['qf1_loss', 'qf2_loss']
+ # All ops to call during one training step
+ self.step_ops = [qf1_loss, qf2_loss,
+ qf1, qf2, train_values_op]
+
+ # Monitor losses and entropy in tensorboard
+ tf.summary.scalar('policy_loss', policy_loss)
+ tf.summary.scalar('qf1_loss', qf1_loss)
+ tf.summary.scalar('qf2_loss', qf2_loss)
+ tf.summary.scalar('learning_rate', tf.reduce_mean(self.learning_rate_ph))
+
+ # Retrieve parameters that must be saved
+ self.params = get_vars("model")
+ self.target_params = get_vars("target/")
+
+ # Initialize Variables and target network
+ with self.sess.as_default():
+ self.sess.run(tf.global_variables_initializer())
+ self.sess.run(target_init_op)
+
+ self.summary = tf.summary.merge_all()
+
+ def _train_step(self, step, writer, learning_rate, update_policy):
+ # Sample a batch from the replay buffer
+ batch = self.replay_buffer.sample(self.batch_size)
+ batch_obs, batch_actions, batch_rewards, batch_next_obs, batch_dones = batch
+
+ feed_dict = {
+ self.observations_ph: batch_obs,
+ self.actions_ph: batch_actions,
+ self.next_observations_ph: batch_next_obs,
+ self.rewards_ph: batch_rewards.reshape(self.batch_size, -1),
+ self.terminals_ph: batch_dones.reshape(self.batch_size, -1),
+ self.learning_rate_ph: learning_rate
+ }
+
+ step_ops = self.step_ops
+ if update_policy:
+ # Update policy and target networks
+ step_ops = step_ops + [self.policy_train_op, self.target_ops, self.policy_loss]
+
+ # Do one gradient step
+ # and optionally compute log for tensorboard
+ if writer is not None:
+ out = self.sess.run([self.summary] + step_ops, feed_dict)
+ summary = out.pop(0)
+ writer.add_summary(summary, step)
+ else:
+ out = self.sess.run(step_ops, feed_dict)
+
+ # Unpack to monitor losses
+ qf1_loss, qf2_loss, *_values = out
+
+ return qf1_loss, qf2_loss
+
+ def learn(self, total_timesteps, callback=None, seed=None,
+ log_interval=4, tb_log_name="TD3", reset_num_timesteps=True, replay_wrapper=None):
+
+ new_tb_log = self._init_num_timesteps(reset_num_timesteps)
+
+ if replay_wrapper is not None:
+ self.replay_buffer = replay_wrapper(self.replay_buffer)
+
+ with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \
+ as writer:
+
+ self._setup_learn(seed)
+
+ # Transform to callable if needed
+ self.learning_rate = get_schedule_fn(self.learning_rate)
+ # Initial learning rate
+ current_lr = self.learning_rate(1)
+
+ start_time = time.time()
+ episode_rewards = [0.0]
+ episode_successes = []
+ if self.action_noise is not None:
+ self.action_noise.reset()
+ obs = self.env.reset()
+ self.episode_reward = np.zeros((1,))
+ ep_info_buf = deque(maxlen=100)
+ n_updates = 0
+ infos_values = []
+
+ for step in range(total_timesteps):
+ if callback is not None:
+ # Only stop training if return value is False, not when it is None. This is for backwards
+ # compatibility with callbacks that have no return statement.
+ if callback(locals(), globals()) is False:
+ break
+
+ # Before training starts, randomly sample actions
+ # from a uniform distribution for better exploration.
+ # Afterwards, use the learned policy
+ # if random_exploration is set to 0 (normal setting)
+ if (self.num_timesteps < self.learning_starts
+ or np.random.rand() < self.random_exploration):
+ # No need to rescale when sampling random action
+ rescaled_action = action = self.env.action_space.sample()
+ else:
+ action = self.policy_tf.step(obs[None]).flatten()
+ # Add noise to the action, as the policy
+ # is deterministic, this is required for exploration
+ if self.action_noise is not None:
+ action = np.clip(action + self.action_noise(), -1, 1)
+ # Rescale from [-1, 1] to the correct bounds
+ rescaled_action = action * np.abs(self.action_space.low)
+
+ assert action.shape == self.env.action_space.shape
+
+ new_obs, reward, done, info = self.env.step(rescaled_action)
+
+ # Store transition in the replay buffer.
+ self.replay_buffer.add(obs, action, reward, new_obs, float(done))
+ obs = new_obs
+
+ # Retrieve reward and episode length if using Monitor wrapper
+ maybe_ep_info = info.get('episode')
+ if maybe_ep_info is not None:
+ ep_info_buf.extend([maybe_ep_info])
+
+ if writer is not None:
+ # Write reward per episode to tensorboard
+ ep_reward = np.array([reward]).reshape((1, -1))
+ ep_done = np.array([done]).reshape((1, -1))
+ self.episode_reward = total_episode_reward_logger(self.episode_reward, ep_reward,
+ ep_done, writer, self.num_timesteps)
+
+ if step % self.train_freq == 0:
+ mb_infos_vals = []
+ # Update policy, critics and target networks
+ for grad_step in range(self.gradient_steps):
+ # Break if the warmup phase is not over
+ # or if there are not enough samples in the replay buffer
+ if not self.replay_buffer.can_sample(self.batch_size) \
+ or self.num_timesteps < self.learning_starts:
+ break
+ n_updates += 1
+ # Compute current learning_rate
+ frac = 1.0 - step / total_timesteps
+ current_lr = self.learning_rate(frac)
+ # Update policy and critics (q functions)
+ # Note: the policy is updated less frequently than the Q functions
+ # this is controlled by the `policy_delay` parameter
+ mb_infos_vals.append(
+ self._train_step(step, writer, current_lr, (step + grad_step) % self.policy_delay == 0))
+
+ # Log losses and entropy, useful for monitor training
+ if len(mb_infos_vals) > 0:
+ infos_values = np.mean(mb_infos_vals, axis=0)
+
+ episode_rewards[-1] += reward
+ if done:
+ if self.action_noise is not None:
+ self.action_noise.reset()
+ if not isinstance(self.env, VecEnv):
+ obs = self.env.reset()
+ episode_rewards.append(0.0)
+
+ maybe_is_success = info.get('is_success')
+ if maybe_is_success is not None:
+ episode_successes.append(float(maybe_is_success))
+
+ if len(episode_rewards[-101:-1]) == 0:
+ mean_reward = -np.inf
+ else:
+ mean_reward = round(float(np.mean(episode_rewards[-101:-1])), 1)
+
+ num_episodes = len(episode_rewards)
+ self.num_timesteps += 1
+ # Display training infos
+ if self.verbose >= 1 and done and log_interval is not None and len(episode_rewards) % log_interval == 0:
+ fps = int(step / (time.time() - start_time))
+ logger.logkv("episodes", num_episodes)
+ logger.logkv("mean 100 episode reward", mean_reward)
+ if len(ep_info_buf) > 0 and len(ep_info_buf[0]) > 0:
+ logger.logkv('ep_rewmean', safe_mean([ep_info['r'] for ep_info in ep_info_buf]))
+ logger.logkv('eplenmean', safe_mean([ep_info['l'] for ep_info in ep_info_buf]))
+ logger.logkv("n_updates", n_updates)
+ logger.logkv("current_lr", current_lr)
+ logger.logkv("fps", fps)
+ logger.logkv('time_elapsed', int(time.time() - start_time))
+ if len(episode_successes) > 0:
+ logger.logkv("success rate", np.mean(episode_successes[-100:]))
+ if len(infos_values) > 0:
+ for (name, val) in zip(self.infos_names, infos_values):
+ logger.logkv(name, val)
+ logger.logkv("total timesteps", self.num_timesteps)
+ logger.dumpkvs()
+ # Reset infos:
+ infos_values = []
+ return self
+
+ def action_probability(self, observation, state=None, mask=None, actions=None, logp=False):
+ _ = np.array(observation)
+
+ if actions is not None:
+ raise ValueError("Error: TD3 does not have action probabilities.")
+
+ # here there are no action probabilities, as DDPG does not use a probability distribution
+ warnings.warn("Warning: action probability is meaningless for TD3. Returning None")
+ return None
+
+ def predict(self, observation, state=None, mask=None, deterministic=True):
+ observation = np.array(observation)
+ vectorized_env = self._is_vectorized_observation(observation, self.observation_space)
+
+ observation = observation.reshape((-1,) + self.observation_space.shape)
+ actions = self.policy_tf.step(observation)
+
+ if self.action_noise is not None and not deterministic:
+ actions = np.clip(actions + self.action_noise(), -1, 1)
+
+ actions = actions.reshape((-1,) + self.action_space.shape) # reshape to the correct action shape
+ actions = actions * np.abs(self.action_space.low) # scale the output for the prediction
+
+ if not vectorized_env:
+ actions = actions[0]
+
+ return actions, None
+
+ def get_parameter_list(self):
+ return (self.params +
+ self.target_params)
+
+ def save(self, save_path):
+ data = {
+ "learning_rate": self.learning_rate,
+ "buffer_size": self.buffer_size,
+ "learning_starts": self.learning_starts,
+ "train_freq": self.train_freq,
+ "batch_size": self.batch_size,
+ "tau": self.tau,
+ # Should we also store the replay buffer?
+ # this may lead to high memory usage
+ # with all transition inside
+ # "replay_buffer": self.replay_buffer
+ "policy_delay": self.policy_delay,
+ "target_noise_clip": self.target_noise_clip,
+ "target_policy_noise": self.target_policy_noise,
+ "gamma": self.gamma,
+ "verbose": self.verbose,
+ "observation_space": self.observation_space,
+ "action_space": self.action_space,
+ "policy": self.policy,
+ "n_envs": self.n_envs,
+ "action_noise": self.action_noise,
+ "random_exploration": self.random_exploration,
+ "_vectorize_action": self._vectorize_action,
+ "policy_kwargs": self.policy_kwargs
+ }
+
+ params_to_save = self.get_parameters()
+
+ self._save_to_file(save_path, data=data, params=params_to_save)
diff --git a/tests/test_continuous.py b/tests/test_continuous.py
index 07220fa5c7..be54642cb8 100644
--- a/tests/test_continuous.py
+++ b/tests/test_continuous.py
@@ -5,14 +5,10 @@
import pytest
import numpy as np
-from stable_baselines import A2C, SAC
+from stable_baselines import A2C, SAC, DDPG, PPO1, PPO2, TRPO, TD3
# TODO: add support for continuous actions
# from stable_baselines.acer import ACER
# from stable_baselines.acktr import ACKTR
-from stable_baselines.ddpg import DDPG
-from stable_baselines.ppo1 import PPO1
-from stable_baselines.ppo2 import PPO2
-from stable_baselines.trpo_mpi import TRPO
from stable_baselines.common import set_global_seeds
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.common.identity_env import IdentityEnvBox
@@ -31,6 +27,7 @@
PPO1,
PPO2,
SAC,
+ TD3,
TRPO
]
@@ -87,7 +84,7 @@ def test_model_manipulation(request, model_class):
with pytest.warns(None) as record:
act_prob = model.action_probability(obs)
- if model_class in [DDPG, SAC]:
+ if model_class in [DDPG, SAC, TD3]:
# check that only one warning was raised
assert len(record) == 1, "No warning was raised for {}".format(model_class)
assert act_prob is None, "Error: action_probability should be None for {}".format(model_class)
@@ -104,7 +101,7 @@ def test_model_manipulation(request, model_class):
observations = observations.reshape((-1, 1))
actions = np.array([env.action_space.sample() for _ in range(10)])
- if model_class in [DDPG, SAC]:
+ if model_class in [DDPG, SAC, TD3]:
with pytest.raises(ValueError):
model.action_probability(observations, actions=actions)
else:
diff --git a/tests/test_gail.py b/tests/test_gail.py
index f36fcbc78a..2223948215 100644
--- a/tests/test_gail.py
+++ b/tests/test_gail.py
@@ -4,7 +4,8 @@
import numpy as np
import pytest
-from stable_baselines import A2C, ACER, ACKTR, GAIL, DDPG, DQN, PPO1, PPO2, TRPO, SAC
+from stable_baselines import A2C, ACER, ACKTR, GAIL, DDPG, DQN, PPO1, PPO2,\
+ TD3, TRPO, SAC
from stable_baselines.common.cmd_util import make_atari_env
from stable_baselines.common.vec_env import VecFrameStack
from stable_baselines.gail import ExpertDataset, generate_expert_traj
@@ -105,7 +106,7 @@ def test_pretrain_images():
del dataset, model, env
-@pytest.mark.parametrize("model_class", [A2C, GAIL, DDPG, PPO1, PPO2, SAC, TRPO])
+@pytest.mark.parametrize("model_class", [A2C, GAIL, DDPG, PPO1, PPO2, SAC, TD3, TRPO])
def test_behavior_cloning_box(model_class):
"""
Behavior cloning with continuous actions.
diff --git a/tests/test_her.py b/tests/test_her.py
index 269c9a24de..bf24a75b96 100644
--- a/tests/test_her.py
+++ b/tests/test_her.py
@@ -2,7 +2,7 @@
import pytest
-from stable_baselines import HER, DQN, SAC, DDPG
+from stable_baselines import HER, DQN, SAC, DDPG, TD3
from stable_baselines.her import GoalSelectionStrategy, HERGoalEnvWrapper
from stable_baselines.her.replay_buffer import KEY_TO_GOAL_STRATEGY
from stable_baselines.common.bit_flipping_env import BitFlippingEnv
@@ -32,20 +32,20 @@ def model_predict(model, env, n_steps, additional_check=None):
@pytest.mark.parametrize('goal_selection_strategy', list(GoalSelectionStrategy))
-@pytest.mark.parametrize('model_class', [DQN, SAC, DDPG])
+@pytest.mark.parametrize('model_class', [DQN, SAC, DDPG, TD3])
@pytest.mark.parametrize('discrete_obs_space', [False, True])
def test_her(model_class, goal_selection_strategy, discrete_obs_space):
- env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC],
+ env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC, TD3],
max_steps=N_BITS, discrete_obs_space=discrete_obs_space)
# Take random actions 10% of the time
- kwargs = {'random_exploration': 0.1} if model_class in [DDPG, SAC] else {}
+ kwargs = {'random_exploration': 0.1} if model_class in [DDPG, SAC, TD3] else {}
model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy=goal_selection_strategy,
verbose=0, **kwargs)
model.learn(1000)
-@pytest.mark.parametrize('model_class', [DDPG, SAC, DQN])
+@pytest.mark.parametrize('model_class', [DDPG, SAC, DQN, TD3])
def test_long_episode(model_class):
"""
Check that the model does not break when the replay buffer is still empty
@@ -53,12 +53,12 @@ def test_long_episode(model_class):
"""
# n_bits > nb_rollout_steps
n_bits = 10
- env = BitFlippingEnv(n_bits, continuous=model_class in [DDPG, SAC],
+ env = BitFlippingEnv(n_bits, continuous=model_class in [DDPG, SAC, TD3],
max_steps=n_bits)
kwargs = {}
if model_class == DDPG:
kwargs['nb_rollout_steps'] = 9 # < n_bits
- elif model_class in [DQN, SAC]:
+ elif model_class in [DQN, SAC, TD3]:
kwargs['batch_size'] = 8 # < n_bits
kwargs['learning_starts'] = 0
@@ -68,9 +68,9 @@ def test_long_episode(model_class):
@pytest.mark.parametrize('goal_selection_strategy', [list(KEY_TO_GOAL_STRATEGY.keys())[0]])
-@pytest.mark.parametrize('model_class', [DQN, SAC, DDPG])
+@pytest.mark.parametrize('model_class', [DQN, SAC, DDPG, TD3])
def test_model_manipulation(model_class, goal_selection_strategy):
- env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC], max_steps=N_BITS)
+ env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS)
env = DummyVecEnv([lambda: env])
model = HER('MlpPolicy', env, model_class, n_sampled_goal=3, goal_selection_strategy=goal_selection_strategy,
@@ -93,7 +93,7 @@ def test_model_manipulation(model_class, goal_selection_strategy):
with pytest.raises(ValueError):
model.predict(env.reset())
- env_ = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC], max_steps=N_BITS)
+ env_ = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS)
env_ = HERGoalEnvWrapper(env_)
model_predict(model, env_, n_steps=100, additional_check=None)
@@ -107,7 +107,7 @@ def test_model_manipulation(model_class, goal_selection_strategy):
del model
- env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC], max_steps=N_BITS)
+ env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS)
model = HER.load('./test_her', env=env)
model.learn(1000)
diff --git a/tests/test_identity.py b/tests/test_identity.py
index 33fd6db64c..8a7cd51d29 100644
--- a/tests/test_identity.py
+++ b/tests/test_identity.py
@@ -1,8 +1,8 @@
import pytest
import numpy as np
-from stable_baselines import A2C, ACER, ACKTR, DQN, DDPG, PPO1, PPO2, TRPO
-from stable_baselines.ddpg import AdaptiveParamNoiseSpec
+from stable_baselines import A2C, ACER, ACKTR, DQN, DDPG, SAC, PPO1, PPO2, TD3, TRPO
+from stable_baselines.ddpg import NormalActionNoise
from stable_baselines.common.identity_env import IdentityEnv, IdentityEnvBox
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.common import set_global_seeds
@@ -62,17 +62,21 @@ def test_identity(model_name):
@pytest.mark.slow
-def test_identity_ddpg():
+@pytest.mark.parametrize("model_class", [DDPG, TD3, SAC])
+def test_identity_continuous(model_class):
"""
Test if the algorithm (with a given policy)
can learn an identity transformation (i.e. return observation as an action)
"""
env = DummyVecEnv([lambda: IdentityEnvBox(eps=0.5)])
- std = 0.2
- param_noise = AdaptiveParamNoiseSpec(initial_stddev=float(std), desired_action_stddev=float(std))
+ if model_class in [DDPG, TD3]:
+ n_actions = 1
+ action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
+ else:
+ action_noise = None
- model = DDPG("MlpPolicy", env, gamma=0.0, param_noise=param_noise, memory_limit=int(1e6))
+ model = model_class("MlpPolicy", env, gamma=0.1, action_noise=action_noise, buffer_size=int(1e6))
model.learn(total_timesteps=20000, seed=0)
n_trials = 1000