-
Notifications
You must be signed in to change notification settings - Fork 8
/
Utils.py
138 lines (116 loc) · 3.51 KB
/
Utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import numpy as np
import scipy.io as sio
import os
import glob
import re
import torch
import torch.nn as nn
import math
import random
def _as_floats(im1, im2):
float_type = np.result_type(im1.dtype, im2.dtype, np.float32)
im1 = np.asarray(im1, dtype=float_type)
im2 = np.asarray(im2, dtype=float_type)
return im1, im2
def compare_mse(im1, im2):
im1, im2 = _as_floats(im1, im2)
return np.mean(np.square(im1 - im2), dtype=np.float64)
def compare_psnr(im_true, im_test, data_range=None):
im_true, im_test = _as_floats(im_true, im_test)
err = compare_mse(im_true, im_test)
return 10 * np.log10((data_range ** 2) / err)
def psnr(img1, img2):
mse = np.mean((img1/255. - img2/255.) ** 2)
if mse < 1.0e-10:
return 100
PIXEL_MAX = 1
return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
def PSNR_GPU(im_true, im_fake):
im_true *= 255
im_fake *= 255
im_true = im_true.round()
im_fake = im_fake.round()
data_range = 255
esp = 1e-12
C = im_true.size()[0]
H = im_true.size()[1]
W = im_true.size()[2]
Itrue = im_true.clone()
Ifake = im_fake.clone()
mse = nn.MSELoss(reduce=False)
err = mse(Itrue, Ifake).sum() / (C*H*W)
psnr = 10. * np.log((data_range**2)/(err.data + esp)) / np.log(10.)
return psnr
def PSNR_Nssr(im_true, im_fake):
mse = ((im_true - im_fake)**2).mean()
psnr = 10. * np.log10(1/mse)
return psnr
def dataparallel(model, ngpus, gpu0=0):
if ngpus==0:
assert False, "only support gpu mode"
gpu_list = list(range(gpu0, gpu0+ngpus))
assert torch.cuda.device_count() >= gpu0 + ngpus
if ngpus > 1:
if not isinstance(model, torch.nn.DataParallel):
model = torch.nn.DataParallel(model, gpu_list).cuda()
else:
model = model.cuda()
elif ngpus == 1:
model = model.cuda()
return model
def findLastCheckpoint(save_dir):
file_list = glob.glob(os.path.join(save_dir, 'model_*.pth'))
if file_list:
epochs_exist = []
for file_ in file_list:
result = re.findall(".*model_(.*).pth.*", file_)
epochs_exist.append(int(result[0]))
initial_epoch = max(epochs_exist)
else:
initial_epoch = 0
return initial_epoch
# load HSIs
def prepare_data(path, file_list, file_num):
HR_HSI = np.zeros((((512,512,28,file_num))))
for idx in range(file_num):
# read HrHSI
HR_code = file_list[idx]
path1 = os.path.join(path) + HR_code + '.mat'
data = sio.loadmat(path1)
HR_HSI[:,:,:,idx] = data['data_slice'] / 65535.0
HR_HSI[HR_HSI < 0.] = 0.
HR_HSI[HR_HSI > 1.] = 1.
return HR_HSI
def loadpath(pathlistfile):
fp = open(pathlistfile)
pathlist = fp.read().splitlines()
fp.close()
random.shuffle(pathlist)
return pathlist
def y2x(y):
## Spilt operator
sz = y.size()
if len(sz) == 3:
y = y.unsqueeze(0)
bs = 1
else:
bs = sz[0]
sz = y.size()
x = torch.zeros([bs, 28, sz[2], sz[2]]).cuda()
for t in range(28):
temp = y[:, :, :, 0 + 2 * t : sz[2] + 2 * t]
x[:, t, :, :] = temp.squeeze(1)
return x
def x2y(x):
## Shift and Sum operator
sz = x.size()
if len(sz) == 3:
x = x.unsqueeze(0).unsqueeze(0)
bs = 1
else:
bs = sz[0]
sz = x.size()
y = torch.zeros([bs, 1, sz[2], sz[2]+2*27]).cuda()
for t in range(28):
y[:, :, :, 0 + 2 * t : sz[2] + 2 * t] = x[:, t, :, :].unsqueeze(1) + y[:, :, :, 0 + 2 * t : sz[2] + 2 * t]
return y