Skip to content

Commit

Permalink
updates after rebuttal: DC-CNN changed to soft format, add ISTA-Net+
Browse files Browse the repository at this point in the history
  • Loading branch information
hellopipu committed Feb 14, 2022
1 parent e779ce7 commit c29814f
Show file tree
Hide file tree
Showing 34 changed files with 406 additions and 65 deletions.
110 changes: 66 additions & 44 deletions Solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
from model.DCCNN import DCCNN
from model.LPDNet import LPDNet
from model.HQSNet import HQSNet

from model.ISTANet_plus import ISTANetplus
import numpy as np

class Solver():
def __init__(self, args):
torch.autograd.set_detect_anomaly(True)
self.args = args
################ experiment settings ################
self.model_name = self.args.model
Expand All @@ -43,22 +45,24 @@ def __init__(self, args):
os.makedirs(self.saveDir)

self.task_name = self.model_name + '_acc_' + str(self.acc) + '_bs_' + str(self.batch_size) \
+ '_lr_' + str(self.lr)
+ '_lr_' + str(self.lr) + 'bf_5_nocat' #first_dc'#'_iter_10'#'_bf=1' #+ _nocat '_bf=1'#
print('task_name: ', self.task_name)
self.model_path = 'weight/' + self.task_name + '_best.pth' # model load path for test and visualization

############################################ Specify network ############################################
if self.model_name == 'dc-cnn':
self.net = DCCNN()
self.net = DCCNN(n_iter=8)
elif self.model_name == 'ista-net-plus':
self.net = ISTANetplus(n_iter=8)
elif self.model_name == 'lpd-net':
self.net = LPDNet()
self.net = LPDNet(n_iter=8)
elif self.model_name == 'hqs-net':
self.net = HQSNet(block_type='cnn')
self.net = HQSNet(block_type='cnn',buffer_size=5, n_iter=8)
elif self.model_name == 'hqs-net-unet':
self.net = HQSNet(block_type='unet')
self.net = HQSNet(block_type='unet', n_iter=10)
else:
assert "wrong model name !"
print('Total # of model params: %.5fM' % (sum(p.numel() for p in self.net.parameters()) / (1024.0 * 1024)))
print('Total # of model params: %.5fM' % (sum(p.numel() for p in self.net.parameters()) / 10.**6))
self.net.cuda()

def train(self):
Expand All @@ -70,10 +74,10 @@ def train(self):
## 2. we train the hqs-net-unet model with ssim + l1 loss, the reason is that, we found when using ms-ssim loss,
## the gradient of ms-ssim may be nan. This bug exists in both pytorch and tensoflow implementation of ms-ssim loss.
## see https://github.com/tensorflow/tensorflow/issues/50400, https://github.com/VainF/pytorch-msssim/issues/12
if self.model_name == 'hqs-net-unet':
self.criterion = CompoundLoss('ssim')
else:
self.criterion = CompoundLoss('ms-ssim')
# if self.model_name == 'hqs-net-unet':
# self.criterion = CompoundLoss('ssim')
# else:
self.criterion = CompoundLoss('ms-ssim')

############################################ Specify optimizer ########################################

Expand All @@ -84,10 +88,12 @@ def train(self):
dataset_train = MyData(self.imageDir_train, self.acc, self.img_size, is_training='train')
dataset_val = MyData(self.imageDir_val, self.acc, self.img_size, is_training='val')

num_workers = 4
use_pin_memory = True
loader_train = Data.DataLoader(dataset_train, batch_size=self.batch_size, shuffle=True, drop_last=True,
num_workers=4, pin_memory=True)
num_workers=num_workers, pin_memory=use_pin_memory)
loader_val = Data.DataLoader(dataset_val, batch_size=self.batch_size, shuffle=False, drop_last=False,
num_workers=4, pin_memory=True)
num_workers=num_workers, pin_memory=use_pin_memory)
self.slices_val = len(dataset_val)
print("slices of 2d train data: ", len(dataset_train))
print("slices of 2d validation data: ", len(dataset_val))
Expand All @@ -99,16 +105,22 @@ def train(self):

start_epoch = 0
best_val_psnr = 0
if 0:
best_name = self.task_name + '_best.pth'
checkpoint = torch.load(join(self.saveDir, best_name))
self.net.load_state_dict(checkpoint['net'])
start_epoch = checkpoint['epoch']+1
best_val_psnr = checkpoint['val_psnr']
print('load pretrained model---, start epoch at, ',start_epoch, ', star_psnr_val is: ',best_val_psnr)
for epoch in range(start_epoch, self.num_epoch):
####################### 1. training #######################

