-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
115 lines (80 loc) · 3.57 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import numpy as np
from gym.spaces import Box, Discrete
import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical
def mlp(sizes, activation=nn.ReLU, output_activation=nn.Identity):
layers = []
for j in range(len(sizes)-1):
act = activation if j < len(sizes)-2 else output_activation
layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
return nn.Sequential(*layers)
class RewNet(nn.Module):
def __init__(self, obs_dim, hidden_sizes):
super().__init__()
self.reward_net = mlp([obs_dim] + list(hidden_sizes) + [1])
def forward(self, next_obs):
return self.reward_net(next_obs)
class Actor(nn.Module):
def _distribution(self, obs):
raise NotImplementedError
def _log_prob_from_distribution(self, pi, act):
raise NotImplementedError
def forward(self, obs, act=None):
# Produce action distributions for given observations, and
# optionally compute the log likelihood of given actions under
# those distributions.
pi = self._distribution(obs)
logp_a = None
if act is not None:
logp_a = self._log_prob_from_distribution(pi, act)
return pi, logp_a
class MLPCategoricalActor(Actor):
def __init__(self, obs_dim, act_dim, hidden_sizes, activation=nn.ReLU):
super().__init__()
self.logits_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)
def _distribution(self, obs):
logits = self.logits_net(obs)
return Categorical(logits=logits)
def _log_prob_from_distribution(self, pi, act):
return pi.log_prob(act)
class MLPGaussianActor(Actor):
def __init__(self, obs_dim, act_dim, hidden_sizes, activation=nn.ReLU):
super().__init__()
log_std = -0.5 * np.ones(act_dim, dtype=np.float32)
self.log_std = torch.nn.Parameter(torch.as_tensor(log_std))
self.mu_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)
def _distribution(self, obs):
mu = self.mu_net(obs)
std = torch.exp(self.log_std)
return Normal(mu, std)
def _log_prob_from_distribution(self, pi, act):
return pi.log_prob(act).sum(axis=-1) # Last axis sum needed for Torch Normal distribution
class MLPCritic(nn.Module):
def __init__(self, obs_dim, hidden_sizes, activation=nn.ReLU):
super().__init__()
self.v_net = mlp([obs_dim] + list(hidden_sizes) + [1], activation)
def forward(self, obs):
return torch.squeeze(self.v_net(obs), -1) # Critical to ensure v has right shape.
class MLPActorCritic(nn.Module):
def __init__(self, observation_space, action_space,
hidden_sizes=(64,64), activation=nn.Tanh):
super().__init__()
obs_dim = observation_space.shape[0]
# policy builder depends on action space
if isinstance(action_space, Box):
self.pi = MLPGaussianActor(obs_dim, action_space.shape[0], hidden_sizes, activation)
elif isinstance(action_space, Discrete):
self.pi = MLPCategoricalActor(obs_dim, action_space.n, hidden_sizes, activation)
# build value function
self.v = MLPCritic(obs_dim, hidden_sizes, activation)
def step(self, obs):
with torch.no_grad():
pi = self.pi._distribution(obs)
a = pi.sample()
logp_a = self.pi._log_prob_from_distribution(pi, a)
v = self.v(obs)
return a.numpy(), v.numpy(), logp_a.numpy()
def act(self, obs):
return self.step(obs)[0]