-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset_synapse.py
93 lines (78 loc) · 3.43 KB
/
dataset_synapse.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import random
import h5py
import numpy as np
import torch
from scipy import ndimage
from scipy.ndimage.interpolation import zoom
from torch.utils.data import Dataset
import cv2
from torchvision.transforms import transforms
def random_rot_flip(image, label):
k = np.random.randint(0, 4)
image = np.rot90(image, k)
label = np.rot90(label, k)
axis = np.random.randint(0, 2)
image = np.flip(image, axis=axis).copy()
label = np.flip(label, axis=axis).copy()
return image, label
def random_rotate(image, label):
angle = np.random.randint(-20, 20)
image = ndimage.rotate(image, angle, order=0, reshape=False)
label = ndimage.rotate(label, angle, order=0, reshape=False)
return image, label
class RandomGenerator(object):
def __init__(self, output_size):
self.output_size = output_size
def __call__(self, sample):
image, label = sample['image'], sample['label']
# label=label[:,:,1:2]
# label[label==2]=1
if random.random() > 0.5:
image, label = random_rot_flip(image, label)
elif random.random() > 0.5:
image, label = random_rotate(image, label)
# image = cv2.resize(image, (224, 224))
# label = cv2.resize(label, (224, 224))
x, y, _ = image.shape
if x != self.output_size[0] or y != self.output_size[1]:
image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y, 1), order=3) # why not 3?
label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
# image = torch.from_numpy(image.astype(np.float32)) / 255
image = torch.from_numpy(image.astype(np.float32))
image = image.permute(2, 0, 1)
label = torch.from_numpy(label.astype(np.float32))
sample = {'image': image, 'label': label.long()}
return sample
class Synapse_dataset(Dataset):
def __init__(self, base_dir, list_dir, split, transform=None):
self.transform = transform # using transform in torch!
self.split = split
self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines()
self.data_dir = base_dir
def __len__(self):
return len(self.sample_list)
def __getitem__(self, idx):
if self.split == "train":
slice_name = self.sample_list[idx].strip('\n')
data_path = os.path.join(self.data_dir, slice_name+'.npz')
data = np.load(data_path)
image, label = data['image'], data['label']
else:
vol_name = self.sample_list[idx].strip('\n')
filepath = self.data_dir + "/{}.npz".format(vol_name) #filepath = self.data_dir + "/{}.npy.h5".format(vol_name)
# data = h5py.File(filepath)
data = np.load(filepath)
image, label = data['image'][:].transpose(2, 0, 1), data['label'][:]
# image = image/255
sample = {'image': image / 255, 'label': label}
# sample = {'image': image, 'label': label}
if self.transform:
sample = self.transform(sample)
sample['case_name'] = self.sample_list[idx].strip('\n')
return sample
if __name__ == '__main__':
dataset = Synapse_dataset(base_dir='../data/ISIC/test_vol_h5', list_dir='../lists/lists_ISIC', split="test_vol",
transform=transforms.Compose(
[RandomGenerator(output_size=[224, 224])]))
dataset.__getitem__(0)