loss_g = self._train_cnn(loader_train)
####################### 2. validate #######################
if epoch == start_epoch:
base_psnr, base_ssim = self._validate_base(loader_val)
if epoch % self.val_on_epochs == 0:
if epoch == 0:
base_psnr, base_ssim = self._validate_base(loader_val)
val_psnr, val_ssim = self._validate(loader_val)

########################## 3. print and tensorboard ########################
print("Epoch {}/{}".format(epoch + 1, self.num_epoch))
print(" base PSNR:\t\t{:.6f}".format(base_psnr))
Expand Down Expand Up @@ -148,19 +160,23 @@ def test(self):
self.net.cuda()
self.net.eval()

base_psnr = 0
test_psnr = 0
base_ssim = 0
test_ssim = 0
base_nrmse = 0
test_nrmse = 0
base_psnr = []
test_psnr = []
base_ssim = []
test_ssim = []
base_nrmse = []
test_nrmse = []
with torch.no_grad():
time_0 = time.time()
for data_dict in tqdm(loader_val):
im_A, im_A_und, k_A_und, mask = data_dict['im_A'].float().cuda(), data_dict['im_A_und'].float().cuda(), \
data_dict['k_A_und'].float().cuda(), \
data_dict['mask_A'].float().cuda()
T1 = self.net(im_A_und, k_A_und, mask)

if self.model_name == 'ista-net-plus':
T1, loss_layers_sym = self.net(im_A_und, k_A_und, mask)
else:
T1 = self.net(im_A_und, k_A_und, mask)
############## convert model ouput to complex value in original range

T1 = output2complex(T1)
Expand All @@ -170,42 +186,45 @@ def test(self):
########################### calulate metrics ###################################
for T1_i, im_A_i, im_A_und_i in zip(T1.cpu().numpy(), im_A.cpu().numpy(), im_A_und.cpu().numpy()):
## for skimage.metrics, input is (im_true,im_pred)
base_nrmse += cal_nrmse(im_A_i, im_A_und_i)
test_nrmse += cal_nrmse(im_A_i, T1_i)
base_ssim += cal_ssim(im_A_i, im_A_und_i)
test_ssim += cal_ssim(im_A_i, T1_i)
base_psnr += cal_psnr(im_A_i, im_A_und_i, data_range=im_A_i.max())
test_psnr += cal_psnr(im_A_i, T1_i, data_range=im_A_i.max())
base_nrmse.append(cal_nrmse(im_A_i, im_A_und_i))
test_nrmse.append(cal_nrmse(im_A_i, T1_i))
base_ssim.append(cal_ssim(im_A_i, im_A_und_i, data_range=im_A_i.max()))
test_ssim.append(cal_ssim(im_A_i, T1_i, data_range=im_A_i.max()))
base_psnr.append(cal_psnr(im_A_i, im_A_und_i, data_range=im_A_i.max()))
test_psnr.append(cal_psnr(im_A_i, T1_i, data_range=im_A_i.max()))

time_1 = time.time()
## comment metric calculation code for more precise inference speed
print('inference speed: {:.5f} ms/slice'.format(1000 * (time_1 - time_0) / len_data))
base_psnr /= len_data
test_psnr /= len_data
base_ssim /= len_data
test_ssim /= len_data
base_nrmse /= len_data
test_nrmse /= len_data

print(" base PSNR:\t\t{:.6f}".format(base_psnr))
print(" test PSNR:\t\t{:.6f}".format(test_psnr))
print(" base SSIM:\t\t{:.6f}".format(base_ssim))
print(" test SSIM:\t\t{:.6f}".format(test_ssim))
print(" base NRMSE:\t\t{:.6f}".format(base_nrmse))
print(" test NRMSE:\t\t{:.6f}".format(test_nrmse))

print(" base PSNR:\t\t{:.6f}, std: {:.6f}".format(np.mean(base_psnr),np.std(base_psnr)))
print(" test PSNR:\t\t{:.6f}, std: {:.6f}".format(np.mean(test_psnr),np.std(test_psnr)))
print(" base SSIM:\t\t{:.6f}, std: {:.6f}".format(np.mean(base_ssim),np.std(base_ssim)))
print(" test SSIM:\t\t{:.6f}, std: {:.6f}".format(np.mean(test_ssim),np.std(test_ssim)))
print(" base NRMSE:\t\t{:.6f}, std: {:.6f}".format(np.mean(base_nrmse),np.std(base_nrmse)))
print(" test NRMSE:\t\t{:.6f}, std: {:.6f}".format(np.mean(test_nrmse),np.std(test_nrmse)))

