-
Notifications
You must be signed in to change notification settings - Fork 8
/
Test.py
82 lines (71 loc) · 2.99 KB
/
Test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import torch.utils.data as tud
import os
import argparse
from Utils import *
import scipy.io as sio
import numpy as np
from Dataset import dataset
from torch.autograd import Variable
import time
from skimage import measure
from thop import profile
from HerosNet import HerosNet
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
parser = argparse.ArgumentParser(description="PyTorch Spectral Compressive Imaging")
parser.add_argument('--data_path', default='/userhome/zxy/test/', type=str,help='path of data')
parser.add_argument('--mask_path', default='/userhome/zxy/mask.mat', type=str,help='path of mask')
parser.add_argument("--size", default=256, type=int, help='the size of trainset image')
parser.add_argument("--trainset_num", default=2000, type=int, help='total number of trainset')
parser.add_argument("--testset_num", default=10, type=int, help='total number of testset')
parser.add_argument("--seed", default=1, type=int, help='Random_seed')
parser.add_argument("--batch_size", default=1, type=int, help='batch_size')
parser.add_argument("--isTrain", default=False, type=bool, help='train or test')
opt = parser.parse_args()
print(opt)
def prepare_data_test(path, file_num):
HR_HSI = np.zeros((((256,256,28,file_num))))
for idx in range(file_num):
#### read HrHSI
path1 = os.path.join(path) + 'scene%02d.mat' % (idx+1)
data = sio.loadmat(path1)
HR_HSI[:,:,:,idx] = data['img']
HR_HSI[HR_HSI < 0.] = 0.
HR_HSI[HR_HSI > 1.] = 1.
return HR_HSI
HR_HSI = prepare_data_test(opt.data_path, 10)
dataset = dataset(opt, HR_HSI)
loader_train = tud.DataLoader(dataset, batch_size=opt.batch_size)
for i in range(69, 70):
model = HerosNet(Ch=28, stages=8, size=256)
model = dataparallel(model, 1)
model.load_state_dict(torch.load('./checkpoint/net.pkl'))
model = model.eval()
psnr_total = 0
ssim_total = 0
ssim_list = []
k = 0
for j, label in enumerate(loader_train):
with torch.no_grad():
label = Variable(label)
label = label.cuda()
start = time.time()
out, Phi = model(label)
elapsed = (time.time() - start)
result = out[7]
result = result.clamp(min=0.,max=1.)
psnr = compare_psnr(result.cpu().numpy(), label.cpu().numpy(), data_range=1.0)
ssim = measure.compare_ssim(result.cpu().permute(2,3,1,0).squeeze(3).numpy(), label.cpu().permute(2,3,1,0).squeeze(3).numpy(), multichannel=True)
psnr_total = psnr_total + psnr
ssim_total = ssim_total + ssim
k = k + 1
print(psnr, ssim)
res = result.cpu().permute(2,3,1,0).squeeze(3).numpy()
save_path = './Results/' + str(j + 1) + '.mat'
sio.savemat(save_path, {'res':res})
mask = Phi.cpu().numpy()
save_path3 = './Results/binarymask.mat'
sio.savemat(save_path3, {'mask': mask})
print(k)
print("model %d, Avg PSNR = %.4f, Avg SSIM = %.4f" % (i, psnr_total/k, ssim_total/k))