forked from black0017/3D-GAN-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_gans.py
65 lines (52 loc) · 1.77 KB
/
train_gans.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from torch.autograd.variable import Variable
def ones_target(size):
'''
Tensor containing ones, with shape = size
'''
data = Variable(torch.ones(size, 1))
return data
def zeros_target(size):
'''
FAKE data
Tensor containing zeros, with shape = size
'''
data = Variable(torch.zeros(size, 1))
return data
def train_discriminator(discriminator, optimizer, real_data, fake_data, loss):
cuda = next(discriminator.parameters()).is_cuda
N = real_data.size(0)
# Reset gradients
optimizer.zero_grad()
# 1.1 Train on Real Data
prediction_real = discriminator(real_data)
# Calculate error and backpropagate
target_real = ones_target(N)
if cuda:
target_real.cuda()
error_real = loss(prediction_real, target_real)
error_real.backward()
# 1.2 Train on Fake Data
prediction_fake = discriminator(fake_data)
# Calculate error and backpropagate
target_fake = zeros_target(N)
if cuda:
target_fake.cuda()
error_fake = loss(prediction_fake, target_fake)
error_fake.backward()
# 1.3 Update weights with gradients
optimizer.step()
# Return error and predictions for real and fake inputs
return error_real + error_fake, prediction_real, prediction_fake
def train_generator(discriminator, optimizer, fake_data, loss):
cuda = next(discriminator.parameters()).is_cuda
N = fake_data.size(0) # Reset gradients
optimizer.zero_grad() # Sample noise and generate fake data
prediction = discriminator(fake_data) # Calculate error and backpropagate
target = ones_target(N)
if cuda:
target.cuda()
error = loss(prediction, target)
error.backward() # Update weights with gradients
optimizer.step() # Return error
return error