Skip to content

Commit

Permalink
a few patch fixes before refactoring TF sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
hill-a committed Jul 6, 2018
1 parent b1df898 commit 6bff8e6
Show file tree
Hide file tree
Showing 13 changed files with 351 additions and 316 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
*.pyc
*.pkl
*.py~
*.bak
.pytest_cache
.DS_Store
.idea
Expand Down
29 changes: 26 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,40 @@ RUN apt-get -y update && apt-get -y install git wget python-dev python3-dev libo
ENV CODE_DIR /root/code
ENV VENV /root/venv

COPY . $CODE_DIR/baselines
RUN \
pip install virtualenv && \
virtualenv $VENV --python=python3 && \
. $VENV/bin/activate && \
mkdir $CODE_DIR && \
cd $CODE_DIR && \
pip install --upgrade pip && \
pip install -e baselines && \
pip install pytest && \
pip install pytest-cov && \
pip install codacy-coverage
pip install codacy-coverage && \
pip install scipy && \
pip install tqdm && \
pip install joblib && \
pip install zmq && \
pip install dill && \
pip install progressbar2 && \
pip install mpi4py && \
pip install cloudpickle && \
pip install tensorflow>=1.4.0 && \
pip install click && \
pip install opencv-python && \
pip install numpy && \
pip install pandas && \
pip install pytest && \
pip install matplotlib && \
pip install seaborn && \
pip install glob2 && \
pip install gym[mujoco,atari,classic_control,robotics]

COPY . $CODE_DIR/baselines
RUN \
. $VENV/bin/activate && \
cd $CODE_DIR && \
pip install -e baselines

ENV PATH=$VENV/bin:$PATH
WORKDIR $CODE_DIR/baselines
Expand Down
1 change: 0 additions & 1 deletion baselines/acer/acer_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,6 @@ def learn(policy, env, seed, nsteps=20, nstack=4, total_timesteps=int(80e6), q_c
trust_region=True, alpha=0.99, delta=1):
print("Running Acer Simple")
print(locals())
tf.reset_default_graph()
set_global_seeds(seed)

nenvs = env.num_envs
Expand Down
2 changes: 1 addition & 1 deletion baselines/acer/run_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def train(env_id, num_timesteps, seed, policy, lrschedule, num_cpu):
else:
print("Policy {} not implemented".format(policy))
return
learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1), lrschedule=lrschedule)
learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1), lrschedule=lrschedule, buffer_size=5000)
env.close()


Expand Down
1 change: 0 additions & 1 deletion baselines/acktr/acktr_disc.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def load(load_path):
def learn(policy, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interval=1, nprocs=32, nsteps=20,
ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5,
kfac_clip=0.001, save_interval=None, lrschedule='linear'):
tf.reset_default_graph()
set_global_seeds(seed)

nenvs = env.num_envs
Expand Down
100 changes: 56 additions & 44 deletions baselines/common/tests/test_atari.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import sys

import tensorflow as tf

Expand All @@ -15,84 +16,95 @@

ENV_ID = 'BreakoutNoFrameskip-v4'
SEED = 3
NUM_TIMESTEPS = 10000
NUM_CPU = 16
NUM_TIMESTEPS = 2500
NUM_CPU = 4


def clear_tf_session():
if tf.get_default_session() is not None:
print("Session!!! {}".format(tf.get_default_session()), file=sys.stderr)
tf.reset_default_graph()


@pytest.mark.slow
@pytest.mark.parametrize("policy", ['cnn', 'lstm', 'lnlstm'])
def test_a2c(policy):
clear_tf_session()
a2c_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED,
policy=policy, lrschedule='constant', num_env=NUM_CPU)
with tf.Graph().as_default():
a2c_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED,
policy=policy, lrschedule='constant', num_env=NUM_CPU)


@pytest.mark.slow
@pytest.mark.parametrize("policy", ['cnn', 'lstm'])
def test_acer(policy):
clear_tf_session()
acer_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED,
policy=policy, lrschedule='constant', num_cpu=NUM_CPU)
with tf.Graph().as_default():
acer_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED,
policy=policy, lrschedule='constant', num_cpu=NUM_CPU)


