-
Notifications
You must be signed in to change notification settings - Fork 2
/
gan_cifar10.py
105 lines (85 loc) · 4.14 KB
/
gan_cifar10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse
import os
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import models
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--outf', default='tmp/cifar10', help='where to save results')
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--niterD', type=int, default=5, help='no. updates of D per update of G')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument('--alpha', type=float, default=0.0, help='Lagrange multiplier')
parser.add_argument('--rho', type=float, default=1e-5, help='quadratic weight penalty')
args = parser.parse_args()
cudnn.benchmark = True
os.system('mkdir -p {}'.format(args.outf))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = 32
dataset = dset.CIFAR10(root='cifar10', download=True,
transform=transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
shuffle=True, num_workers=2, drop_last=True)
# Resnet and Convnet from WGAN-GP paper
netG = models.Resnet_G().to(device)
netD = models.Convnet_D().to(device)
NZ = 128
z = torch.FloatTensor(args.batch_size, NZ).to(device)
alpha = torch.tensor(args.alpha).to(device)
optimizerG = optim.Adam(netG.parameters(), lr=args.lr, betas=(0.5, 0.9), amsgrad=True)
optimizerD = optim.Adam(netD.parameters(), lr=args.lr, betas=(0.5, 0.9), amsgrad=True)
losses = []
for epoch in range(args.epochs):
for i, data in enumerate(dataloader):
# --- train D
for _ in range(args.niterD):
optimizerD.zero_grad()
x_real = data[0].to(device)
x_fake = netG(z.normal_(0,1)).detach()
x_real.requires_grad_() # to compute gradD_real
x_fake.requires_grad_() # to compute gradD_fake
y_real = netD(x_real)
y_fake = netD(x_fake)
lossE = y_real.mean() - y_fake.mean()
# grad() does not broadcast so we compute for the sum, effect is the same
gradD_real = torch.autograd.grad(y_real.sum(), x_real, create_graph=True)[0]
gradD_fake = torch.autograd.grad(y_fake.sum(), x_fake, create_graph=True)[0]
omega = 0.5*(gradD_real.view(gradD_real.size(0), -1).pow(2).sum(dim=1).mean() +
gradD_fake.view(gradD_fake.size(0), -1).pow(2).sum(dim=1).mean())
loss = -lossE - alpha*(1.0 - omega) + 0.5*args.rho*(1.0 - omega).pow(2)
loss.backward()
optimizerD.step()
alpha -= args.rho*(1.0 - omega.item())
# --- train G
optimizerG.zero_grad()
x_fake = netG(z.normal_(0,1))
y_fake = netD(x_fake)
loss = -y_fake.mean()
loss.backward()
optimizerG.step()
# --- logging
losses.append(lossE.item())
if (i+1) % 100 == 0:
print "epoch: {} | [{}/{}] loss: {:.3f}, alpha: {:.3f}, omega: {:.3f}".format(
epoch, (i+1), int(len(dataset)/args.batch_size), lossE.item(), alpha.item(), omega.item())
# generated images and loss curve
vutils.save_image(x_fake, '{}/x_{}.png'.format(args.outf, epoch), normalize=True)
fig, ax = plt.subplots()
ax.set_ylabel('IPM estimate')
ax.set_xlabel('iteration')
ax.semilogy(losses)
fig.savefig('{}/loss.png'.format(args.outf))
plt.close(fig)