-
Notifications
You must be signed in to change notification settings - Fork 34
/
train.py
84 lines (62 loc) · 3.06 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import sys
import warnings
import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from data.fer2013 import get_dataloaders
from utils.checkpoint import save
from utils.hparams import setup_hparams
from utils.loops import train, evaluate
from utils.setup_network import setup_network
warnings.filterwarnings("ignore")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def run(net, logger, hps):
# Create dataloaders
trainloader, valloader, testloader = get_dataloaders(bs=hps['bs'])
net = net.to(device)
learning_rate = float(hps['lr'])
scaler = GradScaler()
# optimizer = torch.optim.Adadelta(net.parameters(), lr=learning_rate, weight_decay=0.0001)
# optimizer = torch.optim.Adagrad(net.parameters(), lr=learning_rate, weight_decay=0.0001)
# optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=0.0001, amsgrad=True)
# optimizer = torch.optim.ASGD(net.parameters(), lr=learning_rate, weight_decay=0.0001)
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, nesterov=True, weight_decay=0.0001)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.75, patience=5, verbose=True)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 20, gamma=0.5, last_epoch=-1, verbose=True)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(trainloader), epochs=hps['n_epochs'])
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1, verbose=True)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6, last_epoch=-1, verbose=False)
criterion = nn.CrossEntropyLoss()
best_acc = 0.0
print("Training", hps['name'], "on", device)
for epoch in range(hps['start_epoch'], hps['n_epochs']):
acc_tr, loss_tr = train(net, trainloader, criterion, optimizer, scaler)
logger.loss_train.append(loss_tr)
logger.acc_train.append(acc_tr)
acc_v, loss_v = evaluate(net, valloader, criterion)
logger.loss_val.append(loss_v)
logger.acc_val.append(acc_v)
# Update learning rate
scheduler.step(acc_v)
if acc_v > best_acc:
best_acc = acc_v
save(net, logger, hps, epoch + 1)
logger.save_plt(hps)
if (epoch + 1) % hps['save_freq'] == 0:
save(net, logger, hps, epoch + 1)
logger.save_plt(hps)
print('Epoch %2d' % (epoch + 1),
'Train Accuracy: %2.4f %%' % acc_tr,
'Val Accuracy: %2.4f %%' % acc_v,
sep='\t\t')
# Calculate performance on test set
acc_test, loss_test = evaluate(net, testloader, criterion)
print('Test Accuracy: %2.4f %%' % acc_test,
'Test Loss: %2.6f' % loss_test,
sep='\t\t')
if __name__ == "__main__":
# Important parameters
hps = setup_hparams(sys.argv[1:])
logger, net = setup_network(hps)
run(net, logger, hps)