Skip to content

Commit

Permalink
fixed some issues + acktr working with atari
Browse files Browse the repository at this point in the history
  • Loading branch information
hill-a committed Jul 5, 2018
1 parent cb3842e commit b1df898
Show file tree
Hide file tree
Showing 16 changed files with 322 additions and 332 deletions.
32 changes: 16 additions & 16 deletions baselines/a2c/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
from baselines.common.input import observation_input


def nature_cnn(unscaled_images):
def nature_cnn(unscaled_images, **kwargs):
"""
CNN from Nature paper.
"""
scaled_images = tf.cast(unscaled_images, tf.float32) / 255.
activ = tf.nn.relu
h = activ(conv(scaled_images, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2)))
h2 = activ(conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2)))
h3 = activ(conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2)))
h = activ(conv(scaled_images, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2), **kwargs))
h2 = activ(conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), **kwargs))
h3 = activ(conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2), **kwargs))
h3 = conv_to_fc(h3)
return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2)))
return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2), **kwargs))


class A2CPolicy(object):
Expand Down Expand Up @@ -96,10 +96,10 @@ def value(self, obs, state, mask):


class LnLstmPolicy(A2CPolicy):
def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=False):
def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=False, **kwargs):
super(LnLstmPolicy, self).__init__(sess, ob_space, ac_space, nbatch, nsteps, nlstm, reuse)
with tf.variable_scope("model", reuse=reuse):
h = nature_cnn(self.processed_x)
h = nature_cnn(self.processed_x, **kwargs)
xs = batch_to_seq(h, self.nenv, nsteps)
ms = batch_to_seq(self.masks_ph, self.nenv, nsteps)
h5, self.snew = lnlstm(xs, ms, self.states_ph, 'lstm1', nh=nlstm)
Expand All @@ -122,10 +122,10 @@ def value(self, obs, state, mask):


class LstmPolicy(A2CPolicy):
def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=False):
def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=False, **kwargs):
super(LstmPolicy, self).__init__(sess, ob_space, ac_space, nbatch, nsteps, nlstm, reuse)
with tf.variable_scope("model", reuse=reuse):
h = nature_cnn(self.obs_ph)
h = nature_cnn(self.obs_ph, **kwargs)
xs = batch_to_seq(h, self.nenv, nsteps)
ms = batch_to_seq(self.masks_ph, self.nenv, nsteps)
h5, self.snew = lstm(xs, ms, self.states_ph, 'lstm1', nh=nlstm)
Expand All @@ -148,10 +148,10 @@ def value(self, obs, state, mask):


class CnnPolicy(A2CPolicy):
def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=False):
def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=False, **kwargs):
super(CnnPolicy, self).__init__(sess, ob_space, ac_space, nbatch, nsteps, nlstm, reuse)
with tf.variable_scope("model", reuse=reuse):
h = nature_cnn(self.processed_x)
h = nature_cnn(self.processed_x, **kwargs)
vf = fc(h, 'v', 1)[:, 0]
self.pd, self.pi = self.pdtype.pdfromlatent(h, init_scale=0.01)

Expand All @@ -169,15 +169,15 @@ def value(self, obs, *args, **kwargs):


class MlpPolicy(A2CPolicy):
def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=False):
def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=False, **kwargs):
super(MlpPolicy, self).__init__(sess, ob_space, ac_space, nbatch, nsteps, nlstm, reuse)
with tf.variable_scope("model", reuse=reuse):
activ = tf.tanh
processed_x = tf.layers.flatten(self.processed_x)
pi_h1 = activ(fc(processed_x, 'pi_fc1', nh=64, init_scale=np.sqrt(2)))
pi_h2 = activ(fc(pi_h1, 'pi_fc2', nh=64, init_scale=np.sqrt(2)))
vf_h1 = activ(fc(processed_x, 'vf_fc1', nh=64, init_scale=np.sqrt(2)))
vf_h2 = activ(fc(vf_h1, 'vf_fc2', nh=64, init_scale=np.sqrt(2)))
pi_h1 = activ(fc(processed_x, 'pi_fc1', nh=64, init_scale=np.sqrt(2), **kwargs))
pi_h2 = activ(fc(pi_h1, 'pi_fc2', nh=64, init_scale=np.sqrt(2), **kwargs))
vf_h1 = activ(fc(processed_x, 'vf_fc1', nh=64, init_scale=np.sqrt(2), **kwargs))
vf_h2 = activ(fc(vf_h1, 'vf_fc2', nh=64, init_scale=np.sqrt(2), **kwargs))
vf = fc(vf_h2, 'vf', 1)[:, 0]

self.pd, self.pi = self.pdtype.pdfromlatent(pi_h2, init_scale=0.01)
Expand Down
7 changes: 5 additions & 2 deletions baselines/acer/acer_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,15 @@ def custom_getter(getter, *args, **kwargs):

def train(obs, actions, rewards, dones, mus, states, masks, steps):
cur_lr = lr.value_steps(steps)
td_map = {train_model.obs_ph: obs, polyak_model.obs_ph: obs, action_ph: actions, reward_ph: rewards, done_ph: dones, mu_ph: mus, learning_rate_ph: cur_lr}
if len(states) == 0:
td_map = {train_model.obs_ph: obs, polyak_model.obs_ph: obs, action_ph: actions, reward_ph: rewards,
done_ph: dones, mu_ph: mus, learning_rate_ph: cur_lr}

