-
Notifications
You must be signed in to change notification settings - Fork 69
/
train.py
126 lines (99 loc) · 3.64 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import math
import torch
import torch.nn as nn
import torch.utils
import torch.utils.data
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
from model import VRNN
"""implementation of the Variational Recurrent
Neural Network (VRNN) from https://arxiv.org/abs/1506.02216
using unimodal isotropic gaussian distributions for
inference, prior, and generating models."""
def train(epoch):
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
#transforming data
data = data.to(device)
data = data.squeeze().transpose(0, 1) # (seq, batch, elem)
data = (data - data.min()) / (data.max() - data.min())
#forward + backward + optimize
optimizer.zero_grad()
kld_loss, nll_loss, _, _ = model(data)
loss = kld_loss + nll_loss
loss.backward()
optimizer.step()
#grad norm clipping, only in pytorch version >= 1.10
nn.utils.clip_grad_norm_(model.parameters(), clip)
#printing
if batch_idx % print_every == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\t KLD Loss: {:.6f} \t NLL Loss: {:.6f}'.format(
epoch, batch_idx * batch_size, batch_size * (len(train_loader.dataset)//batch_size),
100. * batch_idx / len(train_loader),
kld_loss / batch_size,
nll_loss / batch_size))
sample = model.sample(torch.tensor(28, device=device))
plt.imshow(sample.to(torch.device('cpu')).numpy())
plt.pause(1e-6)
train_loss += loss.item()
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
def test(epoch):
"""uses test data to evaluate
likelihood of the model"""
mean_kld_loss, mean_nll_loss = 0, 0
with torch.no_grad():
for i, (data, _) in enumerate(test_loader):
data = data.to(device)
data = data.squeeze().transpose(0, 1)
data = (data - data.min()) / (data.max() - data.min())
kld_loss, nll_loss, _, _ = model(data)
mean_kld_loss += kld_loss.item()
mean_nll_loss += nll_loss.item()
mean_kld_loss /= len(test_loader.dataset)
mean_nll_loss /= len(test_loader.dataset)
print('====> Test set loss: KLD Loss = {:.4f}, NLL Loss = {:.4f} '.format(
mean_kld_loss, mean_nll_loss))
# changing device
if torch.cuda.is_available():
device = torch.device('cuda')
torch.cuda.empty_cache()
else:
device = torch.device('cpu')
#hyperparameters
x_dim = 28
h_dim = 100
z_dim = 16
n_layers = 1
n_epochs = 25
clip = 10
learning_rate = 1e-3
batch_size = 8 #128
seed = 128
print_every = 1000 # batches
save_every = 10 # epochs
#manual seed
torch.manual_seed(seed)
plt.ion()
#init model + optimizer + datasets
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True,
transform=transforms.ToTensor()),
batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False,
transform=transforms.ToTensor()),
batch_size=batch_size, shuffle=True)
model = VRNN(x_dim, h_dim, z_dim, n_layers)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(1, n_epochs + 1):
#training + testing
train(epoch)
test(epoch)
#saving model
if epoch % save_every == 1:
fn = 'saves/vrnn_state_dict_'+str(epoch)+'.pth'
torch.save(model.state_dict(), fn)
print('Saved model to '+fn)