forked from aimagelab/meshed-memory-transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
323 lines (276 loc) · 13.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
import random
from data import ImageDetectionsField, TextField, RawField
from data import COCO, DataLoader
import evaluation
from evaluation import PTBTokenizer, Cider
from models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.nn import NLLLoss
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import argparse, os, pickle
import numpy as np
import itertools
import multiprocessing
from shutil import copyfile
random.seed(1234)
torch.manual_seed(1234)
np.random.seed(1234)
def evaluate_loss(model, dataloader, loss_fn, text_field):
# Validation loss
model.eval()
running_loss = .0
with tqdm(desc='Epoch %d - validation' % e, unit='it', total=len(dataloader)) as pbar:
with torch.no_grad():
for it, (detections, captions) in enumerate(dataloader):
detections, captions = detections.to(device), captions.to(device)
out = model(detections, captions)
captions = captions[:, 1:].contiguous()
out = out[:, :-1].contiguous()
loss = loss_fn(out.view(-1, len(text_field.vocab)), captions.view(-1))
this_loss = loss.item()
running_loss += this_loss
pbar.set_postfix(loss=running_loss / (it + 1))
pbar.update()
val_loss = running_loss / len(dataloader)
return val_loss
def evaluate_metrics(model, dataloader, text_field):
import itertools
model.eval()
gen = {}
gts = {}
with tqdm(desc='Epoch %d - evaluation' % e, unit='it', total=len(dataloader)) as pbar:
for it, (images, caps_gt) in enumerate(iter(dataloader)):
images = images.to(device)
with torch.no_grad():
out, _ = model.beam_search(images, 20, text_field.vocab.stoi['<eos>'], 5, out_size=1)
caps_gen = text_field.decode(out, join_words=False)
for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)):
gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
gen['%d_%d' % (it, i)] = [gen_i, ]
gts['%d_%d' % (it, i)] = gts_i
pbar.update()
gts = evaluation.PTBTokenizer.tokenize(gts)
gen = evaluation.PTBTokenizer.tokenize(gen)
scores, _ = evaluation.compute_scores(gts, gen)
return scores
def train_xe(model, dataloader, optim, text_field):
# Training with cross-entropy
model.train()
scheduler.step()
running_loss = .0
with tqdm(desc='Epoch %d - train' % e, unit='it', total=len(dataloader)) as pbar:
for it, (detections, captions) in enumerate(dataloader):
detections, captions = detections.to(device), captions.to(device)
out = model(detections, captions)
optim.zero_grad()
captions_gt = captions[:, 1:].contiguous()
out = out[:, :-1].contiguous()
loss = loss_fn(out.view(-1, len(text_field.vocab)), captions_gt.view(-1))
loss.backward()
optim.step()
this_loss = loss.item()
running_loss += this_loss
pbar.set_postfix(loss=running_loss / (it + 1))
pbar.update()
scheduler.step()
loss = running_loss / len(dataloader)
return loss
def train_scst(model, dataloader, optim, cider, text_field):
# Training with self-critical
tokenizer_pool = multiprocessing.Pool()
running_reward = .0
running_reward_baseline = .0
model.train()
running_loss = .0
seq_len = 20
beam_size = 5
with tqdm(desc='Epoch %d - train' % e, unit='it', total=len(dataloader)) as pbar:
for it, (detections, caps_gt) in enumerate(dataloader):
detections = detections.to(device)
outs, log_probs = model.beam_search(detections, seq_len, text_field.vocab.stoi['<eos>'],
beam_size, out_size=beam_size)
optim.zero_grad()
# Rewards
caps_gen = text_field.decode(outs.view(-1, seq_len))
caps_gt = list(itertools.chain(*([c, ] * beam_size for c in caps_gt)))
caps_gen, caps_gt = tokenizer_pool.map(evaluation.PTBTokenizer.tokenize, [caps_gen, caps_gt])
reward = cider.compute_score(caps_gt, caps_gen)[1].astype(np.float32)
reward = torch.from_numpy(reward).to(device).view(detections.shape[0], beam_size)
reward_baseline = torch.mean(reward, -1, keepdim=True)
loss = -torch.mean(log_probs, -1) * (reward - reward_baseline)
loss = loss.mean()
loss.backward()
optim.step()
running_loss += loss.item()
running_reward += reward.mean().item()
running_reward_baseline += reward_baseline.mean().item()
pbar.set_postfix(loss=running_loss / (it + 1), reward=running_reward / (it + 1),
reward_baseline=running_reward_baseline / (it + 1))
pbar.update()
loss = running_loss / len(dataloader)
reward = running_reward / len(dataloader)
reward_baseline = running_reward_baseline / len(dataloader)
return loss, reward, reward_baseline
if __name__ == '__main__':
device = torch.device('cuda')
parser = argparse.ArgumentParser(description='Meshed-Memory Transformer')
parser.add_argument('--exp_name', type=str, default='m2_transformer')
parser.add_argument('--batch_size', type=int, default=10)
parser.add_argument('--workers', type=int, default=0)
parser.add_argument('--m', type=int, default=40)
parser.add_argument('--head', type=int, default=8)
parser.add_argument('--warmup', type=int, default=10000)
parser.add_argument('--resume_last', action='store_true')
parser.add_argument('--resume_best', action='store_true')
parser.add_argument('--features_path', type=str)
parser.add_argument('--annotation_folder', type=str)
parser.add_argument('--logs_folder', type=str, default='tensorboard_logs')
args = parser.parse_args()
print(args)
print('Meshed-Memory Transformer Training')
writer = SummaryWriter(log_dir=os.path.join(args.logs_folder, args.exp_name))
# Pipeline for image regions
image_field = ImageDetectionsField(detections_path=args.features_path, max_detections=50, load_in_tmp=False)
# Pipeline for text
text_field = TextField(init_token='<bos>', eos_token='<eos>', lower=True, tokenize='spacy',
remove_punctuation=True, nopoints=False)
# Create the dataset
dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder)
train_dataset, val_dataset, test_dataset = dataset.splits
if not os.path.isfile('vocab_%s.pkl' % args.exp_name):
print("Building vocabulary")
text_field.build_vocab(train_dataset, val_dataset, min_freq=5)
pickle.dump(text_field.vocab, open('vocab_%s.pkl' % args.exp_name, 'wb'))
else:
text_field.vocab = pickle.load(open('vocab_%s.pkl' % args.exp_name, 'rb'))
# Model and dataloaders
encoder = MemoryAugmentedEncoder(3, 0, attention_module=ScaledDotProductAttentionMemory,
attention_module_kwargs={'m': args.m})
decoder = MeshedDecoder(len(text_field.vocab), 54, 3, text_field.vocab.stoi['<pad>'])
model = Transformer(text_field.vocab.stoi['<bos>'], encoder, decoder).to(device)
dict_dataset_train = train_dataset.image_dictionary({'image': image_field, 'text': RawField()})
ref_caps_train = list(train_dataset.text)
cider_train = Cider(PTBTokenizer.tokenize(ref_caps_train))
dict_dataset_val = val_dataset.image_dictionary({'image': image_field, 'text': RawField()})
dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()})
def lambda_lr(s):
warm_up = args.warmup
s += 1
return (model.d_model ** -.5) * min(s ** -.5, s * warm_up ** -1.5)
# Initial conditions
optim = Adam(model.parameters(), lr=1, betas=(0.9, 0.98))
scheduler = LambdaLR(optim, lambda_lr)
loss_fn = NLLLoss(ignore_index=text_field.vocab.stoi['<pad>'])
use_rl = False
best_cider = .0
patience = 0
start_epoch = 0
if args.resume_last or args.resume_best:
if args.resume_last:
fname = 'saved_models/%s_last.pth' % args.exp_name
else:
fname = 'saved_models/%s_best.pth' % args.exp_name
if os.path.exists(fname):
data = torch.load(fname)
torch.set_rng_state(data['torch_rng_state'])
torch.cuda.set_rng_state(data['cuda_rng_state'])
np.random.set_state(data['numpy_rng_state'])
random.setstate(data['random_rng_state'])
model.load_state_dict(data['state_dict'], strict=False)
optim.load_state_dict(data['optimizer'])
scheduler.load_state_dict(data['scheduler'])
start_epoch = data['epoch'] + 1
best_cider = data['best_cider']
patience = data['patience']
use_rl = data['use_rl']
print('Resuming from epoch %d, validation loss %f, and best cider %f' % (
data['epoch'], data['val_loss'], data['best_cider']))
print("Training starts")
for e in range(start_epoch, start_epoch + 100):
dataloader_train = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
drop_last=True)
dataloader_val = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
dict_dataloader_train = DataLoader(dict_dataset_train, batch_size=args.batch_size // 5, shuffle=True,
num_workers=args.workers)
dict_dataloader_val = DataLoader(dict_dataset_val, batch_size=args.batch_size // 5)
dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size // 5)
if not use_rl:
train_loss = train_xe(model, dataloader_train, optim, text_field)
writer.add_scalar('data/train_loss', train_loss, e)
else:
train_loss, reward, reward_baseline = train_scst(model, dict_dataloader_train, optim, cider_train, text_field)
writer.add_scalar('data/train_loss', train_loss, e)
writer.add_scalar('data/reward', reward, e)
writer.add_scalar('data/reward_baseline', reward_baseline, e)
# Validation loss
val_loss = evaluate_loss(model, dataloader_val, loss_fn, text_field)
writer.add_scalar('data/val_loss', val_loss, e)
# Validation scores
scores = evaluate_metrics(model, dict_dataloader_val, text_field)
print("Validation scores", scores)
val_cider = scores['CIDEr']
writer.add_scalar('data/val_cider', val_cider, e)
writer.add_scalar('data/val_bleu1', scores['BLEU'][0], e)
writer.add_scalar('data/val_bleu4', scores['BLEU'][3], e)
writer.add_scalar('data/val_meteor', scores['METEOR'], e)
writer.add_scalar('data/val_rouge', scores['ROUGE'], e)
# Test scores
scores = evaluate_metrics(model, dict_dataloader_test, text_field)
print("Test scores", scores)
writer.add_scalar('data/test_cider', scores['CIDEr'], e)
writer.add_scalar('data/test_bleu1', scores['BLEU'][0], e)
writer.add_scalar('data/test_bleu4', scores['BLEU'][3], e)
writer.add_scalar('data/test_meteor', scores['METEOR'], e)
writer.add_scalar('data/test_rouge', scores['ROUGE'], e)
# Prepare for next epoch
best = False
if val_cider >= best_cider:
best_cider = val_cider
patience = 0
best = True
else:
patience += 1
switch_to_rl = False
exit_train = False
if patience == 5:
if not use_rl:
use_rl = True
switch_to_rl = True
patience = 0
optim = Adam(model.parameters(), lr=5e-6)
print("Switching to RL")
else:
print('patience reached.')
exit_train = True
if switch_to_rl and not best:
data = torch.load('saved_models/%s_best.pth' % args.exp_name)
torch.set_rng_state(data['torch_rng_state'])
torch.cuda.set_rng_state(data['cuda_rng_state'])
np.random.set_state(data['numpy_rng_state'])
random.setstate(data['random_rng_state'])
model.load_state_dict(data['state_dict'])
print('Resuming from epoch %d, validation loss %f, and best cider %f' % (
data['epoch'], data['val_loss'], data['best_cider']))
torch.save({
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state(),
'numpy_rng_state': np.random.get_state(),
'random_rng_state': random.getstate(),
'epoch': e,
'val_loss': val_loss,
'val_cider': val_cider,
'state_dict': model.state_dict(),
'optimizer': optim.state_dict(),
'scheduler': scheduler.state_dict(),
'patience': patience,
'best_cider': best_cider,
'use_rl': use_rl,
}, 'saved_models/%s_last.pth' % args.exp_name)
if best:
copyfile('saved_models/%s_last.pth' % args.exp_name, 'saved_models/%s_best.pth' % args.exp_name)
if exit_train:
writer.close()
break