-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
106 lines (85 loc) · 3.3 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import logging
import os
import time
import torch
from config import parse_args
from data_helper import create_dataloaders
from model import MultiModal, MyModel
from util import setup_device, setup_seed, setup_logging, build_optimizer, evaluate
from ema import EMA
def validate(model, val_dataloader):
model.eval()
predictions = []
labels = []
losses = []
with torch.no_grad():
for batch in val_dataloader:
loss, _, pred_label_id, label = model(batch)
loss = loss.mean()
predictions.extend(pred_label_id.cpu().numpy())
labels.extend(label.cpu().numpy())
losses.append(loss.cpu().numpy())
loss = sum(losses) / len(losses)
results = evaluate(predictions, labels)
model.train()
return loss, results
def train_and_validate(args):
# 1. load data
train_dataloader, val_dataloader = create_dataloaders(args)
# 2. build model and optimizers
# model = MultiModal(args)
model = MyModel(args)
if args.ema:
ema = EMA(model, 0.99, device=args.device)
ema.register()
optimizer, scheduler = build_optimizer(args, model)
if args.device == 'cuda':
model = torch.nn.parallel.DataParallel(model.to(args.device))
# 3. training
step = 0
best_score = args.best_score
start_time = time.time()
num_total_steps = len(train_dataloader) * args.max_epochs
for epoch in range(args.max_epochs):
for batch in train_dataloader:
model.train()
loss, accuracy, _, _ = model(batch)
loss = loss.mean()
accuracy = accuracy.mean()
loss.backward()
optimizer.step()
if args.ema:
ema.update()
optimizer.zero_grad()
scheduler.step()
step += 1
if step % args.print_steps == 0:
time_per_step = (time.time() - start_time) / max(1, step)
remaining_time = time_per_step * (num_total_steps - step)
remaining_time = time.strftime('%H:%M:%S', time.gmtime(remaining_time))
logging.info(f"Epoch {epoch} step {step} eta {remaining_time}: loss {loss:.3f}, accuracy {accuracy:.3f}")
# 4. validation
if args.ema:
ema.apply_shadow()
loss, results = validate(model, val_dataloader)
if args.ema:
ema.restore()
results = {k: round(v, 4) for k, v in results.items()}
logging.info(f"Epoch {epoch} step {step}: loss {loss:.3f}, {results}")
# 5. save checkpoint
mean_f1 = results['mean_f1']
if mean_f1 > best_score:
best_score = mean_f1
state_dict = model.module.state_dict() if args.device == 'cuda' else model.state_dict()
torch.save({'epoch': epoch, 'model_state_dict': state_dict, 'mean_f1': mean_f1},
f'{args.savedmodel_path}/model_epoch_{epoch}_mean_f1_{mean_f1}.bin')
def main():
args = parse_args()
setup_logging(args)
setup_device(args)
setup_seed(args)
os.makedirs(args.savedmodel_path, exist_ok=True)
logging.info("Training/evaluation parameters: %s", args)
train_and_validate(args)
if __name__ == '__main__':
main()