forked from cchen-cc/SIFA
-
Notifications
You must be signed in to change notification settings - Fork 4
/
data_loader.py
79 lines (56 loc) · 2.95 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import tensorflow as tf
import numpy as np
BATCH_SIZE = 8
def _decode_samples(image_list, shuffle=False):
decomp_feature = {
'dsize_dim0': tf.FixedLenFeature([], tf.int64),
'dsize_dim1': tf.FixedLenFeature([], tf.int64),
'dsize_dim2': tf.FixedLenFeature([], tf.int64),
'lsize_dim0': tf.FixedLenFeature([], tf.int64),
'lsize_dim1': tf.FixedLenFeature([], tf.int64),
'lsize_dim2': tf.FixedLenFeature([], tf.int64),
'data_vol': tf.FixedLenFeature([], tf.string),
'label_vol': tf.FixedLenFeature([], tf.string)}
raw_size = [256, 256, 3]
volume_size = [256, 256, 3]
label_size = [256, 256, 1]
data_queue = tf.train.string_input_producer(image_list, shuffle=shuffle)
reader = tf.TFRecordReader()
fid, serialized_example = reader.read(data_queue)
parser = tf.parse_single_example(serialized_example, features=decomp_feature)
data_vol = tf.decode_raw(parser['data_vol'], tf.float32)
data_vol = tf.reshape(data_vol, raw_size)
data_vol = tf.slice(data_vol, [0, 0, 0], volume_size)
label_vol = tf.decode_raw(parser['label_vol'], tf.float32)
label_vol = tf.reshape(label_vol, raw_size)
label_vol = tf.slice(label_vol, [0, 0, 1], label_size)
batch_y = tf.one_hot(tf.cast(tf.squeeze(label_vol), tf.uint8), 5)
return tf.expand_dims(data_vol[:, :, 1], axis=2), batch_y
def _load_samples(source_pth, target_pth):
with open(source_pth, 'r') as fp:
rows = fp.readlines()
imagea_list = [row[:-1] for row in rows]
with open(target_pth, 'r') as fp:
rows = fp.readlines()
imageb_list = [row[:-1] for row in rows]
data_vola, label_vola = _decode_samples(imagea_list, shuffle=True)
data_volb, label_volb = _decode_samples(imageb_list, shuffle=True)
return data_vola, data_volb, label_vola, label_volb
def load_data(source_pth, target_pth, do_shuffle=True):
image_i, image_j, gt_i, gt_j = _load_samples(source_pth, target_pth)
if 'mr' in source_pth:
image_i = tf.subtract(tf.multiply(tf.div(tf.subtract(image_i, -1.7), tf.subtract(4.0, -1.7)), 2.0), 1)
elif 'ct' in source_pth:
image_i = tf.subtract(tf.multiply(tf.div(tf.subtract(image_i, -1.9), tf.subtract(3.0, -1.9)), 2.0), 1)
if 'ct' in target_pth:
image_j = tf.subtract(tf.multiply(tf.div(tf.subtract(image_j, -1.9), tf.subtract(3.0, -1.9)), 2.0), 1)
elif 'mr' in target_pth:
image_j = tf.subtract(tf.multiply(tf.div(tf.subtract(image_j, -1.7), tf.subtract(4.0, -1.7)), 2.0), 1)
image_i = tf.concat((image_i,image_i,image_i), axis=2)
image_j = tf.concat((image_j,image_j,image_j), axis=2)
# Batch
if do_shuffle is True:
images_i, images_j, gt_i, gt_j = tf.train.shuffle_batch([image_i, image_j, gt_i, gt_j], BATCH_SIZE, 500, 100)
else:
images_i, images_j, gt_i, gt_j = tf.train.batch([image_i, image_j, gt_i, gt_j], batch_size=BATCH_SIZE, num_threads=1, capacity=500)
return images_i, images_j, gt_i, gt_j