-
Notifications
You must be signed in to change notification settings - Fork 50
/
inference.py
74 lines (59 loc) · 2.26 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import imageio
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
from easydict import EasyDict
from torchvision import transforms
from transforms import sep_transforms
from utils.flow_utils import flow_to_image, resize_flow
from utils.torch_utils import restore_model
from models.pwclite import PWCLite
class TestHelper():
def __init__(self, cfg):
self.cfg = EasyDict(cfg)
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
"cpu")
self.model = self.init_model()
self.input_transform = transforms.Compose([
sep_transforms.Zoom(*self.cfg.test_shape),
sep_transforms.ArrayToTensor(),
transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
])
def init_model(self):
model = PWCLite(self.cfg.model)
# print('Number fo parameters: {}'.format(model.num_parameters()))
model = model.to(self.device)
model = restore_model(model, self.cfg.pretrained_model)
model.eval()
return model
def run(self, imgs):
imgs = [self.input_transform(img).unsqueeze(0) for img in imgs]
img_pair = torch.cat(imgs, 1).to(self.device)
return self.model(img_pair)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model', default='checkpoints/KITTI15/pwclite_ar.tar')
parser.add_argument('-s', '--test_shape', default=[384, 640], type=int, nargs=2)
parser.add_argument('-i', '--img_list', nargs='+',
default=['examples/img1.png', 'examples/img2.png'])
args = parser.parse_args()
cfg = {
'model': {
'upsample': True,
'n_frames': len(args.img_list),
'reduce_dense': True
},
'pretrained_model': args.model,
'test_shape': args.test_shape,
}
ts = TestHelper(cfg)
imgs = [imageio.imread(img).astype(np.float32) for img in args.img_list]
h, w = imgs[0].shape[:2]
flow_12 = ts.run(imgs)['flows_fw'][0]
flow_12 = resize_flow(flow_12, (h, w))
np_flow_12 = flow_12[0].detach().cpu().numpy().transpose([1, 2, 0])
vis_flow = flow_to_image(np_flow_12)
fig = plt.figure()
plt.imshow(vis_flow)
plt.show()