-
Notifications
You must be signed in to change notification settings - Fork 5
/
run.py
84 lines (71 loc) · 2.37 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from lib.config import cfg, args
import numpy as np
import os
import cv2
def run_dataset():
from lib.datasets import make_data_loader
import tqdm
cfg.train.num_workers = 0
data_loader = make_data_loader(cfg, is_train=False)
for batch in tqdm.tqdm(data_loader):
pass
def run_network():
from lib.networks import make_network
from lib.datasets import make_data_loader
from lib.utils.net_utils import load_network
from lib.utils.data_utils import to_cuda
import tqdm
import torch
import time
network = make_network(cfg).cuda()
load_network(network, cfg.trained_model_dir, epoch=cfg.test.epoch)
network.eval()
data_loader = make_data_loader(cfg, is_train=False)
total_time = 0
for batch in tqdm.tqdm(data_loader):
batch = to_cuda(batch)
with torch.no_grad():
torch.cuda.synchronize()
start = time.time()
network(batch)
torch.cuda.synchronize()
total_time += time.time() - start
print(total_time / len(data_loader))
def run_evaluate():
from lib.datasets import make_data_loader
from lib.evaluators import make_evaluator
import tqdm
import torch
from lib.networks import make_network
from lib.utils import net_utils
import time
network = make_network(cfg).cuda()
net_utils.load_network(network,
cfg.trained_model_dir,
resume=cfg.resume,
epoch=cfg.test.epoch)
network.eval()
data_loader = make_data_loader(cfg, is_train=False)
evaluator = make_evaluator(cfg)
net_time = []
for batch in tqdm.tqdm(data_loader):
for k in batch:
if k != 'meta':
batch[k] = batch[k].cuda()
with torch.no_grad():
torch.cuda.synchronize()
start_time = time.time()
output = network(batch)
torch.cuda.synchronize()
end_time = time.time()
net_time.append(end_time - start_time)
evaluator.evaluate(output, batch)
evaluator.summarize()
if len(net_time) > 1:
print('net_time: ', np.mean(net_time[1:]))
print('fps: ', 1./np.mean(net_time[1:]))
else:
print('net_time: ', np.mean(net_time))
print('fps: ', 1./np.mean(net_time))
if __name__ == '__main__':
globals()['run_' + args.type]()