-
Notifications
You must be signed in to change notification settings - Fork 1
/
inference.py
71 lines (60 loc) · 2.42 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numpy as np
import torch
import torch.utils.data as data
from torchvision import transforms
import matplotlib.pyplot as plt
from NYUDepth import NYUDepth
from utils import make_depth_fig, resize_image_depth, make_error_map, dump
if __name__ == "__main__":
model_path = 'upproj.pth'
model = torch.load(model_path)
model.to('cpu')
model.eval()
train_dataset = NYUDepth(root='/data/mengli/undergrade//nyu-dataset',
mode='train')
test_dataset = NYUDepth(root='/data/mengli/undergrade//nyu-dataset',
mode='test')
train_loader = data.DataLoader(train_dataset,
batch_size=1,
shuffle=True)
test_loader = data.DataLoader(test_dataset,
batch_size=1,
shuffle=True)
print('Size of train dataset:', len(train_dataset))
print('Size of test dataset:', len(test_dataset))
with torch.no_grad():
for i, (img, depth) in enumerate(train_loader):
img, depth = img.float(), depth.float()
depth = depth.unsqueeze(1)
output = model(img)
image, depth_gt, depth_pred = resize_image_depth(
img, depth, output)
fig = make_depth_fig(image, depth_gt.T, depth_pred.T)
fig.suptitle('Training dataset')
fig2, error_map = make_error_map(image, depth_gt.T, depth_pred.T)
plt.show(block=False)
# dump
dump(image=image,
depth=depth_pred.T,
depth_gt=depth_gt.T,
error_map=error_map,
prefix='infer_train',
n=i)
break
for i, (img, depth) in enumerate(test_loader):
img, depth = img.float(), depth.float()
depth = depth.unsqueeze(1)
output = model(img)
image, depth_gt, depth_pred = resize_image_depth(
img, depth, output)
fig = make_depth_fig(image, depth_gt.T, depth_pred.T)
fig.suptitle('Testing dataset')
fig2 = make_error_map(image, depth_gt.T, depth_pred.T)
plt.show()
dump(image=image,
depth=depth_pred.T,
depth_gt=depth_gt.T,
error_map=error_map,
prefix='infer_test',
n=i)
break