-
Notifications
You must be signed in to change notification settings - Fork 20
/
demo.py
120 lines (99 loc) · 4.06 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import glob
import os
import time
import torch
from PIL import Image
from vizer.draw import draw_boxes
from dssd.config import cfg
from dssd.data.datasets import COCODataset, VOCDataset
import argparse
import numpy as np
from dssd.data.transforms import build_transforms
from dssd.modeling.detector import build_detection_model
from dssd.utils import mkdir
from dssd.utils.checkpoint import CheckPointer
@torch.no_grad()
def run_demo(cfg, ckpt, score_threshold, images_dir, output_dir, dataset_type):
if dataset_type == "voc":
class_names = VOCDataset.class_names
elif dataset_type == 'coco':
class_names = COCODataset.class_names
else:
raise NotImplementedError('Not implemented now.')
device = torch.device(cfg.MODEL.DEVICE)
model = build_detection_model(cfg)
model = model.to(device)
checkpointer = CheckPointer(model, save_dir=cfg.OUTPUT_DIR)
checkpointer.load(ckpt, use_latest=ckpt is None)
weight_file = ckpt if ckpt else checkpointer.get_checkpoint_file()
print('Loaded weights from {}'.format(weight_file))
image_paths = glob.glob(os.path.join(images_dir, '*.jpg'))
mkdir(output_dir)
cpu_device = torch.device("cpu")
transforms = build_transforms(cfg, is_train=False)
model.eval()
for i, image_path in enumerate(image_paths):
start = time.time()
image_name = os.path.basename(image_path)
image = np.array(Image.open(image_path).convert("RGB"))
height, width = image.shape[:2]
images = transforms(image)[0].unsqueeze(0)
load_time = time.time() - start
start = time.time()
result = model(images.to(device))[0]
inference_time = time.time() - start
result = result.resize((width, height)).to(cpu_device).numpy()
boxes, labels, scores = result['boxes'], result['labels'], result['scores']
indices = scores > score_threshold
boxes = boxes[indices]
labels = labels[indices]
scores = scores[indices]
meters = ' | '.join(
[
'objects {:02d}'.format(len(boxes)),
'load {:03d}ms'.format(round(load_time * 1000)),
'inference {:03d}ms'.format(round(inference_time * 1000)),
'FPS {}'.format(round(1.0 / inference_time))
]
)
print('({:04d}/{:04d}) {}: {}'.format(i + 1, len(image_paths), image_name, meters))
drawn_image = draw_boxes(image, boxes, labels, scores, class_names).astype(np.uint8)
Image.fromarray(drawn_image).save(os.path.join(output_dir, image_name))
def main():
parser = argparse.ArgumentParser(description="DSSD Demo.")
parser.add_argument(
"--config-file",
default="",
metavar="FILE",
help="path to config file",
type=str,
)
parser.add_argument("--ckpt", type=str, default=None, help="Trained weights.")
parser.add_argument("--score_threshold", type=float, default=0.7)
parser.add_argument("--images_dir", default='demo', type=str, help='Specify a image dir to do prediction.')
parser.add_argument("--output_dir", default='demo/result', type=str, help='Specify a image dir to save predicted images.')
parser.add_argument("--dataset_type", default="voc", type=str, help='Specify dataset type. Currently support voc and coco.')
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
print(args)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
print("Loaded configuration file {}".format(args.config_file))
with open(args.config_file, "r") as cf:
config_str = "\n" + cf.read()
print(config_str)
print("Running with config:\n{}".format(cfg))
run_demo(cfg=cfg,
ckpt=args.ckpt,
score_threshold=args.score_threshold,
images_dir=args.images_dir,
output_dir=args.output_dir,
dataset_type=args.dataset_type)
if __name__ == '__main__':
main()