Skip to content

Commit

Permalink
Twin Delayed DDPG (TD3) (#392)
Browse files Browse the repository at this point in the history
* Add TD3

* Update defaults

* Add tests for TD3

* Update doc and bump version

* Clean up td3

* Enable deterministic=False for TD3

* Move load method to base class

* Fix codacy complain

* Improve doc

* Fix default args

* Clean up DDPG

* Minor: improve doc

* Doc fix for TD3

* Typo in comment

Co-Authored-By: Ashley Hill <[email protected]>

* Add noise imports to TD3 init
  • Loading branch information
araffin authored and hill-a committed Jul 28, 2019
1 parent 2d8f49d commit 3e261db
Show file tree
Hide file tree
Showing 26 changed files with 1,063 additions and 159 deletions.
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ This toolset is a fork of OpenAI Baselines, with a major structural refactoring,
- PEP8 compliant (unified code style)
- Documented functions and classes
- More tests & more code coverage
- Additional algorithms: SAC and TD3 (+ HER support for DQN, DDPG, SAC and TD3)


| **Features** | **Stable-Baselines** | **OpenAI Baselines** |
| --------------------------- | --------------------------------- | --------------------------------- |
Expand All @@ -33,7 +35,7 @@ This toolset is a fork of OpenAI Baselines, with a major structural refactoring,
| PEP8 code style | :heavy_check_mark: | :heavy_check_mark: <sup>(5)</sup> |
| Custom callback | :heavy_check_mark: | :heavy_minus_sign: <sup>(6)</sup> |

<sup><sup>(1): Forked from previous version of OpenAI baselines, with now SAC in addition</sup></sup><br>
<sup><sup>(1): Forked from previous version of OpenAI baselines, with now SAC and TD3 in addition</sup></sup><br>
<sup><sup>(2): Currently not available for DDPG, and only from the run script. </sup></sup><br>
<sup><sup>(3): Only via the run script.</sup></sup><br>
<sup><sup>(4): Rudimentary logging of training information (no loss nor graph). </sup></sup><br>
Expand Down Expand Up @@ -156,15 +158,16 @@ All the following examples can be executed online using Google colab notebooks:
| PPO1 | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: <sup>(4)</sup> |
| PPO2 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| SAC | :heavy_check_mark: | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: |
| TD3 | :heavy_check_mark: | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: |
| TRPO | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: <sup>(4)</sup> |

<sup><sup>(1): Whether or not the algorithm has be refactored to fit the ```BaseRLModel``` class.</sup></sup><br>
<sup><sup>(2): Only implemented for TRPO.</sup></sup><br>
<sup><sup>(3): Re-implemented from scratch</sup></sup><br>
<sup><sup>(3): Re-implemented from scratch, now supports DQN, DDPG, SAC and TD3</sup></sup><br>
<sup><sup>(4): Multi Processing with [MPI](https://mpi4py.readthedocs.io/en/stable/).</sup></sup><br>
<sup><sup>(5): TODO, in project scope.</sup></sup>

NOTE: Soft Actor-Critic (SAC) was not part of the original baselines and HER was reimplemented from scratch.
NOTE: Soft Actor-Critic (SAC) and Twin Delayed DDPG (TD3) were not part of the original baselines and HER was reimplemented from scratch.

Actions ```gym.spaces```:
* ```Box```: A N-dimensional box that containes every point in the action space.
Expand Down Expand Up @@ -220,4 +223,4 @@ If you want to contribute, please read **CONTRIBUTING.md** guide first.

Stable Baselines was created in the [robotics lab U2IS](http://u2is.ensta-paristech.fr/index.php?lang=en) ([INRIA Flowers](https://flowers.inria.fr/) team) at [ENSTA ParisTech](http://www.ensta-paristech.fr/en).

Logo credits: L.M. Tenkes
Logo credits: [L.M. Tenkes](https://www.instagram.com/lucillehue/)
5 changes: 3 additions & 2 deletions docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ GAIL [#f2]_ ✔️ ✔️ ✔️ ✔️
PPO1 ✔️ ❌ ✔️ ✔️ ✔️ [#f3]_
PPO2 ✔️ ✔️ ✔️ ✔️ ✔️
SAC ✔️ ❌ ✔️ ❌ ❌
TD3 ✔️ ❌ ✔️ ❌ ❌
TRPO ✔️ ❌ ✔️ ✔ ✔️ [#f3]_
============ ======================== ========= =========== ============ ================

Expand All @@ -34,8 +35,8 @@ TRPO ✔️ ❌ ✔️ ✔
.. [#f4] TODO, in project scope.
.. note::
Non-array spaces such as `Dict` or `Tuple` are not currently supported by any algorithm,
except HER for dict when working with gym.GoalEnv
Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm,
except HER for dict when working with ``gym.GoalEnv``

Actions ``gym.spaces``:

Expand Down
8 changes: 4 additions & 4 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ In the following example, we will train, save and load a DQN model on the Lunar
``load`` function re-creates model from scratch on each call, which can be slow.
If you need to e.g. evaluate same model with multiple different sets of parameters, consider
using ``load_parameters`` instead.

.. code-block:: python
import gym
Expand Down Expand Up @@ -318,15 +318,15 @@ Accessing and modifying model parameters
----------------------------------------

You can access model's parameters via ``load_parameters`` and ``get_parameters`` functions, which
use dictionaries that map variable names to NumPy arrays.
use dictionaries that map variable names to NumPy arrays.

These functions are useful when you need to e.g. evaluate large set of models with same network structure,
visualize different layers of the network or modify parameters manually.

You can access original Tensorflow Variables with function ``get_parameter_list``.

Following example demonstrates reading parameters, modifying some of them and loading them to model
by implementing `evolution strategy <http://blog.otoro.net/2017/10/29/visual-evolution-strategies/>`_
by implementing `evolution strategy <http://blog.otoro.net/2017/10/29/visual-evolution-strategies/>`_
for solving ``CartPole-v1`` environment. The initial guess for parameters is obtained by running
A2C policy gradient updates on the model.

Expand Down Expand Up @@ -466,7 +466,7 @@ The parking env is a goal-conditioned continuous control task, in which the vehi
import highway_env
import numpy as np
from stable_baselines import HER, SAC, DDPG
from stable_baselines import HER, SAC, DDPG, TD3
from stable_baselines.ddpg import NormalActionNoise
env = gym.make("parking-v0")
Expand Down
5 changes: 5 additions & 0 deletions docs/guide/vec_envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ SubprocVecEnv ✔️ ✔️ ✔️ ✔️ ✔️

For more information, see Python's `multiprocessing guidelines <https://docs.python.org/3/library/multiprocessing.html#the-spawn-and-forkserver-start-methods>`_.

VecEnv
------

.. autoclass:: VecEnv
:members:

DummyVecEnv
-----------
Expand Down
2 changes: 2 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ This toolset is a fork of OpenAI Baselines, with a major structural refactoring,
- PEP8 compliant (unified code style)
- Documented functions and classes
- More tests & more code coverage
- Additional algorithms: SAC and TD3 (+ HER support for DQN, DDPG, SAC and TD3)


.. toctree::
Expand Down Expand Up @@ -66,6 +67,7 @@ This toolset is a fork of OpenAI Baselines, with a major structural refactoring,
modules/ppo1
modules/ppo2
modules/sac
modules/td3
modules/trpo

.. toctree::
Expand Down
6 changes: 5 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@ Changelog
For download links, please look at `Github release page <https://github.com/hill-a/stable-baselines/releases>`_.


Pre-Release 2.6.1a0 (WIP)
Pre-Release 2.7.0a0 (WIP)
--------------------------

**Twin Delayed DDPG (TD3)**

Breaking Changes:
^^^^^^^^^^^^^^^^^

New Features:
^^^^^^^^^^^^^
- added Twin Delayed DDPG (TD3) algorithm, with HER support

- Add support for continuous action spaces to `action_probability`, computing the PDF of a Gaussian
policy in addition to the existing support for categorical stochastic policies.
Expand All @@ -34,6 +37,7 @@ Others:
- renamed some keys in ``traj_segment_generator`` to be more meaningful
- retrieve unnormalized reward when using Monitor wrapper with TRPO, PPO1 and GAIL
to display them in the logs (mean episode reward)
- Clean up DDPG code (renamed variables)

Documentation:
^^^^^^^^^^^^^^
Expand Down
10 changes: 5 additions & 5 deletions docs/modules/her.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ HER

`Hindsight Experience Replay (HER) <https://arxiv.org/abs/1707.01495>`_

HER is a method wrapper that works with Off policy methods (DQN, SAC and DDPG for example).
HER is a method wrapper that works with Off policy methods (DQN, SAC, TD3 and DDPG for example).

.. note::

Expand Down Expand Up @@ -39,20 +39,20 @@ Notes
Can I use?
----------

Please refer to the wrapped model (DQN, SAC or DDPG) for that section.
Please refer to the wrapped model (DQN, SAC, TD3 or DDPG) for that section.

Example
-------

.. code-block:: python
from stable_baselines import HER, DQN, SAC, DDPG
from stable_baselines import HER, DQN, SAC, DDPG, TD3
from stable_baselines.her import GoalSelectionStrategy, HERGoalEnvWrapper
from stable_baselines.common.bit_flipping_env import BitFlippingEnv
model_class = DQN # works also with SAC and DDPG
model_class = DQN # works also with SAC, DDPG and TD3
env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC], max_steps=N_BITS)
env = BitFlippingEnv(N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS)
# Available strategies (cf paper): future, final, episode, random
goal_selection_strategy = 'future' # equivalent to GoalSelectionStrategy.FUTURE
Expand Down
2 changes: 1 addition & 1 deletion docs/modules/policies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ If you need more control on the policy architecture, you can also create a custo
CnnPolicies are for images only. MlpPolicies are made for other type of features (e.g. robot joints)

.. warning::
For all algorithms (except DDPG and SAC), continuous actions are clipped during training and testing
For all algorithms (except DDPG, TD3 and SAC), continuous actions are clipped during training and testing
(to avoid out of bound error).


Expand Down
2 changes: 1 addition & 1 deletion docs/modules/ppo2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ PPO2
The `Proximal Policy Optimization <https://arxiv.org/abs/1707.06347>`_ algorithm combines ideas from A2C (having multiple workers)
and TRPO (it uses a trust region to improve the actor).

The main idea is that after an update, the new policy should be not too far form the `old` policy.
The main idea is that after an update, the new policy should be not too far form the old policy.
For that, ppo uses clipping to avoid too large update.

.. note::
Expand Down
5 changes: 5 additions & 0 deletions docs/modules/sac.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@

SAC
===

`Soft Actor Critic (SAC) <https://spinningup.openai.com/en/latest/algorithms/sac.html>`_ Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.

SAC is the successor of `Soft Q-Learning SQL <https://arxiv.org/abs/1702.08165>`_ and incorporates the double Q-learning trick from TD3.
A key feature of SAC, and a major difference with common RL algorithms, is that it is trained to maximize a trade-off between expected return and entropy, a measure of randomness in the policy.


.. warning::

The SAC model does not support ``stable_baselines.common.policies`` because it uses double q-values
Expand Down
163 changes: 163 additions & 0 deletions docs/modules/td3.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
.. _td3:

.. automodule:: stable_baselines.td3


TD3
===

`Twin Delayed DDPG (TD3) <https://spinningup.openai.com/en/latest/algorithms/td3.html>`_ Addressing Function Approximation Error in Actor-Critic Methods.

TD3 is a direct successor of DDPG and improves it using three major tricks: clipped double Q-Learning, delayed policy update and target policy smoothing.
We recommend reading `OpenAI Spinning guide on TD3 <https://spinningup.openai.com/en/latest/algorithms/td3.html>`_ to learn more about those.


.. warning::

The TD3 model does not support ``stable_baselines.common.policies`` because it uses double q-values
estimation, as a result it must use its own policy models (see :ref:`td3_policies`).


.. rubric:: Available Policies

.. autosummary::
:nosignatures:

MlpPolicy
LnMlpPolicy
CnnPolicy
LnCnnPolicy

Notes
-----

- Original paper: https://arxiv.org/pdf/1802.09477.pdf
- OpenAI Spinning Guide for TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html
- Original Implementation: https://github.com/sfujim/TD3

.. note::

The default policies for TD3 differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation,
to match the original paper


Can I use?
----------

- Recurrent policies: ❌
- Multi processing: ❌
- Gym spaces:


============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌ ✔️
Box ✔️ ✔️
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
============= ====== ===========


Example
-------

.. code-block:: python
import gym
import numpy as np
from stable_baselines import TD3
from stable_baselines.td3.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.ddpg.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
env = gym.make('Pendulum-v0')
env = DummyVecEnv([lambda: env])
# The noise objects for TD3
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
model = TD3(MlpPolicy, env, action_noise=action_noise, verbose=1)
model.learn(total_timesteps=50000, log_interval=10)
model.save("td3_pendulum")
del model # remove to demonstrate saving and loading
model = TD3.load("td3_pendulum")
obs = env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
Parameters
----------

.. autoclass:: TD3
:members:
:inherited-members:

.. _td3_policies:

TD3 Policies
-------------

.. autoclass:: MlpPolicy
:members:
:inherited-members:


.. autoclass:: LnMlpPolicy
:members:
:inherited-members:


.. autoclass:: CnnPolicy
:members:
:inherited-members:


.. autoclass:: LnCnnPolicy
:members:
:inherited-members:


Custom Policy Network
---------------------

Similarly to the example given in the `examples <../guide/custom_policy.html>`_ page.
You can easily define a custom architecture for the policy network:

.. code-block:: python
import gym
import numpy as np
from stable_baselines import TD3
from stable_baselines.td3.policies import FeedForwardPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.ddpg.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
# Custom MLP policy with two layers
class CustomTD3Policy(FeedForwardPolicy):
def __init__(self, *args, **kwargs):
super(CustomTD3Policy, self).__init__(*args, **kwargs,
layers=[400, 300],
layer_norm=False,
feature_extraction="mlp")
# Create and wrap the environment
env = gym.make('Pendulum-v0')
env = DummyVecEnv([lambda: env])
# The noise objects for TD3
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
model = TD3(CustomTD3Policy, env, action_noise=action_noise, verbose=1)
# Train the agent
model.learn(total_timesteps=80000)
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
- PEP8 compliant (unified code style)
- Documented functions and classes
- More tests & more code coverage
- Additional algorithms: SAC and TD3 (+ HER support for DQN, DDPG, SAC and TD3)
## Links
Expand Down Expand Up @@ -137,7 +138,7 @@
license="MIT",
long_description=long_description,
long_description_content_type='text/markdown',
version="2.6.1a0",
version="2.7.0a0",
)

# python setup.py sdist
Expand Down
Loading

0 comments on commit 3e261db

Please sign in to comment.