-
Notifications
You must be signed in to change notification settings - Fork 14
/
models.py
108 lines (92 loc) · 4.31 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import pytorch_lightning as pl
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
class VAE(pl.LightningModule):
def __init__(self, max_len, vocab_len, latent_dim, embedding_dim):
super(VAE, self).__init__()
self.latent_dim = latent_dim
self.max_len = max_len
self.vocab_len = vocab_len
self.embedding = nn.Embedding(vocab_len, embedding_dim, padding_idx=0)
self.encoder = nn.Sequential(nn.Linear(max_len * embedding_dim, 2000),
nn.ReLU(),
nn.Linear(2000, 1000),
nn.BatchNorm1d(1000),
nn.ReLU(),
nn.Linear(1000, 1000),
nn.BatchNorm1d(1000),
nn.ReLU(),
nn.Linear(1000, latent_dim * 2))
self.decoder = nn.Sequential(nn.Linear(latent_dim, 1000),
nn.BatchNorm1d(1000),
nn.ReLU(),
nn.Linear(1000, 1000),
nn.BatchNorm1d(1000),
nn.ReLU(),
nn.Linear(1000, 2000),
nn.ReLU(),
nn.Linear(2000, max_len * vocab_len))
def encode(self, x):
x = self.encoder(self.embedding(x).view((len(x), -1))).view((-1, 2, self.latent_dim))
mu, log_var = x[:, 0, :], x[:, 1, :]
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std, mu, log_var
def decode(self, x):
return F.log_softmax(self.decoder(x).view((-1, self.max_len, self.vocab_len)), dim=2).view((-1, self.max_len * self.vocab_len))
def forward(self, x):
z, mu, log_var = self.encode(x)
return self.decode(z), z, mu, log_var
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=0.0001)
return {'optimizer': optimizer}
def loss_function(self, pred, target, mu, log_var, batch_size, p):
nll = F.nll_loss(pred, target)
kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / (batch_size * pred.shape[1])
return (1 - p) * nll + p * kld, nll, kld
def training_step(self, train_batch, batch_idx):
out, z, mu, log_var = self(train_batch)
p = 0.1
loss, nll, kld = self.loss_function(out.reshape((-1, self.vocab_len)), train_batch.flatten(), mu, log_var, len(train_batch), p)
self.log('train_loss', loss)
self.log('train_nll', nll)
self.log('train_kld', kld)
return loss
def validation_step(self, val_batch, batch_idx):
out, z, mu, log_var = self(val_batch)
loss, nll, kld = self.loss_function(out.reshape((-1, self.vocab_len)), val_batch.flatten(), mu, log_var, len(val_batch), 0.5)
self.log('val_loss', loss)
self.log('val_nll', nll)
self.log('val_kld', kld)
self.log('val_mu', torch.mean(mu))
self.log('val_logvar', torch.mean(log_var))
return loss
class PropertyPredictor(pl.LightningModule):
def __init__(self, in_dim, learning_rate=0.001):
super(PropertyPredictor, self).__init__()
self.learning_rate = learning_rate
self.fc = nn.Sequential(nn.Linear(in_dim, 1000),
nn.ReLU(),
nn.Linear(1000, 1000),
nn.ReLU(),
nn.Linear(1000, 1))
def forward(self, x):
return self.fc(x)
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=self.learning_rate)
def loss_function(self, pred, real):
return F.mse_loss(pred, real)
def training_step(self, batch, batch_idx):
x, y = batch
out = self(x)
loss = self.loss_function(out, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
out = self(x)
loss = self.loss_function(out, y)
self.log('val_loss', loss)
return loss