def _train_cnn(self, loader_train):
self.net.train()
for data_dict in tqdm(loader_train):
im_A, im_A_und, k_A_und, mask = data_dict['im_A'].float().cuda(), data_dict[
'im_A_und'].float().cuda(), data_dict['k_A_und'].float().cuda(), data_dict['mask_A'].float().cuda()
T1 = self.net(im_A_und, k_A_und, mask)
if self.model_name == 'ista-net-plus':
T1,loss_layers_sym = self.net(im_A_und, k_A_und, mask)
else:
T1 = self.net(im_A_und, k_A_und, mask)

T1 = output2complex(T1)
im_A = output2complex(im_A)
############################################# 1.2 update generator #############################################

loss_g = self.criterion(T1, im_A, data_range=im_A.max())
if self.model_name == 'ista-net-plus':
loss_constraint = torch.mean(torch.pow(loss_layers_sym[0], 2))
for k in range(len(loss_layers_sym)-1):
loss_constraint += torch.mean(torch.pow(loss_layers_sym[k + 1], 2))
loss_g = loss_g + 0.01 * loss_constraint

self.optimizer_G.zero_grad()
loss_g.backward()
self.optimizer_G.step()
Expand Down Expand Up @@ -244,7 +263,10 @@ def _validate(self, loader_val):
im_A, im_A_und, k_A_und, mask = data_dict['im_A'].float().cuda(), data_dict[
'im_A_und'].float().cuda(), data_dict['k_A_und'].float().cuda(), data_dict[
'mask_A'].float().cuda()
T1 = self.net(im_A_und, k_A_und, mask)
if self.model_name == 'ista-net-plus':
T1,_ = self.net(im_A_und, k_A_und, mask)
else:
T1 = self.net(im_A_und, k_A_und, mask)
############## convert model ouput to complex value in original range
T1 = output2complex(T1)
im_A = output2complex(im_A)
Expand Down
2 changes: 2 additions & 0 deletions dd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# @author : Bingyu Xin
# @Institute : CS@Rutgers
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def main(args):
############################### experiment settings ##########################
parser.add_argument('--mode', default='train', choices=['train', 'test'],
help='mode for the program')
parser.add_argument('--model', default='hqs-net', choices=['dc-cnn', 'lpd-net', 'hqs-net', 'hqs-net-unet'],
parser.add_argument('--model', default='hqs-net', choices=['dc-cnn', 'lpd-net', 'hqs-net', 'hqs-net-unet','ista-net-plus'],
help='models to reconstruct')
parser.add_argument('--acc', type=int, default=5,
help='Acceleration factor for k-space sampling')
Expand Down
2 changes: 2 additions & 0 deletions model/BasicModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def conv_block(model_name='hqs-net', channel_in=22, n_convs=3, n_filters=32):
layers = []
if model_name == 'dc-cnn':
channel_out = channel_in
if model_name == 'ista-net':
channel_out = n_filters
elif model_name == 'prim-net' or model_name == 'hqs-net':
channel_out = channel_in - 2
elif model_name == 'dual-net':
Expand Down
26 changes: 24 additions & 2 deletions model/DCCNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class DCCNN(nn.Module):
def __init__(self, n_iter=5, n_convs=5, n_filters=64, norm='ortho'):
def __init__(self, n_iter=8, n_convs=6, n_filters=64, norm='ortho'):
'''
DC-CNN modified from paper " A Deep Cascade of Convolutional Neural Networks for Dynamic MR Image Reconstruction "
( https://arxiv.org/pdf/1704.02422.pdf ) ( https://github.com/js3611/Deep-MRI-Reconstruction )
Expand All @@ -19,6 +19,7 @@ def __init__(self, n_iter=5, n_convs=5, n_filters=64, norm='ortho'):
channel_in = 2
rec_blocks = []
self.norm = norm
self.mu = nn.Parameter(torch.Tensor([0.5]))
self.n_iter = n_iter
for i in range(n_iter):
rec_blocks.append(conv_block('dc-cnn', channel_in, n_filters=n_filters, n_convs=n_convs))
Expand All @@ -32,18 +33,39 @@ def dc_operation(self, x_rec, k_un, mask):
k_rec = torch.fft.fft2(torch.view_as_complex(x_rec.contiguous()), norm=self.norm)

k_rec = torch.view_as_real(k_rec)
# noiseless
k_out = k_rec + (k_un - k_rec) * mask

k_out = torch.view_as_complex(k_out)
x_out = torch.view_as_real(torch.fft.ifft2(k_out, norm=self.norm))
x_out = x_out.permute(0, 3, 1, 2)
return x_out
def _forward_operation(self, img, mask):

k = torch.fft.fft2(torch.view_as_complex(img.permute(0, 2, 3, 1).contiguous()),
norm=self.norm)
k = torch.view_as_real(k).permute(0, 3, 1, 2).contiguous()
k = mask * k
return k

def _backward_operation(self, k, mask):

k = mask * k
img = torch.fft.ifft2(torch.view_as_complex(k.permute(0, 2, 3, 1).contiguous()), norm=self.norm)
img = torch.view_as_real(img).permute(0, 3, 1, 2).contiguous()
return img

def update_opration(self, f_1, k, mask):
h_1 = k - self._forward_operation(f_1, mask)
update = f_1 + self.mu * self._backward_operation(h_1, mask)
return update

def forward(self, x, k, m):
for i in range(self.n_iter):
# x = self.update_opration(x, k, m)
x_cnn = self.rec_blocks[i](x)
x = x + x_cnn
x = self.dc_operation(x, k, m)
x = self.update_opration(x, k, m)
return x


Expand Down
19 changes: 16 additions & 3 deletions model/HQSNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class HQSNet(nn.Module):
def __init__(self, buffer_size=5, n_iter=10, n_convs=3, n_filters=32, block_type='cnn', norm='ortho'):
def __init__(self, buffer_size=5, n_iter=8, n_convs=6, n_filters=64, block_type='cnn', norm='ortho'):
'''
HQS-Net
:param buffer_size: buffer_size m
Expand All @@ -20,13 +20,13 @@ def __init__(self, buffer_size=5, n_iter=10, n_convs=3, n_filters=32, block_type
self.m = buffer_size
self.n_iter = n_iter
## the initialization of mu may influence the final accuracy
self.mu = nn.Parameter(2. * torch.ones((1, 1)))
self.mu = nn.Parameter(0.5 * torch.ones((1, 1))) #2
self.block_type = block_type
if self.block_type == 'cnn':
rec_blocks = []
for i in range(self.n_iter):
rec_blocks.append(
conv_block('hqs-net', channel_in=2 * (self.m + 1), n_convs=n_convs, n_filters=n_filters))
conv_block('hqs-net', channel_in=2 * (self.m+1 ), n_convs=n_convs, n_filters=n_filters)) #self.m +
self.rec_blocks = nn.ModuleList(rec_blocks)
elif self.block_type == 'unet':
self.rec_blocks = UNetRes(in_nc=2 * (self.m + 1), out_nc=2 * self.m, nc=[64, 128, 256, 512], nb=4,
Expand Down Expand Up @@ -56,14 +56,27 @@ def update_opration(self, f_1, k, mask):
def forward(self, img, k, mask):

## initialize buffer f : the concatenation of m copies of the complex-valued zero-filled images

f = torch.cat([img] * self.m, 1).to(img.device)

## n reconstruction blocks buff=5_nocat
# for i in range(self.n_iter):
# for j in range(self.m):
# f_1 = f[:, j*2:j*2+2].clone()
# f[:, j*2:j*2+2] = self.update_opration(f_1, k, mask)
# if self.block_type == 'cnn':
# # f = f + self.rec_blocks[i](torch.cat([f, updated_f_1], 1))
# f = f + self.rec_blocks[i](f)
# elif self.block_type == 'unet':
# f = f + self.rec_blocks(torch.cat([f, updated_f_1], 1))

## n reconstruction blocks
for i in range(self.n_iter):
f_1 = f[:, 0:2].clone()
updated_f_1 = self.update_opration(f_1, k, mask)
if self.block_type == 'cnn':
f = f + self.rec_blocks[i](torch.cat([f, updated_f_1], 1))
# f = updated_f_1 + self.rec_blocks[i](updated_f_1)
elif self.block_type == 'unet':
f = f + self.rec_blocks(torch.cat([f, updated_f_1], 1))
return f[:, 0:2]
Expand Down
Loading

0 comments on commit c29814f

Please sign in to comment.