Skip to content

Commit

Permalink
2d model
Browse files Browse the repository at this point in the history
  • Loading branch information
HowieMa committed Feb 14, 2019
1 parent b0d982f commit 2c469df
Show file tree
Hide file tree
Showing 16 changed files with 281 additions and 267 deletions.
316 changes: 169 additions & 147 deletions .idea/workspace.xml

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions data_loader/brats15_v2.py → data_loader/brats15_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from src.utils import *

ddd = ['flair', 't1', 't1c', 't2']
np.random.rand(20180128)


class Brats15DataLoader(Dataset):
Expand Down Expand Up @@ -173,11 +172,13 @@ def get_slices(self, img, label):

# test case
if __name__ =="__main__":
slice = 10
slice = 5
vol_num = 4
data_dir = '../data_sample/'
conf = '../config/sample15.conf'
print ('**** whole tumor task *************')
brats15 = Brats15DataLoader(data_dir=data_dir, task_type='wt', conf=conf, is_train=True)
brats15 = Brats15DataLoader(data_dir=data_dir, task_type='wt',
conf=conf, is_train=False)
volume, labels, subjct = brats15[0]

print 'volume size ...'
Expand All @@ -193,15 +194,15 @@ def get_slices(self, img, label):

print ('get sample of images')
for i in range(4):
sample_img = volume[0][i, slice, :, :] # 192 * 192
sample_img = volume[vol_num][i, slice, :, :] # 192 * 192
sample_img = norm(sample_img)
out = np.zeros((192, 200))
out[:, :96] = sample_img[:,:96]
out[:, 100:196] = sample_img[:, 96:]
scipy.misc.imsave('img/img_%s_wt.jpg' % ddd[i], out)

print ('get sample of labels')
sample_label = labels[0][0, slice, :, :] # 192 * 192
sample_label = labels[vol_num][0, slice, :, :] # 192 * 192
print sample_label.shape
label = np.ones((192, 200))
label[:, :96] = sample_label[:,:96]
Expand Down
113 changes: 64 additions & 49 deletions data_loader/brats15.py → data_loader/brats15_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,15 @@
ddd = ['flair', 't1', 't1c', 't2']



class Brats15DataLoader(Dataset):
def __init__(self, data_dir, direction='axial', task_type='wt',
conf='../config/train15.conf', with_gt=True):
def __init__(self, data_dir, is_train, direction='axial', task_type='wt',
conf='../config/train15.conf'):
self.data_dir = data_dir #
self.img_lists = []

self.data_box = [16, 128, 128] #
self.margin = 5
self.with_gt = with_gt
self.volume_size = 16
self.data_box = [144, 192, 192] # max 145
self.margin = 0
self.is_train = is_train # True or False

