-
Notifications
You must be signed in to change notification settings - Fork 38
/
test_fcn8.py
executable file
·77 lines (68 loc) · 2.48 KB
/
test_fcn8.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
# Test script for the FCN8 (VGG)
# for 5 object categories: HD, FV, RO, RI, WR
# See https://arxiv.org/pdf/2004.01241.pdf
"""
from __future__ import print_function, division
import os
import ntpath
import numpy as np
from PIL import Image
from os.path import join, exists
# local libs
from models.fcn import fcn_8, fcn_8_vgg
from utils.data_utils import getPaths
## experiment directories
# test_dir = "/mnt/data1/ImageSeg/suim/TEST/images/"
test_dir = "data/test/images/"
## sample and ckpt dir
samples_dir = "data/test/output/"
RO_dir = samples_dir + "RO/"
FB_dir = samples_dir + "FB/"
WR_dir = samples_dir + "WR/"
HD_dir = samples_dir + "HD/"
RI_dir = samples_dir + "RI/"
if not exists(samples_dir): os.makedirs(samples_dir)
if not exists(RO_dir): os.makedirs(RO_dir)
if not exists(FB_dir): os.makedirs(FB_dir)
if not exists(WR_dir): os.makedirs(WR_dir)
if not exists(HD_dir): os.makedirs(HD_dir)
if not exists(RI_dir): os.makedirs(RI_dir)
## input/output shapes
im_res_ = (320, 240, 3)
ckpt_name = "fcn8_vgg5.hdf5"
model = fcn_8_vgg(n_classes=5,
input_height=im_res_[1],
input_width=im_res_[0])
print (model.summary())
model.load_weights(join("ckpt/saved/", ckpt_name))
im_h, im_w = im_res_[1], im_res_[0]
def testGenerator():
assert exists(test_dir), "local image path doesnt exist"
imgs = []
for p in getPaths(test_dir):
# read and scale inputs
img = Image.open(p).resize((im_w, im_h))
img = np.array(img) / 255.
img = np.expand_dims(img, axis=0)
# inference
out_img = model.predict(img)
# thresholding
out_img[out_img>0.5] = 1.
out_img[out_img<0.5] = 0.
print ("tested: {0}".format(p))
# get filename
img_name = ntpath.basename(p).split('.')[0] + '.bmp'
# save individual output masks
ROs = np.reshape(out_img[0, :, :, 0], (im_h, im_w))
FVs = np.reshape(out_img[0, :, :, 1], (im_h, im_w))
HDs = np.reshape(out_img[0, :, :, 2], (im_h, im_w))
RIs = np.reshape(out_img[0, :, :, 3], (im_h, im_w))
WRs = np.reshape(out_img[0, :, :, 4], (im_h, im_w))
Image.fromarray(np.uint8(ROs * 255.)).save(RO_dir + img_name)
Image.fromarray(np.uint8(FVs * 255.)).save(FB_dir + img_name)
Image.fromarray(np.uint8(HDs * 255.)).save(HD_dir + img_name)
Image.fromarray(np.uint8(RIs * 255.)).save(RI_dir + img_name)
Image.fromarray(np.uint8(WRs * 255.)).save(WR_dir + img_name)
# test images
testGenerator()