-
Notifications
You must be signed in to change notification settings - Fork 108
/
CS-CsiNet_onlytest.py
117 lines (101 loc) · 3.89 KB
/
CS-CsiNet_onlytest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import tensorflow as tf
from keras.layers import Input, Dense, BatchNormalization, Reshape, Conv2D, add, LeakyReLU
from keras.models import Model, model_from_json
from keras.callbacks import TensorBoard, Callback
import scipy.io as sio
import numpy as np
import math
import time
tf.reset_default_graph()
envir = 'indoor' #'indoor' or 'outdoor'
# image params
img_height = 32
img_width = 32
img_channels = 2
img_total = img_height*img_width*img_channels
# network params
residual_num = 2
encoded_dim = 512 #compress rate=1/4->dim.=512, compress rate=1/16->dim.=128, compress rate=1/32->dim.=64, compress rate=1/64->dim.=32
file = 'CS-CsiNet_'+(envir)+'_dim'+str(encoded_dim)
# load json and create model
outfile = "saved_model/model_%s.json"%file
json_file = open(outfile, 'r')
loaded_model_json = json_file.read()
json_file.close()
decoder = model_from_json(loaded_model_json)
# load weights outto new model
outfile = "saved_model/model_%s.h5"%file
decoder.load_weights(outfile)
# Data loading
if envir == 'indoor':
mat = sio.loadmat('data/DATA_Htestin.mat')
x_test = mat['HT'] # array
elif envir == 'outdoor':
mat = sio.loadmat('data/DATA_Htestout.mat')
x_test = mat['HT'] # array
x_test = x_test.astype('float32')
# encoder with random projection
#A = np.random.uniform(low=-0.5, high=0.5, size = (x_train.shape[1], encoded_dim))
mat = sio.loadmat('data/A%d.mat'%(encoded_dim))
A = mat['A'] # array
y_test = np.dot(x_test, A.T)
x_test = np.reshape(x_test, (len(x_test), img_channels, img_height, img_width)) # adapt this if using `channels_first` image data format
tStart = time.time()
x_hat = decoder.predict(y_test)
tEnd = time.time()
print ("It cost %f sec" % ((tEnd - tStart)/x_test.shape[0]))
# Calcaulating the NMSE and rho
if envir == 'indoor':
mat = sio.loadmat('data/DATA_HtestFin_all.mat')
X_test = mat['HF_all']# array
elif envir == 'outdoor':
mat = sio.loadmat('data/DATA_HtestFout_all.mat')
X_test = mat['HF_all']# array
X_test = np.reshape(X_test, (len(X_test), img_height, 125))
x_test_real = np.reshape(x_test[:, 0, :, :], (len(x_test), -1))
x_test_imag = np.reshape(x_test[:, 1, :, :], (len(x_test), -1))
x_test_C = x_test_real-0.5 + 1j*(x_test_imag-0.5)
x_hat_real = np.reshape(x_hat[:, 0, :, :], (len(x_hat), -1))
x_hat_imag = np.reshape(x_hat[:, 1, :, :], (len(x_hat), -1))
x_hat_C = x_hat_real-0.5 + 1j*(x_hat_imag-0.5)
x_hat_F = np.reshape(x_hat_C, (len(x_hat_C), img_height, img_width))
X_hat = np.fft.fft(np.concatenate((x_hat_F, np.zeros((len(x_hat_C), img_height, 257-img_width))), axis=2), axis=2)
X_hat = X_hat[:, :, 0:125]
n1 = np.sqrt(np.sum(np.conj(X_test)*X_test, axis=1))
n1 = n1.astype('float64')
n2 = np.sqrt(np.sum(np.conj(X_hat)*X_hat, axis=1))
n2 = n2.astype('float64')
aa = abs(np.sum(np.conj(X_test)*X_hat, axis=1))
rho = np.mean(aa/(n1*n2), axis=1)
X_hat = np.reshape(X_hat, (len(X_hat), -1))
X_test = np.reshape(X_test, (len(X_test), -1))
power = np.sum(abs(x_test_C)**2, axis=1)
power_d = np.sum(abs(X_hat)**2, axis=1)
mse = np.sum(abs(x_test_C-x_hat_C)**2, axis=1)
print("In "+envir+" environment")
print("When dimension is", encoded_dim)
print("NMSE is ", 10*math.log10(np.mean(mse/power)))
print("Correlation is ", np.mean(rho))
import matplotlib.pyplot as plt
'''abs'''
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
# display origoutal
ax = plt.subplot(2, n, i + 1 )
x_testplo = abs(x_test[i, 0, :, :]-0.5 + 1j*(x_test[i, 1, :, :]-0.5))
plt.imshow(np.max(np.max(x_testplo))-x_testplo.T)
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.invert_yaxis()
# display reconstruction
ax = plt.subplot(2, n, i + 1 + n)
decoded_imgsplo = abs(x_hat[i, 0, :, :]-0.5
+ 1j*(x_hat[i, 1, :, :]-0.5))
plt.imshow(np.max(np.max(decoded_imgsplo))-decoded_imgsplo.T)
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.invert_yaxis()
plt.show()