-
Notifications
You must be signed in to change notification settings - Fork 4
/
mpo.py
311 lines (253 loc) · 12.9 KB
/
mpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import time
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy
from torch.nn.utils import clip_grad_norm_
from torch.distributions import Categorical
from scipy.optimize import minimize
from traj_buffer import TrajBuffer
class MPO(object):
"""
Maximum A Posteriori Policy Optimization (MPO) ; Discrete action-space ; Retrace
Params:
env: gym environment
actor: actor network
critic: critic network
obs_shape: shape of observation (from env)
action_shape: shape of action
dual_constraint: learning rate of η in g(η)
kl_constraint: Hard constraint on KL
learning_rate: Bellman equation's decay for Q-retrace
clip:
alpha: scaling factor of the lagrangian multiplier in the M-step
episodes: number of iterations to sample episodes + do updates
sample_episodes: number of episodes to sample
episode_length: length of each episode
lagrange_it: number of Lagrangian optimization steps
runs: amount of training updates before updating target parameters
device: pytorch device
save_path: path to save model to
"""
def __init__(self, env, actor, critic, obs_shape, action_shape,
dual_constraint=0.1, kl_constraint=0.01,
learning_rate=0.99, alpha=1.0,
episodes=1000, sample_episodes=1, episode_length=1000,
lagrange_it=5, runs=50, device='cpu',
save_path="./model/mpo"):
# initialize env
self.env = env
# initialize some hyperparameters
self.α = alpha
self.ε = dual_constraint
self.ε_kl = kl_constraint
self.γ = learning_rate
self.episodes = episodes
self.sample_episodes = sample_episodes
self.episode_length = episode_length
self.lagrange_it = lagrange_it
self.mb_size = (episode_length) * env.num_envs
self.runs = runs
self.device = device
# initialize networks and optimizer
self.obs_shape = obs_shape
self.action_shape = action_shape
self.critic = critic
self.target_critic = deepcopy(critic)
for target_param, param in zip(self.target_critic.parameters(),
self.critic.parameters()):
target_param.data.copy_(param.data)
target_param.requires_grad = False
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
self.actor = actor
self.target_actor = deepcopy(actor)
for target_param, param in zip(self.target_actor.parameters(),
self.actor.parameters()):
target_param.data.copy_(param.data)
target_param.requires_grad = False
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
# initialize Lagrange Multiplier
self.η = np.random.rand()
self.η_kl = 0.0
# buffer and others
self.buffer = TrajBuffer(env, episode_length, 100000)
self.save_path = save_path
def _sample_trajectory(self):
mean_reward = 0
for _ in range(self.sample_episodes):
obs = self.env.reset()
done = False
obs_b = np.zeros([self.episode_length, self.env.num_envs, self.obs_shape])
action_b = np.zeros([self.episode_length, self.env.num_envs])
reward_b = np.zeros([self.episode_length, self.env.num_envs])
prob_b = np.zeros([self.episode_length, self.env.num_envs, self.action_shape])
done_b = np.zeros([self.episode_length, self.env.num_envs])
for steps in range(self.episode_length):
action, prob = self.target_actor.action(torch.from_numpy(np.expand_dims(obs, axis=0)).to(self.device).float())
action = np.reshape(action.cpu().numpy(), -1)
prob = prob.cpu().numpy()
obs_b[steps] = obs
done_b[steps] = done
obs, reward, done, _ = self.env.step(action)
mean_reward += reward
action_b[steps] = action
reward_b[steps] = reward
prob_b[steps] = prob
self.buffer.put(obs_b, action_b, reward_b, prob_b, done_b)
return mean_reward / self.episode_length / self.sample_episodes
def _update_critic_retrace(self, state_batch, action_batch, policies_batch, reward_batch, done_batch):
state_batch_last = state_batch[-1]
state_batch = state_batch[0:-1]
action_batch = action_batch[0:-1]
policies_batch = policies_batch[0:-1]
reward_batch = reward_batch[0:-1]
action_size = policies_batch.shape[-1]
nsteps = state_batch.shape[0]
n_envs = state_batch.shape[1]
self.critic_optimizer.zero_grad()
with torch.no_grad():
policies, a_log_prob, entropy = self.actor.evaluate_action(state_batch.view(-1, self.obs_shape), action_batch.view(-1, 1))
target_policies, _, _ = self.target_actor.evaluate_action(state_batch.view(-1, self.obs_shape), action_batch.view(-1, 1))
qval = self.critic(state_batch.view(-1, self.obs_shape))
val = (qval * policies).sum(1, keepdim=True)
old_policies = policies_batch.view(-1, action_size)
policies = policies.view(-1, action_size)
target_policies = target_policies.view(-1, action_size)
val = val.view(-1, 1)
qval = qval.view(-1, action_size)
a_log_prob = a_log_prob.view(-1, 1)
actions = action_batch.view(-1, 1)
q_i = qval.gather(1, actions.long())
rho = policies / (old_policies + 1e-10)
rho_i = rho.gather(1, actions.long())
with torch.no_grad():
next_qval = self.critic(state_batch_last).detach()
next_policies = self.actor.get_action_prob(state_batch_last).detach()
next_val = (next_qval * next_policies).sum(1, keepdim=True)
q_retraces = reward_batch.new(nsteps + 1, n_envs, 1).zero_()
q_retraces[-1] = next_val
for step in reversed(range(nsteps)):
q_ret = reward_batch[step] + self.γ * q_retraces[step + 1] * (1 - done_batch[step + 1])
q_retraces[step] = q_ret
q_ret = (rho_i[step] * (q_retraces[step] - q_i[step])) + val[step]
q_retraces = q_retraces[:-1]
q_retraces = q_retraces.view(-1, 1)
q_loss = (q_i - q_retraces.detach()).pow(2).mean() * 0.5
q_loss.backward()
clip_grad_norm_(self.critic.parameters(), 5.0)
self.critic_optimizer.step()
return q_loss.detach()
def _categorical_kl(self, p1, p2):
p1 = torch.clamp_min(p1, 0.0001)
p2 = torch.clamp_min(p2, 0.0001)
return torch.mean((p1 * torch.log(p1 / p2)).sum(dim=-1))
def _update_param(self):
# Update policy parameters
for target_param, param in zip(self.target_actor.parameters(), self.actor.parameters()):
target_param.data.copy_(param.data)
# Update critic parameters
for target_param, param in zip(self.target_critic.parameters(), self.critic.parameters()):
target_param.data.copy_(param.data)
def train(self):
# start training
start_time = time.time()
for episode in range(self.episodes):
# Update replay buffer
mean_reward = self._sample_trajectory()
mean_q_loss = 0
mean_policy = 0
# Find better policy by gradient descent
for _ in range(self.runs):
state_batch, action_batch, reward_batch, policies_batch, done_batch = self.buffer.get()
state_batch = torch.from_numpy(state_batch).to(self.device).float()
action_batch = torch.from_numpy(action_batch).to(self.device).float()
reward_batch = torch.from_numpy(reward_batch).to(self.device).float()
policies_batch = torch.from_numpy(policies_batch).to(self.device).float()
done_batch = torch.from_numpy(done_batch).to(self.device).float()
reward_batch = torch.unsqueeze(reward_batch, dim=-1)
done_batch = torch.unsqueeze(done_batch, dim=-1)
# Update Q-function
q_loss = self._update_critic_retrace(state_batch, action_batch, policies_batch, reward_batch, done_batch)
mean_q_loss += q_loss
# Sample values
state_batch = state_batch.view(self.mb_size, *tuple(state_batch.shape[2:]))
action_batch = action_batch.view(self.mb_size, *tuple(action_batch.shape[2:]))
with torch.no_grad():
actions = torch.arange(self.action_shape)[..., None].expand(self.action_shape, self.mb_size).to(self.device)
b_p = self.target_actor.forward(state_batch)
b = Categorical(probs=b_p)
b_prob = b.expand((self.action_shape, self.mb_size)).log_prob(actions).exp()
target_q = self.target_critic.forward(state_batch)
target_q = target_q.transpose(0, 1)
b_prob_np = b_prob.cpu().numpy()
target_q_np = target_q.cpu().numpy()
# E-step
# Update Dual-function
def dual(η):
"""
dual function of the non-parametric variational
g(η) = η*ε + η \sum \log (\sum \exp(Q(a, s)/η))
"""
max_q = np.max(target_q_np, 0)
return η * self.ε + np.mean(max_q) \
+ η * np.mean(np.log(np.sum(b_prob_np * np.exp((target_q_np - max_q) / η), axis=0)))
bounds = [(1e-6, None)]
res = minimize(dual, np.array([self.η]), method='SLSQP', bounds=bounds)
self.η = res.x[0]
# calculate the new q values
qij = torch.softmax(target_q / self.η, dim=0)
# M-step
# update policy based on lagrangian
for _ in range(self.lagrange_it):
π_p = self.actor.forward(state_batch)
π = Categorical(probs=π_p)
loss_p = torch.mean(
qij * π.expand((self.action_shape, self.mb_size)).log_prob(actions)
)
kl = self._categorical_kl(p1=π_p, p2=b_p)
# Update lagrange multipliers by gradient descent
self.η_kl -= self.α * (self.ε_kl - kl).detach().item()
if self.η_kl < 0.0:
self.η_kl = 0.0
self.actor_optimizer.zero_grad()
loss_policy = -(loss_p + self.η_kl * (self.ε_kl - kl))
loss_policy.backward()
clip_grad_norm_(self.actor.parameters(), 5.0)
self.actor_optimizer.step()
mean_policy += loss_policy.item()
# Update target parameters
self._update_param()
print(f"Episode = {episode} ; "
f"Mean reward = {np.mean(mean_reward) / self.episode_length / self.sample_episodes} ; "
f"Mean Q loss = {mean_q_loss / self.runs} ; "
f"Policy loss = {mean_policy / self.runs} ; "
f"η = {self.η} ; η_kl = {self.η_kl} ; "
f"time = {(time.time() - start_time):.2f}")
# Save model
self.save_model()
def load_model(self):
checkpoint = torch.load(self.save_path)
self.critic.load_state_dict(checkpoint['critic_state_dict'])
self.target_critic.load_state_dict(checkpoint['target_critic_state_dict'])
self.actor.load_state_dict(checkpoint['actor_state_dict'])
self.target_actor.load_state_dict(checkpoint['target_actor_state_dict'])
self.critic_optimizer.load_state_dict(checkpoint['critic_optim_state_dict'])
self.actor_optimizer.load_state_dict(checkpoint['actor_optim_state_dict'])
self.η = checkpoint['lagrange_η']
self.η_kl = checkpoint['lagrange_η_kl']
self.critic.train()
self.target_critic.train()
self.actor.train()
self.target_actor.train()
def save_model(self):
data = {
'critic_state_dict': self.critic.state_dict(),
'target_critic_state_dict': self.target_critic.state_dict(),
'actor_state_dict': self.actor.state_dict(),
'target_actor_state_dict': self.target_actor.state_dict(),
'critic_optim_state_dict': self.critic_optimizer.state_dict(),
'actor_optim_state_dict': self.actor_optimizer.state_dict(),
'lagrange_η': self.η,
'lagrange_η_kl': self.η_kl
}
torch.save(data, self.save_path)