diff --git a/bsuite/baselines/actor_critic_jax/__init__.py b/bsuite/baselines/jax/actor_critic/__init__.py similarity index 84% rename from bsuite/baselines/actor_critic_jax/__init__.py rename to bsuite/baselines/jax/actor_critic/__init__.py index 8a9f9683..ddca6c3c 100644 --- a/bsuite/baselines/actor_critic_jax/__init__.py +++ b/bsuite/baselines/jax/actor_critic/__init__.py @@ -16,5 +16,5 @@ # ============================================================================ """A simple actor-critic implementation in JAX.""" -from bsuite.baselines.actor_critic_jax.actor_critic import ActorCritic -from bsuite.baselines.actor_critic_jax.actor_critic import default_agent +from bsuite.baselines.jax.actor_critic.agent import ActorCritic +from bsuite.baselines.jax.actor_critic.agent import default_agent diff --git a/bsuite/baselines/actor_critic_jax/actor_critic.py b/bsuite/baselines/jax/actor_critic/agent.py similarity index 100% rename from bsuite/baselines/actor_critic_jax/actor_critic.py rename to bsuite/baselines/jax/actor_critic/agent.py diff --git a/bsuite/baselines/actor_critic_jax/run.py b/bsuite/baselines/jax/actor_critic/run.py similarity index 96% rename from bsuite/baselines/actor_critic_jax/run.py rename to bsuite/baselines/jax/actor_critic/run.py index fafac303..7a61ada5 100644 --- a/bsuite/baselines/actor_critic_jax/run.py +++ b/bsuite/baselines/jax/actor_critic/run.py @@ -22,8 +22,8 @@ import bsuite from bsuite import sweep -from bsuite.baselines import actor_critic_jax from bsuite.baselines import experiment +from bsuite.baselines.jax import actor_critic from bsuite.baselines.utils import pool # Internal imports. @@ -52,7 +52,7 @@ def run(bsuite_id: str) -> str: overwrite=FLAGS.overwrite, ) - agent = actor_critic_jax.default_agent( + agent = actor_critic.default_agent( env.observation_spec(), env.action_spec()) num_episodes = FLAGS.num_episodes or getattr(env, 'bsuite_num_episodes') diff --git a/bsuite/baselines/actor_critic_rnn/run_test.py b/bsuite/baselines/jax/actor_critic/run_test.py similarity index 96% rename from bsuite/baselines/actor_critic_rnn/run_test.py rename to bsuite/baselines/jax/actor_critic/run_test.py index a4b44560..a5f86e95 100644 --- a/bsuite/baselines/actor_critic_rnn/run_test.py +++ b/bsuite/baselines/jax/actor_critic/run_test.py @@ -21,7 +21,7 @@ from absl.testing import parameterized from bsuite import sweep -from bsuite.baselines.actor_critic_rnn import run +from bsuite.baselines.jax.actor_critic import run FLAGS = flags.FLAGS diff --git a/bsuite/baselines/actor_critic_rnn_jax/__init__.py b/bsuite/baselines/jax/actor_critic_rnn/__init__.py similarity index 82% rename from bsuite/baselines/actor_critic_rnn_jax/__init__.py rename to bsuite/baselines/jax/actor_critic_rnn/__init__.py index 77486a1c..8f5f4425 100644 --- a/bsuite/baselines/actor_critic_rnn_jax/__init__.py +++ b/bsuite/baselines/jax/actor_critic_rnn/__init__.py @@ -16,5 +16,5 @@ # ============================================================================ """A simple actor-critic implementation in JAX.""" -from bsuite.baselines.actor_critic_rnn_jax.actor_critic_rnn import ActorCriticRNN -from bsuite.baselines.actor_critic_rnn_jax.actor_critic_rnn import default_agent +from bsuite.baselines.jax.actor_critic_rnn.agent import ActorCriticRNN +from bsuite.baselines.jax.actor_critic_rnn.agent import default_agent diff --git a/bsuite/baselines/actor_critic_rnn_jax/actor_critic_rnn.py b/bsuite/baselines/jax/actor_critic_rnn/agent.py similarity index 100% rename from bsuite/baselines/actor_critic_rnn_jax/actor_critic_rnn.py rename to bsuite/baselines/jax/actor_critic_rnn/agent.py diff --git a/bsuite/baselines/actor_critic_rnn_jax/run.py b/bsuite/baselines/jax/actor_critic_rnn/run.py similarity index 96% rename from bsuite/baselines/actor_critic_rnn_jax/run.py rename to bsuite/baselines/jax/actor_critic_rnn/run.py index 46faec73..3b213b9e 100644 --- a/bsuite/baselines/actor_critic_rnn_jax/run.py +++ b/bsuite/baselines/jax/actor_critic_rnn/run.py @@ -22,8 +22,8 @@ import bsuite from bsuite import sweep -from bsuite.baselines import actor_critic_rnn_jax from bsuite.baselines import experiment +from bsuite.baselines.jax import actor_critic_rnn from bsuite.baselines.utils import pool # Internal imports. @@ -52,7 +52,7 @@ def run(bsuite_id: str) -> str: overwrite=FLAGS.overwrite, ) - agent = actor_critic_rnn_jax.default_agent( + agent = actor_critic_rnn.default_agent( env.observation_spec(), env.action_spec()) num_episodes = FLAGS.num_episodes or getattr(env, 'bsuite_num_episodes') diff --git a/bsuite/baselines/actor_critic_rnn_jax/run_test.py b/bsuite/baselines/jax/actor_critic_rnn/run_test.py similarity index 95% rename from bsuite/baselines/actor_critic_rnn_jax/run_test.py rename to bsuite/baselines/jax/actor_critic_rnn/run_test.py index 68cb3224..ce992aa7 100644 --- a/bsuite/baselines/actor_critic_rnn_jax/run_test.py +++ b/bsuite/baselines/jax/actor_critic_rnn/run_test.py @@ -21,7 +21,7 @@ from absl.testing import parameterized from bsuite import sweep -from bsuite.baselines.actor_critic_rnn_jax import run +from bsuite.baselines.jax.actor_critic_rnn import run FLAGS = flags.FLAGS diff --git a/bsuite/baselines/dqn_jax/__init__.py b/bsuite/baselines/jax/dqn/__init__.py similarity index 88% rename from bsuite/baselines/dqn_jax/__init__.py rename to bsuite/baselines/jax/dqn/__init__.py index 830ea1da..17463d5f 100644 --- a/bsuite/baselines/dqn_jax/__init__.py +++ b/bsuite/baselines/jax/dqn/__init__.py @@ -16,5 +16,5 @@ # ============================================================================ """A simple DQN agent implemented in JAX.""" -from bsuite.baselines.dqn_jax.dqn import default_agent -from bsuite.baselines.dqn_jax.dqn import DQN +from bsuite.baselines.jax.dqn.agent import default_agent +from bsuite.baselines.jax.dqn.agent import DQN diff --git a/bsuite/baselines/dqn_jax/dqn.py b/bsuite/baselines/jax/dqn/agent.py similarity index 100% rename from bsuite/baselines/dqn_jax/dqn.py rename to bsuite/baselines/jax/dqn/agent.py diff --git a/bsuite/baselines/dqn_jax/run.py b/bsuite/baselines/jax/dqn/run.py similarity index 95% rename from bsuite/baselines/dqn_jax/run.py rename to bsuite/baselines/jax/dqn/run.py index f0f704c1..f9f70d58 100644 --- a/bsuite/baselines/dqn_jax/run.py +++ b/bsuite/baselines/jax/dqn/run.py @@ -22,8 +22,8 @@ import bsuite from bsuite import sweep -from bsuite.baselines import dqn_jax from bsuite.baselines import experiment +from bsuite.baselines.jax import dqn from bsuite.baselines.utils import pool # Internal imports. @@ -52,7 +52,7 @@ def run(bsuite_id: str) -> str: overwrite=FLAGS.overwrite, ) - agent = dqn_jax.default_agent(env.observation_spec(), env.action_spec()) + agent = dqn.default_agent(env.observation_spec(), env.action_spec()) num_episodes = FLAGS.num_episodes or getattr(env, 'bsuite_num_episodes') experiment.run( diff --git a/bsuite/baselines/dqn_jax/run_test.py b/bsuite/baselines/jax/dqn/run_test.py similarity index 96% rename from bsuite/baselines/dqn_jax/run_test.py rename to bsuite/baselines/jax/dqn/run_test.py index 8b0e4ba2..b2eeafe5 100644 --- a/bsuite/baselines/dqn_jax/run_test.py +++ b/bsuite/baselines/jax/dqn/run_test.py @@ -21,7 +21,7 @@ from absl.testing import parameterized from bsuite import sweep -from bsuite.baselines.dqn_jax import run +from bsuite.baselines.jax.dqn import run FLAGS = flags.FLAGS diff --git a/bsuite/baselines/random/__init__.py b/bsuite/baselines/random/__init__.py index e6137b9a..6255c10f 100644 --- a/bsuite/baselines/random/__init__.py +++ b/bsuite/baselines/random/__init__.py @@ -16,5 +16,5 @@ # ============================================================================ """An agent that takes uniformly random actions.""" -from bsuite.baselines.random.random import default_agent -from bsuite.baselines.random.random import Random +from bsuite.baselines.random.agent import default_agent +from bsuite.baselines.random.agent import Random diff --git a/bsuite/baselines/random/random.py b/bsuite/baselines/random/agent.py similarity index 100% rename from bsuite/baselines/random/random.py rename to bsuite/baselines/random/agent.py diff --git a/bsuite/baselines/actor_critic/__init__.py b/bsuite/baselines/tf/actor_critic/__init__.py similarity index 79% rename from bsuite/baselines/actor_critic/__init__.py rename to bsuite/baselines/tf/actor_critic/__init__.py index 3db3a000..4195235b 100644 --- a/bsuite/baselines/actor_critic/__init__.py +++ b/bsuite/baselines/tf/actor_critic/__init__.py @@ -16,6 +16,6 @@ # ============================================================================ """A simple TensorFlow 2-based implementation of the actor-critic algorithm.""" -from bsuite.baselines.actor_critic.actor_critic import ActorCritic -from bsuite.baselines.actor_critic.actor_critic import default_agent -from bsuite.baselines.actor_critic.actor_critic import PolicyValueNet +from bsuite.baselines.tf.actor_critic.agent import ActorCritic +from bsuite.baselines.tf.actor_critic.agent import default_agent +from bsuite.baselines.tf.actor_critic.agent import PolicyValueNet diff --git a/bsuite/baselines/actor_critic/actor_critic.py b/bsuite/baselines/tf/actor_critic/agent.py similarity index 100% rename from bsuite/baselines/actor_critic/actor_critic.py rename to bsuite/baselines/tf/actor_critic/agent.py diff --git a/bsuite/baselines/actor_critic/run.py b/bsuite/baselines/tf/actor_critic/run.py similarity index 98% rename from bsuite/baselines/actor_critic/run.py rename to bsuite/baselines/tf/actor_critic/run.py index f1f690ba..e0ed11c2 100644 --- a/bsuite/baselines/actor_critic/run.py +++ b/bsuite/baselines/tf/actor_critic/run.py @@ -22,8 +22,8 @@ import bsuite from bsuite import sweep -from bsuite.baselines import actor_critic from bsuite.baselines import experiment +from bsuite.baselines.tf import actor_critic from bsuite.baselines.utils import pool import sonnet as snt diff --git a/bsuite/baselines/actor_critic/run_test.py b/bsuite/baselines/tf/actor_critic/run_test.py similarity index 96% rename from bsuite/baselines/actor_critic/run_test.py rename to bsuite/baselines/tf/actor_critic/run_test.py index c2582b32..eb6d86f1 100644 --- a/bsuite/baselines/actor_critic/run_test.py +++ b/bsuite/baselines/tf/actor_critic/run_test.py @@ -21,7 +21,7 @@ from absl.testing import parameterized from bsuite import sweep -from bsuite.baselines.actor_critic import run +from bsuite.baselines.tf.actor_critic import run FLAGS = flags.FLAGS diff --git a/bsuite/baselines/actor_critic_rnn/__init__.py b/bsuite/baselines/tf/actor_critic_rnn/__init__.py similarity index 77% rename from bsuite/baselines/actor_critic_rnn/__init__.py rename to bsuite/baselines/tf/actor_critic_rnn/__init__.py index cf4ef034..e4da78ad 100644 --- a/bsuite/baselines/actor_critic_rnn/__init__.py +++ b/bsuite/baselines/tf/actor_critic_rnn/__init__.py @@ -16,6 +16,6 @@ # ============================================================================ """A simple TensorFlow 2-based implementation of a recurrent actor-critic.""" -from bsuite.baselines.actor_critic_rnn.actor_critic_rnn import ActorCriticRNN -from bsuite.baselines.actor_critic_rnn.actor_critic_rnn import default_agent -from bsuite.baselines.actor_critic_rnn.actor_critic_rnn import PolicyValueRNN +from bsuite.baselines.tf.actor_critic_rnn.agent import ActorCriticRNN +from bsuite.baselines.tf.actor_critic_rnn.agent import default_agent +from bsuite.baselines.tf.actor_critic_rnn.agent import PolicyValueRNN diff --git a/bsuite/baselines/actor_critic_rnn/actor_critic_rnn.py b/bsuite/baselines/tf/actor_critic_rnn/agent.py similarity index 100% rename from bsuite/baselines/actor_critic_rnn/actor_critic_rnn.py rename to bsuite/baselines/tf/actor_critic_rnn/agent.py diff --git a/bsuite/baselines/actor_critic_rnn/run.py b/bsuite/baselines/tf/actor_critic_rnn/run.py similarity index 98% rename from bsuite/baselines/actor_critic_rnn/run.py rename to bsuite/baselines/tf/actor_critic_rnn/run.py index af37be48..4179b344 100644 --- a/bsuite/baselines/actor_critic_rnn/run.py +++ b/bsuite/baselines/tf/actor_critic_rnn/run.py @@ -22,8 +22,8 @@ import bsuite from bsuite import sweep -from bsuite.baselines import actor_critic_rnn from bsuite.baselines import experiment +from bsuite.baselines.tf import actor_critic_rnn from bsuite.baselines.utils import pool import sonnet as snt diff --git a/bsuite/baselines/actor_critic_jax/run_test.py b/bsuite/baselines/tf/actor_critic_rnn/run_test.py similarity index 95% rename from bsuite/baselines/actor_critic_jax/run_test.py rename to bsuite/baselines/tf/actor_critic_rnn/run_test.py index bcfc061f..52239911 100644 --- a/bsuite/baselines/actor_critic_jax/run_test.py +++ b/bsuite/baselines/tf/actor_critic_rnn/run_test.py @@ -21,7 +21,7 @@ from absl.testing import parameterized from bsuite import sweep -from bsuite.baselines.actor_critic_jax import run +from bsuite.baselines.tf.actor_critic_rnn import run FLAGS = flags.FLAGS diff --git a/bsuite/baselines/boot_dqn/__init__.py b/bsuite/baselines/tf/boot_dqn/__init__.py similarity index 81% rename from bsuite/baselines/boot_dqn/__init__.py rename to bsuite/baselines/tf/boot_dqn/__init__.py index 25addaf5..aabe0f14 100644 --- a/bsuite/baselines/boot_dqn/__init__.py +++ b/bsuite/baselines/tf/boot_dqn/__init__.py @@ -16,6 +16,6 @@ # ============================================================================ """A simple implementation of Bootstrapped DQN with prior networks.""" -from bsuite.baselines.boot_dqn.boot_dqn import BootstrappedDqn -from bsuite.baselines.boot_dqn.boot_dqn import default_agent -from bsuite.baselines.boot_dqn.boot_dqn import make_ensemble +from bsuite.baselines.tf.boot_dqn.agent import BootstrappedDqn +from bsuite.baselines.tf.boot_dqn.agent import default_agent +from bsuite.baselines.tf.boot_dqn.agent import make_ensemble diff --git a/bsuite/baselines/boot_dqn/boot_dqn.py b/bsuite/baselines/tf/boot_dqn/agent.py similarity index 100% rename from bsuite/baselines/boot_dqn/boot_dqn.py rename to bsuite/baselines/tf/boot_dqn/agent.py diff --git a/bsuite/baselines/boot_dqn/run.py b/bsuite/baselines/tf/boot_dqn/run.py similarity index 99% rename from bsuite/baselines/boot_dqn/run.py rename to bsuite/baselines/tf/boot_dqn/run.py index 974736f9..cbc4b650 100644 --- a/bsuite/baselines/boot_dqn/run.py +++ b/bsuite/baselines/tf/boot_dqn/run.py @@ -22,8 +22,8 @@ import bsuite from bsuite import sweep -from bsuite.baselines import boot_dqn from bsuite.baselines import experiment +from bsuite.baselines.tf import boot_dqn from bsuite.baselines.utils import pool import sonnet as snt diff --git a/bsuite/baselines/boot_dqn/run_test.py b/bsuite/baselines/tf/boot_dqn/run_test.py similarity index 96% rename from bsuite/baselines/boot_dqn/run_test.py rename to bsuite/baselines/tf/boot_dqn/run_test.py index d2b636a5..3b07b6a0 100644 --- a/bsuite/baselines/boot_dqn/run_test.py +++ b/bsuite/baselines/tf/boot_dqn/run_test.py @@ -21,7 +21,7 @@ from absl.testing import parameterized from bsuite import sweep -from bsuite.baselines.boot_dqn import run +from bsuite.baselines.tf.boot_dqn import run FLAGS = flags.FLAGS diff --git a/bsuite/baselines/dqn/__init__.py b/bsuite/baselines/tf/dqn/__init__.py similarity index 88% rename from bsuite/baselines/dqn/__init__.py rename to bsuite/baselines/tf/dqn/__init__.py index 2570438b..a51860ab 100644 --- a/bsuite/baselines/dqn/__init__.py +++ b/bsuite/baselines/tf/dqn/__init__.py @@ -16,5 +16,5 @@ # ============================================================================ """A simple TensorFlow 2-based DQN implementation.""" -from bsuite.baselines.dqn.dqn import default_agent -from bsuite.baselines.dqn.dqn import DQN +from bsuite.baselines.tf.dqn.agent import default_agent +from bsuite.baselines.tf.dqn.agent import DQN diff --git a/bsuite/baselines/dqn/dqn.py b/bsuite/baselines/tf/dqn/agent.py similarity index 100% rename from bsuite/baselines/dqn/dqn.py rename to bsuite/baselines/tf/dqn/agent.py diff --git a/bsuite/baselines/dqn/run.py b/bsuite/baselines/tf/dqn/run.py similarity index 99% rename from bsuite/baselines/dqn/run.py rename to bsuite/baselines/tf/dqn/run.py index e010a20e..d097f677 100644 --- a/bsuite/baselines/dqn/run.py +++ b/bsuite/baselines/tf/dqn/run.py @@ -22,8 +22,8 @@ import bsuite from bsuite import sweep -from bsuite.baselines import dqn from bsuite.baselines import experiment +from bsuite.baselines.tf import dqn from bsuite.baselines.utils import pool import sonnet as snt diff --git a/bsuite/baselines/dqn/run_test.py b/bsuite/baselines/tf/dqn/run_test.py similarity index 96% rename from bsuite/baselines/dqn/run_test.py rename to bsuite/baselines/tf/dqn/run_test.py index c193b777..6d7c5960 100644 --- a/bsuite/baselines/dqn/run_test.py +++ b/bsuite/baselines/tf/dqn/run_test.py @@ -21,7 +21,7 @@ from absl.testing import parameterized from bsuite import sweep -from bsuite.baselines.dqn import run +from bsuite.baselines.tf.dqn import run FLAGS = flags.FLAGS diff --git a/bsuite/baselines/dopamine_dqn/__init__.py b/bsuite/baselines/third_party/dopamine_dqn/__init__.py similarity index 100% rename from bsuite/baselines/dopamine_dqn/__init__.py rename to bsuite/baselines/third_party/dopamine_dqn/__init__.py diff --git a/bsuite/baselines/dopamine_dqn/run.py b/bsuite/baselines/third_party/dopamine_dqn/run.py similarity index 100% rename from bsuite/baselines/dopamine_dqn/run.py rename to bsuite/baselines/third_party/dopamine_dqn/run.py diff --git a/bsuite/baselines/dopamine_dqn/run_test.py b/bsuite/baselines/third_party/dopamine_dqn/run_test.py similarity index 95% rename from bsuite/baselines/dopamine_dqn/run_test.py rename to bsuite/baselines/third_party/dopamine_dqn/run_test.py index c83d475b..0a5274a1 100644 --- a/bsuite/baselines/dopamine_dqn/run_test.py +++ b/bsuite/baselines/third_party/dopamine_dqn/run_test.py @@ -21,7 +21,7 @@ from absl.testing import parameterized from bsuite import sweep -from bsuite.baselines.dopamine_dqn import run +from bsuite.baselines.third_party.dopamine_dqn import run import tensorflow.compat.v1 as tf diff --git a/bsuite/baselines/openai_dqn/__init__.py b/bsuite/baselines/third_party/openai_dqn/__init__.py similarity index 98% rename from bsuite/baselines/openai_dqn/__init__.py rename to bsuite/baselines/third_party/openai_dqn/__init__.py index e9575713..767d200a 100644 --- a/bsuite/baselines/openai_dqn/__init__.py +++ b/bsuite/baselines/third_party/openai_dqn/__init__.py @@ -1,3 +1,4 @@ +# python3 # pylint: disable=g-bad-file-header # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. # diff --git a/bsuite/baselines/openai_dqn/run.py b/bsuite/baselines/third_party/openai_dqn/run.py similarity index 100% rename from bsuite/baselines/openai_dqn/run.py rename to bsuite/baselines/third_party/openai_dqn/run.py diff --git a/bsuite/baselines/openai_dqn/run_test.py b/bsuite/baselines/third_party/openai_dqn/run_test.py similarity index 95% rename from bsuite/baselines/openai_dqn/run_test.py rename to bsuite/baselines/third_party/openai_dqn/run_test.py index e83e76d1..c9c46944 100644 --- a/bsuite/baselines/openai_dqn/run_test.py +++ b/bsuite/baselines/third_party/openai_dqn/run_test.py @@ -21,7 +21,7 @@ from absl.testing import parameterized from bsuite import sweep -from bsuite.baselines.openai_dqn import run +from bsuite.baselines.third_party.openai_dqn import run import tensorflow.compat.v1 as tf diff --git a/bsuite/baselines/openai_ppo/__init__.py b/bsuite/baselines/third_party/openai_ppo/__init__.py similarity index 98% rename from bsuite/baselines/openai_ppo/__init__.py rename to bsuite/baselines/third_party/openai_ppo/__init__.py index e9575713..767d200a 100644 --- a/bsuite/baselines/openai_ppo/__init__.py +++ b/bsuite/baselines/third_party/openai_ppo/__init__.py @@ -1,3 +1,4 @@ +# python3 # pylint: disable=g-bad-file-header # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. # diff --git a/bsuite/baselines/openai_ppo/run.py b/bsuite/baselines/third_party/openai_ppo/run.py similarity index 100% rename from bsuite/baselines/openai_ppo/run.py rename to bsuite/baselines/third_party/openai_ppo/run.py diff --git a/bsuite/baselines/openai_ppo/run_test.py b/bsuite/baselines/third_party/openai_ppo/run_test.py similarity index 94% rename from bsuite/baselines/openai_ppo/run_test.py rename to bsuite/baselines/third_party/openai_ppo/run_test.py index 7cd70832..b002ed7c 100644 --- a/bsuite/baselines/openai_ppo/run_test.py +++ b/bsuite/baselines/third_party/openai_ppo/run_test.py @@ -1,3 +1,4 @@ +# python3 # pylint: disable=g-bad-file-header # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. # @@ -18,7 +19,7 @@ from absl import flags from absl.testing import absltest -from bsuite.baselines.openai_ppo import run +from bsuite.baselines.third_party.openai_ppo import run FLAGS = flags.FLAGS