This repository has been archived by the owner on Dec 11, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 462
/
ppo_agent.py
396 lines (321 loc) · 18.7 KB
/
ppo_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
from collections import OrderedDict
from typing import Union
import numpy as np
from rl_coach.agents.actor_critic_agent import ActorCriticAgent
from rl_coach.agents.policy_optimization_agent import PolicyGradientRescaler
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
from rl_coach.architectures.head_parameters import PPOHeadParameters, VHeadParameters
from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters
from rl_coach.base_parameters import AlgorithmParameters, NetworkParameters, \
AgentParameters, DistributedTaskParameters
from rl_coach.core_types import EnvironmentSteps, Batch
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
from rl_coach.exploration_policies.categorical import CategoricalParameters
from rl_coach.logger import screen
from rl_coach.memories.episodic.episodic_experience_replay import EpisodicExperienceReplayParameters
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace
from rl_coach.utils import force_list
class PPOCriticNetworkParameters(NetworkParameters):
def __init__(self):
super().__init__()
self.input_embedders_parameters = {'observation': InputEmbedderParameters(activation_function='tanh')}
self.middleware_parameters = FCMiddlewareParameters(activation_function='tanh')
self.heads_parameters = [VHeadParameters()]
self.async_training = True
self.l2_regularization = 0
self.create_target_network = True
self.batch_size = 128
class PPOActorNetworkParameters(NetworkParameters):
def __init__(self):
super().__init__()
self.input_embedders_parameters = {'observation': InputEmbedderParameters(activation_function='tanh')}
self.middleware_parameters = FCMiddlewareParameters(activation_function='tanh')
self.heads_parameters = [PPOHeadParameters()]
self.optimizer_type = 'Adam'
self.async_training = True
self.l2_regularization = 0
self.create_target_network = True
self.batch_size = 128
class PPOAlgorithmParameters(AlgorithmParameters):
"""
:param policy_gradient_rescaler: (PolicyGradientRescaler)
This represents how the critic will be used to update the actor. The critic value function is typically used
to rescale the gradients calculated by the actor. There are several ways for doing this, such as using the
advantage of the action, or the generalized advantage estimation (GAE) value.
:param gae_lambda: (float)
The :math:`\lambda` value is used within the GAE function in order to weight different bootstrap length
estimations. Typical values are in the range 0.9-1, and define an exponential decay over the different
n-step estimations.
:param target_kl_divergence: (float)
The target kl divergence between the current policy distribution and the new policy. PPO uses a heuristic to
bring the KL divergence to this value, by adding a penalty if the kl divergence is higher.
:param initial_kl_coefficient: (float)
The initial weight that will be given to the KL divergence between the current and the new policy in the
regularization factor.
:param high_kl_penalty_coefficient: (float)
The penalty that will be given for KL divergence values which are highes than what was defined as the target.
:param clip_likelihood_ratio_using_epsilon: (float)
If not None, the likelihood ratio between the current and new policy in the PPO loss function will be
clipped to the range [1-clip_likelihood_ratio_using_epsilon, 1+clip_likelihood_ratio_using_epsilon].
This is typically used in the Clipped PPO version of PPO, and should be set to None in regular PPO
implementations.
:param value_targets_mix_fraction: (float)
The targets for the value network are an exponential weighted moving average which uses this mix fraction to
define how much of the new targets will be taken into account when calculating the loss.
This value should be set to the range (0,1], where 1 means that only the new targets will be taken into account.
:param estimate_state_value_using_gae: (bool)
If set to True, the state value will be estimated using the GAE technique.
:param use_kl_regularization: (bool)
If set to True, the loss function will be regularized using the KL diveregence between the current and new
policy, to bound the change of the policy during the network update.
:param beta_entropy: (float)
An entropy regulaization term can be added to the loss function in order to control exploration. This term
is weighted using the :math:`\beta` value defined by beta_entropy.
"""
def __init__(self):
super().__init__()
self.policy_gradient_rescaler = PolicyGradientRescaler.GAE
self.gae_lambda = 0.96
self.target_kl_divergence = 0.01
self.initial_kl_coefficient = 1.0
self.high_kl_penalty_coefficient = 1000
self.clip_likelihood_ratio_using_epsilon = None
self.value_targets_mix_fraction = 0.1
self.estimate_state_value_using_gae = True
self.use_kl_regularization = True
self.beta_entropy = 0.01
self.num_consecutive_playing_steps = EnvironmentSteps(5000)
self.act_for_full_episodes = True
class PPOAgentParameters(AgentParameters):
def __init__(self):
super().__init__(algorithm=PPOAlgorithmParameters(),
exploration={DiscreteActionSpace: CategoricalParameters(),
BoxActionSpace: AdditiveNoiseParameters()},
memory=EpisodicExperienceReplayParameters(),
networks={"critic": PPOCriticNetworkParameters(), "actor": PPOActorNetworkParameters()})
@property
def path(self):
return 'rl_coach.agents.ppo_agent:PPOAgent'
# Proximal Policy Optimization - https://arxiv.org/pdf/1707.06347.pdf
class PPOAgent(ActorCriticAgent):
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
super().__init__(agent_parameters, parent)
# signals definition
self.value_loss = self.register_signal('Value Loss')
self.policy_loss = self.register_signal('Policy Loss')
self.kl_divergence = self.register_signal('KL Divergence')
self.total_kl_divergence_during_training_process = 0.0
self.unclipped_grads = self.register_signal('Grads (unclipped)')
@property
def is_on_policy(self) -> bool:
return True
def fill_advantages(self, batch):
batch = Batch(batch)
network_keys = self.ap.network_wrappers['critic'].input_embedders_parameters.keys()
# * Found not to have any impact *
# current_states_with_timestep = self.concat_state_and_timestep(batch)
current_state_values = self.networks['critic'].online_network.predict(batch.states(network_keys)).squeeze()
total_returns = batch.n_step_discounted_rewards()
# calculate advantages
advantages = []
if self.policy_gradient_rescaler == PolicyGradientRescaler.A_VALUE:
advantages = total_returns - current_state_values
elif self.policy_gradient_rescaler == PolicyGradientRescaler.GAE:
# get bootstraps
episode_start_idx = 0
advantages = np.array([])
# current_state_values[batch.game_overs()] = 0
for idx, game_over in enumerate(batch.game_overs()):
if game_over:
# get advantages for the rollout
value_bootstrapping = np.zeros((1,))
rollout_state_values = np.append(current_state_values[episode_start_idx:idx+1], value_bootstrapping)
rollout_advantages, _ = \
self.get_general_advantage_estimation_values(batch.rewards()[episode_start_idx:idx+1],
rollout_state_values)
episode_start_idx = idx + 1
advantages = np.append(advantages, rollout_advantages)
else:
screen.warning("WARNING: The requested policy gradient rescaler is not available")
# standardize
advantages = (advantages - np.mean(advantages)) / np.std(advantages)
# TODO: this will be problematic with a shared memory
for transition, advantage in zip(self.memory.transitions, advantages):
transition.info['advantage'] = advantage
self.action_advantages.add_sample(advantages)
def train_value_network(self, dataset, epochs):
loss = []
batch = Batch(dataset)
network_keys = self.ap.network_wrappers['critic'].input_embedders_parameters.keys()
# * Found not to have any impact *
# add a timestep to the observation
# current_states_with_timestep = self.concat_state_and_timestep(dataset)
mix_fraction = self.ap.algorithm.value_targets_mix_fraction
total_returns = batch.n_step_discounted_rewards(True)
for j in range(epochs):
curr_batch_size = batch.size
if self.networks['critic'].online_network.optimizer_type != 'LBFGS':
curr_batch_size = self.ap.network_wrappers['critic'].batch_size
for i in range(batch.size // curr_batch_size):
# split to batches for first order optimization techniques
current_states_batch = {
k: v[i * curr_batch_size:(i + 1) * curr_batch_size]
for k, v in batch.states(network_keys).items()
}
total_return_batch = total_returns[i * curr_batch_size:(i + 1) * curr_batch_size]
old_policy_values = force_list(self.networks['critic'].target_network.predict(
current_states_batch).squeeze())
if self.networks['critic'].online_network.optimizer_type != 'LBFGS':
targets = total_return_batch
else:
current_values = self.networks['critic'].online_network.predict(current_states_batch)
targets = current_values * (1 - mix_fraction) + total_return_batch * mix_fraction
inputs = copy.copy(current_states_batch)
for input_index, input in enumerate(old_policy_values):
name = 'output_0_{}'.format(input_index)
if name in self.networks['critic'].online_network.inputs:
inputs[name] = input
value_loss = self.networks['critic'].online_network.accumulate_gradients(inputs, targets)
self.networks['critic'].apply_gradients_to_online_network()
if isinstance(self.ap.task_parameters, DistributedTaskParameters):
self.networks['critic'].apply_gradients_to_global_network()
self.networks['critic'].online_network.reset_accumulated_gradients()
loss.append([value_loss[0]])
loss = np.mean(loss, 0)
return loss
def concat_state_and_timestep(self, dataset):
current_states_with_timestep = [np.append(transition.state['observation'], transition.info['timestep'])
for transition in dataset]
current_states_with_timestep = np.expand_dims(current_states_with_timestep, -1)
return current_states_with_timestep
def train_policy_network(self, dataset, epochs):
loss = []
for j in range(epochs):
loss = {
'total_loss': [],
'policy_losses': [],
'unclipped_grads': [],
'fetch_result': []
}
#shuffle(dataset)
for i in range(len(dataset) // self.ap.network_wrappers['actor'].batch_size):
batch = Batch(dataset[i * self.ap.network_wrappers['actor'].batch_size:
(i + 1) * self.ap.network_wrappers['actor'].batch_size])
network_keys = self.ap.network_wrappers['actor'].input_embedders_parameters.keys()
advantages = batch.info('advantage')
actions = batch.actions()
if not isinstance(self.spaces.action, DiscreteActionSpace) and len(actions.shape) == 1:
actions = np.expand_dims(actions, -1)
# get old policy probabilities and distribution
old_policy = force_list(self.networks['actor'].target_network.predict(batch.states(network_keys)))
# calculate gradients and apply on both the local policy network and on the global policy network
fetches = [self.networks['actor'].online_network.output_heads[0].kl_divergence,
self.networks['actor'].online_network.output_heads[0].entropy]
inputs = copy.copy(batch.states(network_keys))
inputs['output_0_0'] = actions
# old_policy_distribution needs to be represented as a list, because in the event of discrete controls,
# it has just a mean. otherwise, it has both a mean and standard deviation
for input_index, input in enumerate(old_policy):
inputs['output_0_{}'.format(input_index + 1)] = input
total_loss, policy_losses, unclipped_grads, fetch_result =\
self.networks['actor'].online_network.accumulate_gradients(
inputs, [advantages], additional_fetches=fetches)
self.networks['actor'].apply_gradients_to_online_network()
if isinstance(self.ap.task_parameters, DistributedTaskParameters):
self.networks['actor'].apply_gradients_to_global_network()
self.networks['actor'].online_network.reset_accumulated_gradients()
loss['total_loss'].append(total_loss)
loss['policy_losses'].append(policy_losses)
loss['unclipped_grads'].append(unclipped_grads)
loss['fetch_result'].append(fetch_result)
self.unclipped_grads.add_sample(unclipped_grads)
for key in loss.keys():
loss[key] = np.mean(loss[key], 0)
if self.ap.network_wrappers['critic'].learning_rate_decay_rate != 0:
curr_learning_rate = self.networks['critic'].online_network.get_variable_value(self.ap.learning_rate)
self.curr_learning_rate.add_sample(curr_learning_rate)
else:
curr_learning_rate = self.ap.network_wrappers['critic'].learning_rate
# log training parameters
screen.log_dict(
OrderedDict([
("Surrogate loss", loss['policy_losses'][0]),
("KL divergence", loss['fetch_result'][0]),
("Entropy", loss['fetch_result'][1]),
("training epoch", j),
("learning_rate", curr_learning_rate)
]),
prefix="Policy training"
)
self.total_kl_divergence_during_training_process = loss['fetch_result'][0]
self.entropy.add_sample(loss['fetch_result'][1])
self.kl_divergence.add_sample(loss['fetch_result'][0])
return loss['total_loss']
def update_kl_coefficient(self):
# John Schulman takes the mean kl divergence only over the last epoch which is strange but we will follow
# his implementation for now because we know it works well
screen.log_title("KL = {}".format(self.total_kl_divergence_during_training_process))
# update kl coefficient
kl_target = self.ap.algorithm.target_kl_divergence
kl_coefficient = self.networks['actor'].online_network.get_variable_value(
self.networks['actor'].online_network.output_heads[0].kl_coefficient)
new_kl_coefficient = kl_coefficient
if self.total_kl_divergence_during_training_process > 1.3 * kl_target:
# kl too high => increase regularization
new_kl_coefficient *= 1.5
elif self.total_kl_divergence_during_training_process < 0.7 * kl_target:
# kl too low => decrease regularization
new_kl_coefficient /= 1.5
# update the kl coefficient variable
if kl_coefficient != new_kl_coefficient:
self.networks['actor'].online_network.set_variable_value(
self.networks['actor'].online_network.output_heads[0].assign_kl_coefficient,
new_kl_coefficient,
self.networks['actor'].online_network.output_heads[0].kl_coefficient_ph)
screen.log_title("KL penalty coefficient change = {} -> {}".format(kl_coefficient, new_kl_coefficient))
def post_training_commands(self):
if self.ap.algorithm.use_kl_regularization:
self.update_kl_coefficient()
# clean memory
self.call_memory('clean')
def train(self):
loss = 0
if self._should_train():
for network in self.networks.values():
network.set_is_training(True)
for training_step in range(self.ap.algorithm.num_consecutive_training_steps):
self.networks['actor'].sync()
self.networks['critic'].sync()
dataset = self.memory.transitions
self.fill_advantages(dataset)
# take only the requested number of steps
dataset = dataset[:self.ap.algorithm.num_consecutive_playing_steps.num_steps]
value_loss = self.train_value_network(dataset, 1)
policy_loss = self.train_policy_network(dataset, 10)
self.value_loss.add_sample(value_loss)
self.policy_loss.add_sample(policy_loss)
for network in self.networks.values():
network.set_is_training(False)
self.post_training_commands()
self.training_iteration += 1
self.update_log() # should be done in order to update the data that has been accumulated * while not playing *
return np.append(value_loss, policy_loss)
def get_prediction(self, states):
tf_input_state = self.prepare_batch_for_inference(states, "actor")
return self.networks['actor'].online_network.predict(tf_input_state)