Skip to content
This repository has been archived by the owner on Oct 7, 2024. It is now read-only.

Commit

Permalink
Re-organize baselines into subdirectories according to their provenan…
Browse files Browse the repository at this point in the history
…ce/libraries used.

- tf: TensorFlow 2/Sonnet 2/TRFL-based agents.
- jax: JAX/Haiku/rlax-based agents.
- third_party: Agents created by third parties (not DeepMind).

Also adopt more standard naming practice within each agent folder (agent.py).

PiperOrigin-RevId: 305674544
Change-Id: I3d4f076fb96d2e0250cfbb3f1adf163ce6932e97
  • Loading branch information
aslanides authored and copybara-github committed Apr 9, 2020
1 parent 8118f60 commit 0fdf026
Show file tree
Hide file tree
Showing 39 changed files with 42 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
# ============================================================================
"""A simple actor-critic implementation in JAX."""

from bsuite.baselines.actor_critic_jax.actor_critic import ActorCritic
from bsuite.baselines.actor_critic_jax.actor_critic import default_agent
from bsuite.baselines.jax.actor_critic.agent import ActorCritic
from bsuite.baselines.jax.actor_critic.agent import default_agent
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import bsuite
from bsuite import sweep

from bsuite.baselines import actor_critic_jax
from bsuite.baselines import experiment
from bsuite.baselines.jax import actor_critic
from bsuite.baselines.utils import pool

# Internal imports.
Expand Down Expand Up @@ -52,7 +52,7 @@ def run(bsuite_id: str) -> str:
overwrite=FLAGS.overwrite,
)

agent = actor_critic_jax.default_agent(
agent = actor_critic.default_agent(
env.observation_spec(), env.action_spec())

num_episodes = FLAGS.num_episodes or getattr(env, 'bsuite_num_episodes')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from absl.testing import parameterized

from bsuite import sweep
from bsuite.baselines.actor_critic_rnn import run
from bsuite.baselines.jax.actor_critic import run

FLAGS = flags.FLAGS

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
# ============================================================================
"""A simple actor-critic implementation in JAX."""

from bsuite.baselines.actor_critic_rnn_jax.actor_critic_rnn import ActorCriticRNN
from bsuite.baselines.actor_critic_rnn_jax.actor_critic_rnn import default_agent
from bsuite.baselines.jax.actor_critic_rnn.agent import ActorCriticRNN
from bsuite.baselines.jax.actor_critic_rnn.agent import default_agent
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import bsuite
from bsuite import sweep

from bsuite.baselines import actor_critic_rnn_jax
from bsuite.baselines import experiment
from bsuite.baselines.jax import actor_critic_rnn
from bsuite.baselines.utils import pool

# Internal imports.
Expand Down Expand Up @@ -52,7 +52,7 @@ def run(bsuite_id: str) -> str:
overwrite=FLAGS.overwrite,
)

agent = actor_critic_rnn_jax.default_agent(
agent = actor_critic_rnn.default_agent(
env.observation_spec(), env.action_spec())

num_episodes = FLAGS.num_episodes or getattr(env, 'bsuite_num_episodes')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from absl.testing import parameterized

from bsuite import sweep
from bsuite.baselines.actor_critic_rnn_jax import run
from bsuite.baselines.jax.actor_critic_rnn import run

FLAGS = flags.FLAGS

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
# ============================================================================
"""A simple DQN agent implemented in JAX."""

from bsuite.baselines.dqn_jax.dqn import default_agent
from bsuite.baselines.dqn_jax.dqn import DQN
from bsuite.baselines.jax.dqn.agent import default_agent
from bsuite.baselines.jax.dqn.agent import DQN
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import bsuite
from bsuite import sweep

from bsuite.baselines import dqn_jax
from bsuite.baselines import experiment
from bsuite.baselines.jax import dqn
from bsuite.baselines.utils import pool

# Internal imports.
Expand Down Expand Up @@ -52,7 +52,7 @@ def run(bsuite_id: str) -> str:
overwrite=FLAGS.overwrite,
)

agent = dqn_jax.default_agent(env.observation_spec(), env.action_spec())
agent = dqn.default_agent(env.observation_spec(), env.action_spec())

