forked from openai/baselines
-
Notifications
You must be signed in to change notification settings - Fork 0
/
trpo_mpi.py
408 lines (330 loc) · 14.7 KB
/
trpo_mpi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
from baselines.common import explained_variance, zipsame, dataset
from baselines import logger
import baselines.common.tf_util as U
import tensorflow as tf, numpy as np
import time
from baselines.common import colorize
from collections import deque
from baselines.common import set_global_seeds
from baselines.common.mpi_adam import MpiAdam
from baselines.common.cg import cg
from baselines.common.input import observation_placeholder
from baselines.common.policies import build_policy
from contextlib import contextmanager
try:
from mpi4py import MPI
except ImportError:
MPI = None
def traj_segment_generator(pi, env, horizon, stochastic):
# Initialize state variables
t = 0
ac = env.action_space.sample()
new = True
rew = 0.0
ob = env.reset()
cur_ep_ret = 0
cur_ep_len = 0
ep_rets = []
ep_lens = []
# Initialize history arrays
obs = np.array([ob for _ in range(horizon)])
rews = np.zeros(horizon, 'float32')
vpreds = np.zeros(horizon, 'float32')
news = np.zeros(horizon, 'int32')
acs = np.array([ac for _ in range(horizon)])
prevacs = acs.copy()
while True:
prevac = ac
ac, vpred, _, _ = pi.step(ob, stochastic=stochastic)
# Slight weirdness here because we need value function at time T
# before returning segment [0, T-1] so we get the correct
# terminal value
if t > 0 and t % horizon == 0:
yield {"ob" : obs, "rew" : rews, "vpred" : vpreds, "new" : news,
"ac" : acs, "prevac" : prevacs, "nextvpred": vpred * (1 - new),
"ep_rets" : ep_rets, "ep_lens" : ep_lens}
_, vpred, _, _ = pi.step(ob, stochastic=stochastic)
# Be careful!!! if you change the downstream algorithm to aggregate
# several of these batches, then be sure to do a deepcopy
ep_rets = []
ep_lens = []
i = t % horizon
obs[i] = ob
vpreds[i] = vpred
news[i] = new
acs[i] = ac
prevacs[i] = prevac
ob, rew, new, _ = env.step(ac)
rews[i] = rew
cur_ep_ret += rew
cur_ep_len += 1
if new:
ep_rets.append(cur_ep_ret)
ep_lens.append(cur_ep_len)
cur_ep_ret = 0
cur_ep_len = 0
ob = env.reset()
t += 1
def add_vtarg_and_adv(seg, gamma, lam):
new = np.append(seg["new"], 0) # last element is only used for last vtarg, but we already zeroed it if last new = 1
vpred = np.append(seg["vpred"], seg["nextvpred"])
T = len(seg["rew"])
seg["adv"] = gaelam = np.empty(T, 'float32')
rew = seg["rew"]
lastgaelam = 0
for t in reversed(range(T)):
nonterminal = 1-new[t+1]
delta = rew[t] + gamma * vpred[t+1] * nonterminal - vpred[t]
gaelam[t] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
seg["tdlamret"] = seg["adv"] + seg["vpred"]
def learn(*,
network,
env,
total_timesteps,
timesteps_per_batch=1024, # what to train on
max_kl=0.001,
cg_iters=10,
gamma=0.99,
lam=1.0, # advantage estimation
seed=None,
ent_coef=0.0,
cg_damping=1e-2,
vf_stepsize=3e-4,
vf_iters =3,
max_episodes=0, max_iters=0, # time constraint
callback=None,
load_path=None,
**network_kwargs
):
'''
learn a policy function with TRPO algorithm
Parameters:
----------
network neural network to learn. Can be either string ('mlp', 'cnn', 'lstm', 'lnlstm' for basic types)
or function that takes input placeholder and returns tuple (output, None) for feedforward nets
or (output, (state_placeholder, state_output, mask_placeholder)) for recurrent nets
env environment (one of the gym environments or wrapped via baselines.common.vec_env.VecEnv-type class
timesteps_per_batch timesteps per gradient estimation batch
max_kl max KL divergence between old policy and new policy ( KL(pi_old || pi) )
ent_coef coefficient of policy entropy term in the optimization objective
cg_iters number of iterations of conjugate gradient algorithm
cg_damping conjugate gradient damping
vf_stepsize learning rate for adam optimizer used to optimie value function loss
vf_iters number of iterations of value function optimization iterations per each policy optimization step
total_timesteps max number of timesteps
max_episodes max number of episodes
max_iters maximum number of policy optimization iterations
callback function to be called with (locals(), globals()) each policy optimization step
load_path str, path to load the model from (default: None, i.e. no model is loaded)
**network_kwargs keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network
Returns:
-------
learnt model
'''
if MPI is not None:
nworkers = MPI.COMM_WORLD.Get_size()
rank = MPI.COMM_WORLD.Get_rank()
else:
nworkers = 1
rank = 0
cpus_per_worker = 1
U.get_session(config=tf.ConfigProto(
allow_soft_placement=True,
inter_op_parallelism_threads=cpus_per_worker,
intra_op_parallelism_threads=cpus_per_worker
))
policy = build_policy(env, network, value_network='copy', **network_kwargs)
set_global_seeds(seed)
np.set_printoptions(precision=3)
# Setup losses and stuff
# ----------------------------------------
ob_space = env.observation_space
ac_space = env.action_space
ob = observation_placeholder(ob_space)
with tf.variable_scope("pi"):
pi = policy(observ_placeholder=ob)
with tf.variable_scope("oldpi"):
oldpi = policy(observ_placeholder=ob)
atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return
ac = pi.pdtype.sample_placeholder([None])
kloldnew = oldpi.pd.kl(pi.pd)
ent = pi.pd.entropy()
meankl = tf.reduce_mean(kloldnew)
meanent = tf.reduce_mean(ent)
entbonus = ent_coef * meanent
vferr = tf.reduce_mean(tf.square(pi.vf - ret))
ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold
surrgain = tf.reduce_mean(ratio * atarg)
optimgain = surrgain + entbonus
losses = [optimgain, meankl, entbonus, surrgain, meanent]
loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]
dist = meankl
all_var_list = get_trainable_variables("pi")
# var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")]
# vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")]
var_list = get_pi_trainable_variables("pi")
vf_var_list = get_vf_trainable_variables("pi")
vfadam = MpiAdam(vf_var_list)
get_flat = U.GetFlat(var_list)
set_from_flat = U.SetFromFlat(var_list)
klgrads = tf.gradients(dist, var_list)
flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan")
shapes = [var.get_shape().as_list() for var in var_list]
start = 0
tangents = []
for shape in shapes:
sz = U.intprod(shape)
tangents.append(tf.reshape(flat_tangent[start:start+sz], shape))
start += sz
gvp = tf.add_n([tf.reduce_sum(g*tangent) for (g, tangent) in zipsame(klgrads, tangents)]) #pylint: disable=E1111
fvp = U.flatgrad(gvp, var_list)
assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
for (oldv, newv) in zipsame(get_variables("oldpi"), get_variables("pi"))])
compute_losses = U.function([ob, ac, atarg], losses)
compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)])
compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list))
@contextmanager
def timed(msg):
if rank == 0:
print(colorize(msg, color='magenta'))
tstart = time.time()
yield
print(colorize("done in %.3f seconds"%(time.time() - tstart), color='magenta'))
else:
yield
def allmean(x):
assert isinstance(x, np.ndarray)
if MPI is not None:
out = np.empty_like(x)
MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
out /= nworkers
else:
out = np.copy(x)
return out
U.initialize()
if load_path is not None:
pi.load(load_path)
th_init = get_flat()
if MPI is not None:
MPI.COMM_WORLD.Bcast(th_init, root=0)
set_from_flat(th_init)
vfadam.sync()
print("Init param sum", th_init.sum(), flush=True)
# Prepare for rollouts
# ----------------------------------------
seg_gen = traj_segment_generator(pi, env, timesteps_per_batch, stochastic=True)
episodes_so_far = 0
timesteps_so_far = 0
iters_so_far = 0
tstart = time.time()
lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths
rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards
if sum([max_iters>0, total_timesteps>0, max_episodes>0])==0:
# noththing to be done
return pi
assert sum([max_iters>0, total_timesteps>0, max_episodes>0]) < 2, \
'out of max_iters, total_timesteps, and max_episodes only one should be specified'
while True:
if callback: callback(locals(), globals())
if total_timesteps and timesteps_so_far >= total_timesteps:
break
elif max_episodes and episodes_so_far >= max_episodes:
break
elif max_iters and iters_so_far >= max_iters:
break
logger.log("********** Iteration %i ************"%iters_so_far)
with timed("sampling"):
seg = seg_gen.__next__()
add_vtarg_and_adv(seg, gamma, lam)
# ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
vpredbefore = seg["vpred"] # predicted value function before udpate
atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate
if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy
args = seg["ob"], seg["ac"], atarg
fvpargs = [arr[::5] for arr in args]
def fisher_vector_product(p):
return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p
assign_old_eq_new() # set old parameter values to new parameter values
with timed("computegrad"):
*lossbefore, g = compute_lossandgrad(*args)
lossbefore = allmean(np.array(lossbefore))
g = allmean(g)
if np.allclose(g, 0):
logger.log("Got zero gradient. not updating")
else:
with timed("cg"):
stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank==0)
assert np.isfinite(stepdir).all()
shs = .5*stepdir.dot(fisher_vector_product(stepdir))
lm = np.sqrt(shs / max_kl)
# logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
fullstep = stepdir / lm
expectedimprove = g.dot(fullstep)
surrbefore = lossbefore[0]
stepsize = 1.0
thbefore = get_flat()
for _ in range(10):
thnew = thbefore + fullstep * stepsize
set_from_flat(thnew)
meanlosses = surr, kl, *_ = allmean(np.array(compute_losses(*args)))
improve = surr - surrbefore
logger.log("Expected: %.3f Actual: %.3f"%(expectedimprove, improve))
if not np.isfinite(meanlosses).all():
logger.log("Got non-finite value of losses -- bad!")
elif kl > max_kl * 1.5:
logger.log("violated KL constraint. shrinking step.")
elif improve < 0:
logger.log("surrogate didn't improve. shrinking step.")
else:
logger.log("Stepsize OK!")
break
stepsize *= .5
else:
logger.log("couldn't compute a good step")
set_from_flat(thbefore)
if nworkers > 1 and iters_so_far % 20 == 0:
paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), vfadam.getflat().sum())) # list of tuples
assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:])
for (lossname, lossval) in zip(loss_names, meanlosses):
logger.record_tabular(lossname, lossval)
with timed("vf"):
for _ in range(vf_iters):
for (mbob, mbret) in dataset.iterbatches((seg["ob"], seg["tdlamret"]),
include_final_partial_batch=False, batch_size=64):
g = allmean(compute_vflossandgrad(mbob, mbret))
vfadam.update(g, vf_stepsize)
logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
if MPI is not None:
listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
else:
listoflrpairs = [lrlocal]
lens, rews = map(flatten_lists, zip(*listoflrpairs))
lenbuffer.extend(lens)
rewbuffer.extend(rews)
logger.record_tabular("EpLenMean", np.mean(lenbuffer))
logger.record_tabular("EpRewMean", np.mean(rewbuffer))
logger.record_tabular("EpThisIter", len(lens))
episodes_so_far += len(lens)
timesteps_so_far += sum(lens)
iters_so_far += 1
logger.record_tabular("EpisodesSoFar", episodes_so_far)
logger.record_tabular("TimestepsSoFar", timesteps_so_far)
logger.record_tabular("TimeElapsed", time.time() - tstart)
if rank==0:
logger.dump_tabular()
return pi
def flatten_lists(listoflists):
return [el for list_ in listoflists for el in list_]
def get_variables(scope):
return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope)
def get_trainable_variables(scope):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
def get_vf_trainable_variables(scope):
return [v for v in get_trainable_variables(scope) if 'vf' in v.name[len(scope):].split('/')]
def get_pi_trainable_variables(scope):
return [v for v in get_trainable_variables(scope) if 'pi' in v.name[len(scope):].split('/')]