-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_miniplaces.py
79 lines (67 loc) · 2.43 KB
/
eval_miniplaces.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# python imports
import os
import time
import argparse
from tqdm import tqdm
# torch imports
import torch
import torch.nn as nn
import torch.optim as optim
# helper functions for computer vision
import torchvision
import torchvision.transforms as transforms
from dataloader import MiniPlaces
from student_code import LeNet, test_model
# main function for training and testing
def main(args):
# set up random seed
torch.manual_seed(0)
###################################
# setup model #
###################################
model = LeNet()
# set up transforms to transform the PIL Image to tensors
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
################################
# setup dataset and dataloader #
################################
data_folder = './data'
os.makedirs(os.path.expanduser(data_folder), exist_ok=True)
test_set = MiniPlaces(
root=data_folder, split="val", download=True, transform=test_transform)
test_loader = torch.utils.data.DataLoader(
test_set, batch_size=32, shuffle=False)
################################
# evaluating the model #
################################
# load from a previous model
if not args.load:
args.load = "./outputs/model_best.pth.tar"
if os.path.isfile(args.load):
print("=> loading checkpoint '{:s}'".format(args.load))
checkpoint = torch.load(args.load)
# load model weight
model.load_state_dict(checkpoint['state_dict'])
epoch = checkpoint['epoch']
print("=> loaded checkpoint '{:s}' (epoch {:d})".format(
args.load, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.load))
return
# evalution and timing
print("Evaluting the model ...\n")
start = time.time()
# evaluate the loaded model
acc = test_model(model, test_loader, epoch-1)
end = time.time()
print("Evaluation took {:0.2f} sec".format(end - start))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Image Classification using Pytorch')
parser.add_argument('--load', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
args = parser.parse_args()
main(args)