-
Notifications
You must be signed in to change notification settings - Fork 62
/
train_t2i_discrete.py
325 lines (262 loc) · 12.8 KB
/
train_t2i_discrete.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
import ml_collections
import torch
from torch import multiprocessing as mp
from datasets import get_dataset
from torchvision.utils import make_grid, save_image
import utils
import einops
from torch.utils._pytree import tree_map
import accelerate
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
import tempfile
from tools.fid_score import calculate_fid_given_paths
from absl import logging
import builtins
import os
import wandb
import libs.autoencoder
import numpy as np
def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
_betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
)
return _betas.numpy()
def get_skip(alphas, betas):
N = len(betas) - 1
skip_alphas = np.ones([N + 1, N + 1], dtype=betas.dtype)
for s in range(N + 1):
skip_alphas[s, s + 1:] = alphas[s + 1:].cumprod()
skip_betas = np.zeros([N + 1, N + 1], dtype=betas.dtype)
for t in range(N + 1):
prod = betas[1: t + 1] * skip_alphas[1: t + 1, t]
skip_betas[:t, t] = (prod[::-1].cumsum())[::-1]
return skip_alphas, skip_betas
def stp(s, ts: torch.Tensor): # scalar tensor product
if isinstance(s, np.ndarray):
s = torch.from_numpy(s).type_as(ts)
extra_dims = (1,) * (ts.dim() - 1)
return s.view(-1, *extra_dims) * ts
def mos(a, start_dim=1): # mean of square
return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)
class Schedule(object): # discrete time
def __init__(self, _betas):
r""" _betas[0...999] = betas[1...1000]
for n>=1, betas[n] is the variance of q(xn|xn-1)
for n=0, betas[0]=0
"""
self._betas = _betas
self.betas = np.append(0., _betas)
self.alphas = 1. - self.betas
self.N = len(_betas)
assert isinstance(self.betas, np.ndarray) and self.betas[0] == 0
assert isinstance(self.alphas, np.ndarray) and self.alphas[0] == 1
assert len(self.betas) == len(self.alphas)
# skip_alphas[s, t] = alphas[s + 1: t + 1].prod()
self.skip_alphas, self.skip_betas = get_skip(self.alphas, self.betas)
self.cum_alphas = self.skip_alphas[0] # cum_alphas = alphas.cumprod()
self.cum_betas = self.skip_betas[0]
self.snr = self.cum_alphas / self.cum_betas
def tilde_beta(self, s, t):
return self.skip_betas[s, t] * self.cum_betas[s] / self.cum_betas[t]
def sample(self, x0): # sample from q(xn|x0), where n is uniform
n = np.random.choice(list(range(1, self.N + 1)), (len(x0),))
eps = torch.randn_like(x0)
xn = stp(self.cum_alphas[n] ** 0.5, x0) + stp(self.cum_betas[n] ** 0.5, eps)
return torch.tensor(n, device=x0.device), eps, xn
def __repr__(self):
return f'Schedule({self.betas[:10]}..., {self.N})'
def LSimple(x0, nnet, schedule, **kwargs):
n, eps, xn = schedule.sample(x0) # n in {1, ..., 1000}
eps_pred = nnet(xn, n, **kwargs)
return mos(eps - eps_pred)
def train(config):
if config.get('benchmark', False):
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
mp.set_start_method('spawn')
accelerator = accelerate.Accelerator()
device = accelerator.device
accelerate.utils.set_seed(config.seed, device_specific=True)
logging.info(f'Process {accelerator.process_index} using device: {device}')
config.mixed_precision = accelerator.mixed_precision
config = ml_collections.FrozenConfigDict(config)
assert config.train.batch_size % accelerator.num_processes == 0
mini_batch_size = config.train.batch_size // accelerator.num_processes
if accelerator.is_main_process:
os.makedirs(config.ckpt_root, exist_ok=True)
os.makedirs(config.sample_dir, exist_ok=True)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
wandb.init(dir=os.path.abspath(config.workdir), project=f'uvit_{config.dataset.name}', config=config.to_dict(),
name=config.hparams, job_type='train', mode='offline')
utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log'))
logging.info(config)
else:
utils.set_logger(log_level='error')
builtins.print = lambda *args: None
logging.info(f'Run on {accelerator.num_processes} devices')
dataset = get_dataset(**config.dataset)
assert os.path.exists(dataset.fid_stat)
train_dataset = dataset.get_split(split='train', labeled=True)
train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True,
num_workers=8, pin_memory=True, persistent_workers=True)
test_dataset = dataset.get_split(split='test', labeled=True) # for sampling
test_dataset_loader = DataLoader(test_dataset, batch_size=config.sample.mini_batch_size, shuffle=True, drop_last=True,
num_workers=8, pin_memory=True, persistent_workers=True)
train_state = utils.initialize_train_state(config, device)
nnet, nnet_ema, optimizer, train_dataset_loader, test_dataset_loader = accelerator.prepare(
train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader, test_dataset_loader)
lr_scheduler = train_state.lr_scheduler
train_state.resume(config.ckpt_root)
autoencoder = libs.autoencoder.get_model(**config.autoencoder)
autoencoder.to(device)
@ torch.cuda.amp.autocast()
def encode(_batch):
return autoencoder.encode(_batch)
@ torch.cuda.amp.autocast()
def decode(_batch):
return autoencoder.decode(_batch)
def get_data_generator():
while True:
for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'):
yield data
data_generator = get_data_generator()
def get_context_generator():
while True:
for data in test_dataset_loader:
_, _context = data
yield _context
context_generator = get_context_generator()
_betas = stable_diffusion_beta_schedule()
_schedule = Schedule(_betas)
logging.info(f'use {_schedule}')
def cfg_nnet(x, timesteps, context):
_cond = nnet_ema(x, timesteps, context=context)
_empty_context = torch.tensor(dataset.empty_context, device=device)
_empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0))
_uncond = nnet_ema(x, timesteps, context=_empty_context)
return _cond + config.sample.scale * (_cond - _uncond)
def train_step(_batch):
_metrics = dict()
optimizer.zero_grad()
_z = autoencoder.sample(_batch[0]) if 'feature' in config.dataset.name else encode(_batch[0])
loss = LSimple(_z, nnet, _schedule, context=_batch[1]) # currently only support the extracted feature version
_metrics['loss'] = accelerator.gather(loss.detach()).mean()
accelerator.backward(loss.mean())
optimizer.step()
lr_scheduler.step()
train_state.ema_update(config.get('ema_rate', 0.9999))
train_state.step += 1
return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics)
def dpm_solver_sample(_n_samples, _sample_steps, **kwargs):
_z_init = torch.randn(_n_samples, *config.z_shape, device=device)
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
def model_fn(x, t_continuous):
t = t_continuous * _schedule.N
return cfg_nnet(x, t, **kwargs)
dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
_z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / _schedule.N, T=1.)
return decode(_z)
def eval_step(n_samples, sample_steps):
logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}, algorithm=dpm_solver, '
f'mini_batch_size={config.sample.mini_batch_size}')
def sample_fn(_n_samples):
_context = next(context_generator)
assert _context.size(0) == _n_samples
return dpm_solver_sample(_n_samples, sample_steps, context=_context)
with tempfile.TemporaryDirectory() as temp_path:
path = config.sample.path or temp_path
if accelerator.is_main_process:
os.makedirs(path, exist_ok=True)
utils.sample2dir(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)
_fid = 0
if accelerator.is_main_process:
_fid = calculate_fid_given_paths((dataset.fid_stat, path))
logging.info(f'step={train_state.step} fid{n_samples}={_fid}')
with open(os.path.join(config.workdir, 'eval.log'), 'a') as f:
print(f'step={train_state.step} fid{n_samples}={_fid}', file=f)
wandb.log({f'fid{n_samples}': _fid}, step=train_state.step)
_fid = torch.tensor(_fid, device=device)
_fid = accelerator.reduce(_fid, reduction='sum')
return _fid.item()
logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}')
step_fid = []
while train_state.step < config.train.n_steps:
nnet.train()
batch = tree_map(lambda x: x.to(device), next(data_generator))
metrics = train_step(batch)
nnet.eval()
if accelerator.is_main_process and train_state.step % config.train.log_interval == 0:
logging.info(utils.dct2str(dict(step=train_state.step, **metrics)))
logging.info(config.workdir)
wandb.log(metrics, step=train_state.step)
if accelerator.is_main_process and train_state.step % config.train.eval_interval == 0:
torch.cuda.empty_cache()
logging.info('Save a grid of images...')
contexts = torch.tensor(dataset.contexts, device=device)[: 2 * 5]
samples = dpm_solver_sample(_n_samples=2 * 5, _sample_steps=50, context=contexts)
samples = make_grid(dataset.unpreprocess(samples), 5)
save_image(samples, os.path.join(config.sample_dir, f'{train_state.step}.png'))
wandb.log({'samples': wandb.Image(samples)}, step=train_state.step)
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps:
torch.cuda.empty_cache()
logging.info(f'Save and eval checkpoint {train_state.step}...')
if accelerator.local_process_index == 0:
train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt'))
accelerator.wait_for_everyone()
fid = eval_step(n_samples=10000, sample_steps=50) # calculate fid of the saved checkpoint
step_fid.append((train_state.step, fid))
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
logging.info(f'Finish fitting, step={train_state.step}')
logging.info(f'step_fid: {step_fid}')
step_best = sorted(step_fid, key=lambda x: x[1])[0][0]
logging.info(f'step_best: {step_best}')
train_state.load(os.path.join(config.ckpt_root, f'{step_best}.ckpt'))
del metrics
accelerator.wait_for_everyone()
eval_step(n_samples=config.sample.n_samples, sample_steps=config.sample.sample_steps)
from absl import flags
from absl import app
from ml_collections import config_flags
import sys
from pathlib import Path
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file(
"config", None, "Training configuration.", lock_config=False)
flags.mark_flags_as_required(["config"])
flags.DEFINE_string("workdir", None, "Work unit directory.")
def get_config_name():
argv = sys.argv
for i in range(1, len(argv)):
if argv[i].startswith('--config='):
return Path(argv[i].split('=')[-1]).stem
def get_hparams():
argv = sys.argv
lst = []
for i in range(1, len(argv)):
assert '=' in argv[i]
if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'):
hparam, val = argv[i].split('=')
hparam = hparam.split('.')[-1]
if hparam.endswith('path'):
val = Path(val).stem
lst.append(f'{hparam}={val}')
hparams = '-'.join(lst)
if hparams == '':
hparams = 'default'
return hparams
def main(argv):
config = FLAGS.config
config.config_name = get_config_name()
config.hparams = get_hparams()
config.workdir = FLAGS.workdir or os.path.join('workdir', config.config_name, config.hparams)
config.ckpt_root = os.path.join(config.workdir, 'ckpts')
config.sample_dir = os.path.join(config.workdir, 'samples')
train(config)
if __name__ == "__main__":
app.run(main)