num_episodes = FLAGS.num_episodes or getattr(env, 'bsuite_num_episodes')
experiment.run(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from absl.testing import parameterized

from bsuite import sweep
from bsuite.baselines.dqn_jax import run
from bsuite.baselines.jax.dqn import run

FLAGS = flags.FLAGS

Expand Down
4 changes: 2 additions & 2 deletions bsuite/baselines/random/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
# ============================================================================
"""An agent that takes uniformly random actions."""

from bsuite.baselines.random.random import default_agent
from bsuite.baselines.random.random import Random
from bsuite.baselines.random.agent import default_agent
from bsuite.baselines.random.agent import Random
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
# ============================================================================
"""A simple TensorFlow 2-based implementation of the actor-critic algorithm."""

from bsuite.baselines.actor_critic.actor_critic import ActorCritic
from bsuite.baselines.actor_critic.actor_critic import default_agent
from bsuite.baselines.actor_critic.actor_critic import PolicyValueNet
from bsuite.baselines.tf.actor_critic.agent import ActorCritic
from bsuite.baselines.tf.actor_critic.agent import default_agent
from bsuite.baselines.tf.actor_critic.agent import PolicyValueNet
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import bsuite
from bsuite import sweep

from bsuite.baselines import actor_critic
from bsuite.baselines import experiment
from bsuite.baselines.tf import actor_critic
from bsuite.baselines.utils import pool

import sonnet as snt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from absl.testing import parameterized

from bsuite import sweep
from bsuite.baselines.actor_critic import run
from bsuite.baselines.tf.actor_critic import run

FLAGS = flags.FLAGS

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
# ============================================================================
"""A simple TensorFlow 2-based implementation of a recurrent actor-critic."""

from bsuite.baselines.actor_critic_rnn.actor_critic_rnn import ActorCriticRNN
from bsuite.baselines.actor_critic_rnn.actor_critic_rnn import default_agent
from bsuite.baselines.actor_critic_rnn.actor_critic_rnn import PolicyValueRNN
from bsuite.baselines.tf.actor_critic_rnn.agent import ActorCriticRNN
from bsuite.baselines.tf.actor_critic_rnn.agent import default_agent
from bsuite.baselines.tf.actor_critic_rnn.agent import PolicyValueRNN
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import bsuite
from bsuite import sweep

from bsuite.baselines import actor_critic_rnn
from bsuite.baselines import experiment
from bsuite.baselines.tf import actor_critic_rnn
from bsuite.baselines.utils import pool

import sonnet as snt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from absl.testing import parameterized

from bsuite import sweep
from bsuite.baselines.actor_critic_jax import run
from bsuite.baselines.tf.actor_critic_rnn import run

FLAGS = flags.FLAGS

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
# ============================================================================
"""A simple implementation of Bootstrapped DQN with prior networks."""

from bsuite.baselines.boot_dqn.boot_dqn import BootstrappedDqn
from bsuite.baselines.boot_dqn.boot_dqn import default_agent
from bsuite.baselines.boot_dqn.boot_dqn import make_ensemble
from bsuite.baselines.tf.boot_dqn.agent import BootstrappedDqn
from bsuite.baselines.tf.boot_dqn.agent import default_agent
from bsuite.baselines.tf.boot_dqn.agent import make_ensemble
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import bsuite
from bsuite import sweep

from bsuite.baselines import boot_dqn
from bsuite.baselines import experiment
from bsuite.baselines.tf import boot_dqn
from bsuite.baselines.utils import pool

import sonnet as snt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from absl.testing import parameterized

from bsuite import sweep
from bsuite.baselines.boot_dqn import run
from bsuite.baselines.tf.boot_dqn import run

FLAGS = flags.FLAGS

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
# ============================================================================
"""A simple TensorFlow 2-based DQN implementation."""

from bsuite.baselines.dqn.dqn import default_agent
from bsuite.baselines.dqn.dqn import DQN
from bsuite.baselines.tf.dqn.agent import default_agent
from bsuite.baselines.tf.dqn.agent import DQN
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import bsuite
from bsuite import sweep

from bsuite.baselines import dqn
from bsuite.baselines import experiment
from bsuite.baselines.tf import dqn
from bsuite.baselines.utils import pool

import sonnet as snt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from absl.testing import parameterized

from bsuite import sweep
from bsuite.baselines.dqn import run
from bsuite.baselines.tf.dqn import run

FLAGS = flags.FLAGS

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from absl.testing import parameterized

from bsuite import sweep
from bsuite.baselines.dopamine_dqn import run
from bsuite.baselines.third_party.dopamine_dqn import run

import tensorflow.compat.v1 as tf

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# python3
# pylint: disable=g-bad-file-header
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from absl.testing import parameterized

from bsuite import sweep
from bsuite.baselines.openai_dqn import run
from bsuite.baselines.third_party.openai_dqn import run

import tensorflow.compat.v1 as tf

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# python3
# pylint: disable=g-bad-file-header
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# python3
# pylint: disable=g-bad-file-header
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
Expand All @@ -18,7 +19,7 @@
from absl import flags
from absl.testing import absltest

from bsuite.baselines.openai_ppo import run
from bsuite.baselines.third_party.openai_ppo import run

FLAGS = flags.FLAGS

Expand Down

0 comments on commit 0fdf026

Please sign in to comment.