if len(states) != 0:
td_map[train_model.states_ph] = states
td_map[train_model.masks_ph] = masks
td_map[polyak_model.states_ph] = states
td_map[polyak_model.masks_ph] = masks

return names_ops, sess.run(run_ops, td_map)[1:] # strip off _train

def save(save_path):
Expand Down
7 changes: 4 additions & 3 deletions baselines/acktr/acktr_disc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, policy, ob_space, ac_space, nenvs, total_timesteps, nprocs=32
stats_decay=0.99, async=1, cold_iter=10,
max_grad_norm=max_grad_norm)

optim.compute_and_apply_stats(self.joint_fisher, var_list=params)
train_op, q_runner = optim.apply_gradients(list(zip(grads, params)))
self.q_runner = q_runner
self.lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)
Expand All @@ -65,10 +66,10 @@ def train(obs, states, rewards, masks, actions, values):
for step in range(len(obs)):
cur_lr = self.lr.value()

td_map = {train_model.X: obs, action_ph: actions, advs_ph: advs, rewards_ph: rewards, pg_lr_ph: cur_lr}
td_map = {train_model.obs_ph: obs, action_ph: actions, advs_ph: advs, rewards_ph: rewards, pg_lr_ph: cur_lr}
if states is not None:
td_map[train_model.S] = states
td_map[train_model.M] = masks
td_map[train_model.states_ph] = states
td_map[train_model.masks_ph] = masks

policy_loss, value_loss, policy_entropy, _ = sess.run(
[pg_loss, vf_loss, entropy, train_op],
Expand Down
22 changes: 9 additions & 13 deletions baselines/acktr/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,12 @@ def search_factors(gradient, graph):
if 'gradientsSampled' in _i.name if 'Shape' not in _i.name]
if len(b_inputs_list) > 0:
b_tensor = b_inputs_list[0]
b_tensor_shape = fprop_op.outputs[0].get_shape()
if len(b_tensor.get_shape()) > 0 and b_tensor.get_shape()[0].value is None:
b_tensor.set_shape(b_tensor_shape)
b_tensors.append(b_tensor)
# only if tensor shape is defined, usually this will prevent tensor like Sum:0 to be used.
if b_tensor.get_shape():
b_tensor_shape = fprop_op.outputs[0].get_shape()
if len(b_tensor.get_shape()) > 0 and b_tensor.get_shape()[0].value is None:
b_tensor.set_shape(b_tensor_shape)
b_tensors.append(b_tensor)
fprop_op_name = op_types.append('UNK-' + fprop_op.op_def.name)

return {'opName': fprop_op_name, 'op': fprop_op, 'fpropFactors': f_tensors, 'bpropFactors': b_tensors}
Expand Down Expand Up @@ -540,16 +542,11 @@ def compute_stats_eigen(self):
# TO-DO: figure out why this op has delays (possibly moving
# eigenvectors around?)
with tf.device('/cpu:0'):
# stats = [copyStats(self.fStats), copyStats(self.bStats)]
# stats = [self.fStats, self.bStats]

stats_eigen = self.stats_eigen
computed_eigen = {}
eigen_reverse_lookup = {}
update_ops = []
# sync copied stats
# with tf.control_dependencies(removeNone(stats[0]) +
# removeNone(stats[1])):
with tf.control_dependencies([]):
for stats_var in stats_eigen:
if stats_var not in computed_eigen:
Expand Down Expand Up @@ -793,8 +790,8 @@ def apply_gradients_kfac(self, grads):
factor_ops_dummy = self.compute_stats_eigen()

# define a queue for the list of factor loading tensors
queue = tf.FIFOQueue(1, [item.dtype for item in factor_ops_dummy], shapes=[
item.get_shape() for item in factor_ops_dummy])
queue = tf.FIFOQueue(1, [item.dtype for item in factor_ops_dummy],
shapes=[item.get_shape() for item in factor_ops_dummy])
enqueue_op = tf.cond(
tf.logical_and(tf.equal(tf.mod(self.stats_step, self._kfac_update), tf.convert_to_tensor(
0)), tf.greater_equal(self.stats_step, self._stats_accum_iter)),
Expand Down Expand Up @@ -877,8 +874,7 @@ def update_optim_op():
return tf.group(*update_ops), qr

def apply_gradients(self, grads):
cold_optim = tf.train.MomentumOptimizer(
self._cold_lr, self._momentum)
cold_optim = tf.train.MomentumOptimizer(self._cold_lr, self._momentum)

def cold_sgd_start():
sgd_grads, sgd_var = zip(*grads)
Expand Down
5 changes: 1 addition & 4 deletions baselines/acktr/run_atari.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#!/usr/bin/env python3

from functools import partial

from baselines import logger
from baselines.acktr.acktr_disc import learn
from baselines.common.cmd_util import make_atari_env, atari_arg_parser
Expand All @@ -11,8 +9,7 @@

def train(env_id, num_timesteps, seed, num_cpu):
env = VecFrameStack(make_atari_env(env_id, num_cpu, seed), 4)
policy_fn = partial(CnnPolicy, one_dim_bias=True)
learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1), nprocs=num_cpu)
learn(CnnPolicy, env, seed, total_timesteps=int(num_timesteps * 1.1), nprocs=num_cpu)
env.close()


Expand Down
4 changes: 0 additions & 4 deletions baselines/common/tests/test_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@


def clear_tf_session():
sess = tf.get_default_session()
while sess is not None:
sess.close()
sess = tf.get_default_session()
tf.reset_default_graph()


Expand Down
Loading

0 comments on commit b1df898

Please sign in to comment.