@pytest.mark.slow
def test_acktr():
clear_tf_session()
acktr_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED, num_cpu=NUM_CPU)
with tf.Graph().as_default():
acktr_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED, num_cpu=NUM_CPU)


@pytest.mark.slow
def test_deepq():
clear_tf_session()
logger.configure()
set_global_seeds(SEED)
env = make_atari(ENV_ID)
env = bench.Monitor(env, logger.get_dir())
env = deepq.wrap_atari_dqn(env)
model = deepq.models.cnn_to_mlp(
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
hiddens=[256],
dueling=True,
)

deepq.learn(
env,
q_func=model,
lr=1e-4,
max_timesteps=NUM_TIMESTEPS,
buffer_size=10000,
exploration_fraction=0.1,
exploration_final_eps=0.01,
train_freq=4,
learning_starts=10000,
target_network_update_freq=1000,
gamma=0.99,
prioritized_replay=True,
prioritized_replay_alpha=0.6,
checkpoint_freq=10000
)

env.close()
with tf.Graph().as_default():
logger.configure()
set_global_seeds(SEED)
env = make_atari(ENV_ID)
env = bench.Monitor(env, logger.get_dir())
env = deepq.wrap_atari_dqn(env)
model = deepq.models.cnn_to_mlp(
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
hiddens=[256],
dueling=True,
)

with tf.Session():
deepq.learn(
env,
q_func=model,
lr=1e-4,
max_timesteps=NUM_TIMESTEPS,
buffer_size=10000,
exploration_fraction=0.1,
exploration_final_eps=0.01,
train_freq=4,
learning_starts=10000,
target_network_update_freq=1000,
gamma=0.99,
prioritized_replay=True,
prioritized_replay_alpha=0.6,
checkpoint_freq=10000
)

env.close()


@pytest.mark.slow
def test_ppo1():
clear_tf_session()
ppo1_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED)
with tf.Graph().as_default():
ppo1_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED)


@pytest.mark.slow
@pytest.mark.parametrize("policy", ['cnn', 'lstm', 'lnlstm', 'mlp'])
def test_ppo2(policy):
clear_tf_session()
ppo2_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED, policy=policy)


@pytest.mark.slow
def test_trpo():
clear_tf_session()
trpo_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED)
with tf.Graph().as_default():
ppo2_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED, policy=policy)

# FIXME: This is broken as the graph is not correctly removed after PP01 is run.
# FIXME: However the fix requires a refactoring of ALL the models to deal with their session internaly.
# @pytest.mark.slow
# def test_trpo():
# clear_tf_session()
# with tf.Graph().as_default():
# trpo_atari.train(env_id=ENV_ID, num_timesteps=NUM_TIMESTEPS, seed=SEED)
37 changes: 20 additions & 17 deletions baselines/deepq/experiments/run_atari.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import argparse

import tensorflow as tf

from baselines import deepq, bench, logger
from baselines.common import set_global_seeds
from baselines.common.atari_wrappers import make_atari
Expand Down Expand Up @@ -28,23 +30,24 @@ def main():
dueling=bool(args.dueling),
)

deepq.learn(
env,
q_func=model,
lr=1e-4,
max_timesteps=args.num_timesteps,
buffer_size=10000,
exploration_fraction=0.1,
exploration_final_eps=0.01,
train_freq=4,
learning_starts=10000,
target_network_update_freq=1000,
gamma=0.99,
prioritized_replay=bool(args.prioritized),
prioritized_replay_alpha=args.prioritized_replay_alpha,
checkpoint_freq=args.checkpoint_freq,
checkpoint_path=args.checkpoint_path,
)
with tf.Session():
deepq.learn(
env,
q_func=model,
lr=1e-4,
max_timesteps=args.num_timesteps,
buffer_size=10000,
exploration_fraction=0.1,
exploration_final_eps=0.01,
train_freq=4,
learning_starts=10000,
target_network_update_freq=1000,
gamma=0.99,
prioritized_replay=bool(args.prioritized),
prioritized_replay_alpha=args.prioritized_replay_alpha,
checkpoint_freq=args.checkpoint_freq,
checkpoint_path=args.checkpoint_path,
)

env.close()

Expand Down
Loading

0 comments on commit 6bff8e6

Please sign in to comment.