-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
executable file
·111 lines (89 loc) · 3.78 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python3
import argparse
import numpy as np
import time
from experiments.utils import load_data_wrapper, load_model
from experiments.parser import parser
from experiments.models.utils import *
import torch
from torchvision.datasets import CIFAR10, MNIST
from torch.utils.data import Dataset, DataLoader, TensorDataset
## READ CONFIGURATION PARAMETERS
import configparser
config = configparser.ConfigParser()
config.read('config.ini')
low_fname = config.get('general','accuracy_low')
high_fname = config.get('general','accuracy_high')
band_fname = config.get('general','accuracy_band')
def main():
############
## PARSER ##
############
args = parser(train=True, attack=False)
threshold = (*[int(val) for val in args.threshold.split(',')],) if args.threshold else None
input_size = (*[int(val) for val in args.input_size.split(',')],)
output_size = (*[int(val) for val in args.output_size.split(',')],) if args.output_size else None
############
## GLOBAL ##
############
NUM_WORKERS = 2
BATCH_SIZE = 128
##################
## Load Dataset ##
##################
trainset, trainloader, testset, testloader, validloader = load_data_wrapper(BATCH_SIZE, args.root, args.dataset,
args.augment,
input_size=input_size, output_size=output_size,
validation=True)
###########
## Train ##
###########
## Remember to use GPU for training and move dataset & model to GPU memory
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
torch.cuda.empty_cache()
print("=> Using device: %s"%device)
else:
print("=> Using device: CPU")
classes = len(testset.classes)
model, model_name = load_model(args.model, classes, args)
model = model.to(device)
train_time = 0
if args.pretrained:
print("\n=> Using pretrained model.")
model.load_state_dict(torch.load(f"pretrained/{model_name}.pt", map_location=torch.device('cpu')))
else:
print("\n=> Training...")
start_time = time.time()
train(model, trainloader, validloader, epochs=args.epochs, lr=args.lr, lr_decay=args.lr_decay,
model_name=model_name, l_curve_name=model_name)
train_time = time.time() - start_time
print("\n=> [TOTAL TRAINING] %.4f mins."%(train_time/60))
##############
## Evaluate ##
##############
accuracy = calc_accuracy(model, testloader)
out_args = dict(LR=args.lr, LR_Decay=args.lr_decay, Runtime=train_time/60)
if args.model == 'wideresnet':
out_args['depth'] = args.depth
out_args['width'] = args.width
## Only when filter is applied
## test accuracy on filtered test set
if args.threshold:
_, _, _, filtered_testloader = load_data_wrapper(BATCH_SIZE, args.root, args.dataset,
args.augment, filter_test=True,
input_size=input_size, output_size=output_size)
accuracy_filtered = calc_accuracy(model, filtered_testloader)
out_args['filter'] = f"{args.filter}, threshold: {threshold}"
out_args['accuracy_filtered_dataset'] = f'{accuracy_filtered*100}%'
if args.filter == 'low':
f = open(low_fname, 'a+')
elif args.filter == 'high':
f = open(high_fname, 'a+')
else:
f = open(band_fname, 'a+')
f.write(f"{threshold}, {accuracy}\n")
f.close()
write_train_output(model_name, accuracy, **out_args)
if __name__ == "__main__":
main()