-
Notifications
You must be signed in to change notification settings - Fork 0
/
gan.py
135 lines (114 loc) · 4.62 KB
/
gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import torch.nn.functional as F
from typing import Optional, Callable, Dict
from ai.train.env.base import MultiEnv
from ai.util import on_interval
from ai.train.logger import log
from ai.model import Model
class Gan(MultiEnv):
'''Training environment for Generative Adversarial Networks.
ARGS
aug : callable or null
optional augmentation before calling the discriminator
TODO: adaptive discriminator augmentation
g_reg_interval : int or null
perform regularization for the generator every g_reg_interval steps.
NOTE: no default implementation. for an example, see path length
regularization in ai/examples/stylegan2/train.py
g_reg_weight : float
weight of the generator's regularization loss
d_reg_interval : int or null
perform regularization for the discrim every d_reg_interval steps.
default is gradient penalty (https://arxiv.org/pdf/1704.00028.pdf)
d_reg_weight : float
weight of the discriminator's regularization loss
'''
def __init__(s,
aug: Optional[Callable] = None,
g_reg_interval: Optional[int] = None,
g_reg_weight: float = 1.,
d_reg_interval: Optional[int] = 16,
d_reg_weight: float = 1.,
):
s._aug = aug
s._g_reg_interval, s._g_reg_weight = g_reg_interval, g_reg_weight
s._d_reg_interval, s._d_reg_weight = d_reg_interval, d_reg_weight
# generator step
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def G(s,
models: Dict[str, Model],
batch: torch.Tensor,
step: int = 0,
) -> torch.Tensor:
# main
loss = s._g_main(models, batch)
log('G.loss.main', loss)
# regularize
if on_interval(step, s._g_reg_interval):
reg_loss = s._g_reg(models, batch)
log('G.loss.reg', reg_loss)
loss += reg_loss * s._g_reg_weight * s._g_reg_interval
return loss
def _g_main(s, models, batch):
g_out = s._generate(models['G'], batch)
d_out = s._discriminate(models['D'], g_out, detach=False)
return s._g_loss_fn(d_out)
def _g_reg(s, models, batch):
raise NotImplementedError(
'set g_reg_interval without implementing _g_reg')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# discriminator step
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def D(s,
models: Dict[str, Model],
batch: torch.Tensor,
step: int = 0,
) -> torch.Tensor:
# main
loss = s._d_main(models, batch)
log('D.loss.main', loss)
# regularize
if on_interval(step, s._d_reg_interval):
reg_loss = s._d_reg(models, batch)
log('D.loss.reg', reg_loss)
loss += reg_loss * s._d_reg_weight * s._d_reg_interval
return loss
def _d_main(s, models, batch):
G, D = models['G'], models['D']
g_out = s._generate(G, batch)
loss_fake = s._d_loss_fn(s._discriminate(D, g_out, True), False)
loss_real = s._d_loss_fn(s._discriminate(D, batch, True), True)
return loss_fake + loss_real
def _d_reg(s, models, batch):
return s._gradient_penalty(models['D'], batch)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# helpers
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def _generate(s, generator, batch):
z = s._random_z(generator.z_dim, batch.shape[0], batch.device)
return generator(z)
def _discriminate(s, discriminator, x, detach):
if s._aug is not None:
x = s._aug(x)
if detach:
x = x.detach()
return discriminator(x)
def _g_loss_fn(s, logits):
return F.softplus(-logits).mean()
def _d_loss_fn(s, logits, is_real):
sign = -1 if is_real else 1
return F.softplus(sign * logits).mean()
def _random_z(s, z_dim, bs, device):
return torch.randn([bs, z_dim], device=device)
def _gradient_penalty(s, discriminator, batch):
input_real = batch.detach().requires_grad_(True)
logits_real = s._discriminate(discriminator, input_real, False)
grads = torch.autograd.grad(
outputs=[logits_real.sum()],
inputs=[input_real],
create_graph=True,
only_inputs=True,
)[0]
loss = grads.square().sum([1, 2, 3])
return loss.mean()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~