-
Notifications
You must be signed in to change notification settings - Fork 2
/
made.py
95 lines (81 loc) · 3.62 KB
/
made.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import numpy as np
import torch
import torch.nn.functional as F
from torch import distributions as tdib
from torch import nn
from gms import common
class MADE(common.Autoreg):
DG = common.AttrDict()
DG.hidden_size = 1024
def __init__(self, G):
super().__init__(G)
self.nin = 784
self.nout = 784
self.hidden_sizes = [G.hidden_size] * 3
# simple MLP neural net w/ masked layers
net = []
hs = [self.nin] + self.hidden_sizes + [self.nout]
for h0, h1 in zip(hs, hs[1:]):
net.extend(
[
MaskedLinear(h0, h1),
nn.ReLU(),
]
)
net.pop() # no activation on the last layer
self.net = nn.Sequential(*net)
self.m = {}
self.create_mask() # builds the initial self.m connectivity
def create_mask(self):
"""
the output that connects to pixel 0 can never see information from pixels 1-783. and so on.
you are flexible to use the neurons for whatever. like the data could have come from wherever.
you just need to assure that no information can propagate from anywhere earlier in the image.
this masking generation logic assures that.
"""
L = len(self.hidden_sizes)
# sample the order of the inputs and the connectivity of all neurons
self.m[-1] = np.arange(self.nin)
for l in range(L):
self.m[l] = np.random.randint(
self.m[l - 1].min(), self.nin - 1, size=self.hidden_sizes[l]
)
# construct the mask matrices
# only activate connections where information comes from a lower numerical rank
masks = [self.m[l - 1][:, None] <= self.m[l][None, :] for l in range(L)]
masks.append(self.m[L - 1][:, None] < self.m[-1][None, :])
# set the masks in all MaskedLinear layers
layers = [l for l in self.net.modules() if isinstance(l, MaskedLinear)]
for l, m in zip(layers, masks):
l.set_mask(m)
def loss(self, x, y=None):
x = x.to(self.G.device)
x = x.view(-1, 784) # Flatten image
logits = self.net(x)
loss = -tdib.Bernoulli(logits=logits).log_prob(x).mean()
return loss, {'nlogp': loss}
def sample(self, n):
samples = torch.zeros(n, 784).to(self.G.device)
# set the pixels 1 by 1 in raster order.
# choose pixel 0, then based on that choose pixel 1, then based on both of those choose pixel 2. etc and so on.
# This works ok, because it is used to this version of information propagation.
# Normally, you can't see the future. And here you can't either. So the same condition is enforced.
steps = []
with torch.no_grad():
for i in range(784):
logits = self.net(samples)[:, i]
probs = torch.sigmoid(logits)
samples[:, i] = torch.bernoulli(probs)
steps += [samples.view(n, 1, 28, 28).cpu()]
# plt.imsave(f'gifs/{i}.png', x.numpy())
samples = samples.view(n, 1, 28, 28)
return samples.cpu(), torch.stack(steps)
class MaskedLinear(nn.Linear):
"""same as Linear except has a configurable mask on the weights"""
def __init__(self, in_features, out_features, bias=True):
super().__init__(in_features, out_features, bias)
self.register_buffer('mask', torch.ones(out_features, in_features))
def set_mask(self, mask):
self.mask.data.copy_(torch.from_numpy(mask.astype(np.uint8).T))
def forward(self, input):
return F.linear(input, self.mask * self.weight, self.bias)