self.task_type = task_type # whole tumor, tumor core,
self.direction = direction # 'axial', 'sagittal', or 'coronal'
Expand All @@ -47,16 +46,16 @@ def __init__(self, data_dir, direction='axial', task_type='wt',
for data in train_config:
self.img_lists.append(os.path.join(self.data_dir, data.strip('\n')))

print ('**** Loading data from disk ....')
print('~' * 50)
print ('******** Loading data from disk ********')
self.data = {}
for subject in self.img_lists:
if subject not in self.data:
self.data[subject] = self.get_subject(subject)

np.random.rand(20180128)

print ('**** Finish loading data ...')
print ('**** total number of data is ' + str(len(self.data)))
print ('******** Finish loading data ********')
print ('******** Total number of subject is ' + str(len(self.data)))
print('~' * 50)

def __len__(self):
return len(self.img_lists)
Expand All @@ -66,31 +65,29 @@ def __getitem__(self, item):
subject = self.img_lists[item] # get absolute dir
volume, label = self.data[subject] # get whole data for one subject

# ********** get slice from whole images **********
volume, label = self.get_rand_slices(volume, label)

# ********** change data type from numpy to torch.Tensor **********
volume = torch.from_numpy(volume) # modal(4) * volume_size(16) * height * weight
label = torch.from_numpy(label) # modal(4) * volume_size(16) * height * weight
volume = torch.from_numpy(volume).float() # Float Tensor 4 * 144 * 192 * 192
label = torch.from_numpy(label).float() # Float Tensor 4 * 144 * 192 * 192

return volume.float(), label.float(), subject
# ********** get slice from whole images **********
images_vol, labels_vol = self.get_slices(volume, label)
return images_vol, labels_vol, subject

def get_subject(self, subject):
"""
get
:param subject: absolute dir
:return:
volume 4D numpy
label 4D numpy
volume 4D numpy 4 * 144 * 192 * 192
label 4D numpy 4 * 144 * 192 * 192
"""
# **************** get file ****************
files = os.listdir(subject) # [XXX.Flair, XXX.T1, XXX.T1c, XXX.T2, XXX.OT]

multi_mode_dir = []
label_dir = ""
for f in files:
if f == '.DS_Store':
continue

# if is data
if 'Flair' in f or 'T1' in f or 'T2' in f:
multi_mode_dir.append(f)
Expand All @@ -113,12 +110,12 @@ def get_subject(self, subject):
label = load_mha_as_array(label_dir) #

# *********** image pre-processing *************
# step1 ********* crop none-zero images and labels *********
# step1 ****** resize images and labels to 160 * 192 * 192 *******
for i in range(len(multi_mode_imgs)):
multi_mode_imgs[i] = crop_with_box(multi_mode_imgs[i], bbmin, bbmax, self.data_box)
multi_mode_imgs[i] = normalize_one_volume(multi_mode_imgs[i])

label = crop_with_box(label, bbmin, bbmax, self.data_box)

# step2 ********* transfer images to different direction *********
multi_mode_imgs = transpose_volumes(multi_mode_imgs, self.direction)
label = transpose_volumes([label], self.direction)[0]
Expand All @@ -128,7 +125,7 @@ def get_subject(self, subject):
label = get_whole_tumor_labels(label)
# for whole tumor task, bouding box is self
bbmin = [0, 0, 0]
bbmax = [label.shape[0] - 1, label.shape[1] - 1, label.shape[2] - 1]
bbmax = [label.shape[0], label.shape[1], label.shape[2]]

elif self.task_type == 'tc':
# for tumor core task, bounding box is the whole tumor box
Expand All @@ -146,53 +143,71 @@ def get_subject(self, subject):

return volume, label

def get_rand_slices(self, img, label):
def get_slices(self, img, label):
"""
get volume randomly
:param img:
:param label:
For training, get volume randomly; For testing, get volume step by step
:param img: 4D Float Tensor 4 * 144 * 192 * 192
:param label: 4D Float Tensor 4 * 144 * 192 * 192
:return:
img Float Tensor List [4 * 16 * 192 * 192] * 9
label Float Tensor List [4 * 16 * 192 * 192] * 9
"""
d, w, h = self.data_box # default is [16, 128, 128]
start = np.random.randint(0, img.shape[1] - d + 1)
width_start = np.random.randint(0, img.shape[2] - w + 1)
height_start = np.random.randint(0, img.shape[3] - h + 1)
times = img.shape[1] / self.volume_size # 144 / 16 = 9

images_vol = []
labels_vol = []
for t in range(times):
if self.is_train is True:
start = np.random.randint(0, img.shape[1] - self.volume_size + 1)
else:
start = t * self.volume_size # 0 ,16, 32, 48, ...

img = img[:, start: start + d,
width_start:width_start + w,
height_start:height_start + h]
label = label[:, start: start + d,
width_start:width_start + w,
height_start:height_start + h]
return img, label
im = img[:, start: start + self.volume_size, :, :]
lbl = label[:, start: start + self.volume_size, :, :]
images_vol.append(im)
labels_vol.append(lbl)

return images_vol, labels_vol


# test case
if __name__ =="__main__":
slice = 10
slice = 5
vol_num = 4
data_dir = '../data_sample/'
conf = '../config/sample15.conf'
print ('**** whole tumor task *************')
brats15 = Brats15DataLoader(data_dir=data_dir, task_type='wt', conf=conf)
brats15 = Brats15DataLoader(data_dir=data_dir, task_type='wt',
conf=conf, is_train=False)
volume, labels, subjct = brats15[0]

print 'volume size ...'
print len(volume)

print ('image size ......')
print (volume.shape) # (4, 16, 128, 128)
print (volume[0].shape) # (4, 16, 192, 192)

print ('label size ......')
print (labels.shape) # (1, 16, 128, 128)
print (labels[0].shape) # (1, 16, 192, 192)

print subjct

print ('get sample of images')
for i in range(4):
sample_img = volume[i, slice, :, :] # 128 * 128
print sample_img.shape # 128 * 128
scipy.misc.imsave('img/img_%s_wt.jpg' % ddd[i], sample_img)
sample_img = volume[vol_num][i, slice, :, :] # 192 * 192
sample_img = norm(sample_img)
out = np.zeros((192, 200))
out[:, :96] = sample_img[:,:96]
out[:, 100:196] = sample_img[:, 96:]
scipy.misc.imsave('img/img_%s_wt.jpg' % ddd[i], out)

print ('get sample of labels')
sample_label = labels[0, slice, :, :] # 128 * 128
sample_label = labels[vol_num][0, slice, :, :] # 192 * 192
print sample_label.shape
scipy.misc.imsave('img/label_wt.jpg', sample_label)
label = np.ones((192, 200))
label[:, :96] = sample_label[:,:96]
label[:, 100:196] = sample_label[:, 96:]
scipy.misc.imsave('img/label_wt.jpg', label)



Expand Down
Binary file modified data_loader/img/img_flair_wt.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified data_loader/img/img_t1_wt.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified data_loader/img/img_t1c_wt.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified data_loader/img/img_t2_wt.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified data_loader/img/label_wt.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions model/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def forward(self, x):
return x


class UpBlock(nn.Module):
class UpBlock3d(nn.Module):
def __init__(self, in_ch, out_ch):
super(UpBlock, self).__init__()
super(UpBlock3d, self).__init__()
self.up_conv = ConvTrans3d(in_ch, out_ch)
self.conv = ConvBlock3d(2 * out_ch, out_ch)

Expand Down
64 changes: 14 additions & 50 deletions model/msnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn as nn
import torch.nn.functional as F

from src.utils import *

class conv(nn.Module):
def __init__(self,in_ch,out_ch):
Expand All @@ -20,9 +21,6 @@ def forward(self, x):


class ResBlock(nn.Module):
"""
"""
def __init__(self,in_ch, out_ch, d=1):
"""
Expand Down Expand Up @@ -170,60 +168,36 @@ def __init__(self, n_channels, out_ch, n_classes):
self.out = up(7*n_classes, n_classes, 1)

def forward(self, x):
x = self.conv(x)
x = self.block0(x)
x0, x = self.block1(x)
x0 = self.up0(x0)
x = self.block2(x)

x1 =self.up1(x)
x = self.block3(x)
x = self.up2(x)
x = torch.cat([x0, x1, x], dim=1)
x = self.out(x)
return F.sigmoid(x)


class ENET(nn.Module):
def __init__(self, n_channels, out_ch, n_classes):
super(ENET, self).__init__()

self.conv = conv(n_channels, out_ch)
self.block0 = ResBlock2(out_ch, out_ch, 1) #
self.block1 = ResBlock2(out_ch, out_ch, 1)
self.block2 = ResBlock3(out_ch, out_ch, 1)
self.block3 = ResBlock3(out_ch, out_ch, 0)

self.up0 = up(out_ch, n_classes, 1)
self.up1 = up(out_ch, n_classes * 2, 2)
self.up2 = up(out_ch, n_classes * 2, 2)
"""
self.out = up(5 * n_classes, n_classes, 1)
:param x: 5D Tensor BatchSize * 4(modal) * 16 * W * H
:return:
"""
x = self.conv(x) # BatchSize * 32 * 16 * W * H
x = self.block0(x) # BatchSize * 32 * 16/2 * W/2 * H/2
print x.shape

def forward(self, x):
x = self.conv(x)
x,_ = self.block0(x)
x0, x = self.block1(x)

x0 = self.up0(x0)
x = self.block2(x)
x1 = self.up1(x)

x1 =self.up1(x)
x = self.block3(x)
x = self.up2(x)

x = torch.cat([x0, x1, x], dim=1)

x = self.out(x)
return F.sigmoid(x)
return x


if __name__ =='__main__':
x = torch.ones(1, 4, 24, 24, 24)
x = torch.ones(1, 4, 16, 192, 192)
print ('test wnet............')
print ('shape of X ')
print x.shape

net = WNET(1, 32, 4)
net = WNET(4, 32, 2)
print"total parameter:" + str(netSize(net)) # 241784
if torch.cuda.is_available():
net = net.cuda()
x = x.cuda()
Expand All @@ -232,17 +206,7 @@ def forward(self, x):
print ('shape of Y ')
print (y.shape)

print ('test Enet.............')
print ('shape of X ')
print x.shape
net = ENET(1, 32, 4)
if torch.cuda.is_available():
net = net.cuda()
x = x.cuda()

y = net(x)
print ('shape of Y ')
print (y.shape)



Expand Down
Loading

0 comments on commit 2c469df

Please sign in to comment.