forked from google-research/medical-ai-research-foundations
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
221 lines (194 loc) · 7.48 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
# Copyright 2023 The medical_research_foundations Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data pipeline."""
import functools
from typing import Dict, Any, Optional, Callable
from absl import flags
from . import data_util
import tensorflow.compat.v1 as tf
FLAGS = flags.FLAGS
def pad_to_batch(dataset, batch_size):
"""Pad Tensors to specified batch size.
Args:
dataset: An instance of tf.data.Dataset.
batch_size: The number of samples per batch of input requested.
Returns:
An instance of tf.data.Dataset that yields the same Tensors with the same
structure as the original padded to batch_size along the leading
dimension.
Raises:
ValueError: If the dataset does not comprise any tensors; if a tensor
yielded by the dataset has an unknown number of dimensions or is a
scalar; or if it can be statically determined that tensors comprising
a single dataset element will have different leading dimensions.
"""
def _pad_to_batch(*args):
"""Given Tensors yielded by a Dataset, pads all to the batch size."""
flat_args = tf.nest.flatten(args)
for tensor in flat_args:
if tensor.shape.ndims is None:
raise ValueError(
'Unknown number of dimensions for tensor %s.' % tensor.name
)
if tensor.shape.ndims == 0:
raise ValueError('Tensor %s is a scalar.' % tensor.name)
# This will throw if flat_args is empty. However, as of this writing,
# tf.data.Dataset.map will throw first with an internal error, so we do
# not check this case explicitly.
first_tensor = flat_args[0]
first_tensor_shape = tf.shape(first_tensor)
first_tensor_batch_size = first_tensor_shape[0]
difference = batch_size - first_tensor_batch_size
for i, tensor in enumerate(flat_args):
control_deps = []
if i != 0:
# Check that leading dimensions of this tensor matches the first,
# either statically or dynamically. (If the first dimensions of both
# tensors are statically known, the we have to check the static
# shapes at graph construction time or else we will never get to the
# dynamic assertion.)
if (first_tensor.shape[:1].is_fully_defined() and
tensor.shape[:1].is_fully_defined()):
if first_tensor.shape[0] != tensor.shape[0]:
raise ValueError(
'Batch size of dataset tensors does not match. %s '
'has shape %s, but %s has shape %s'
% (
first_tensor.name,
first_tensor.shape,
tensor.name,
tensor.shape,
)
)
else:
curr_shape = tf.shape(tensor)
control_deps = [
tf.Assert(
tf.equal(curr_shape[0], first_tensor_batch_size),
[
'Batch size of dataset tensors %s and %s do not match. '
'Shapes are' % (tensor.name, first_tensor.name),
curr_shape,
first_tensor_shape,
],
)
]
with tf.control_dependencies(control_deps):
# Pad to batch_size along leading dimension.
flat_args[i] = tf.pad(
tensor, [[0, difference]] + [[0, 0]] * (tensor.shape.ndims - 1)
)
flat_args[i].set_shape([batch_size] + tensor.shape.as_list()[1:])
return tf.nest.pack_sequence_as(args, flat_args)
return dataset.map(_pad_to_batch)
def build_input_fn_for_builder(
builder,
is_training,
cache_dataset=False,
image_size=224,
rotation_range=0,
color_jitter_strength=1.0,
options=data_util.DistortionOptions(),
):
"""Build input function for TFDS builder.
Args:
builder: TFDS builder for specified dataset.
is_training: (bool) Whether to build in training mode.
cache_dataset: (bool) whether to cache the entire dataset in memory.
image_size: (int) input image size, assumes image is square.
rotation_range: If 0 no rotation, for x, rotation in range (-x, x) degree.
color_jitter_strength: (float) The strength of color jittering.
options: Distortion Options, used to keep track of data augmentation options
Returns:
A function that accepts a dict of params and returns a tuple of images and
features, to be used as the input_fn in TPUEstimator. Params must include
batch_size.
"""
def _input_fn(params):
"""Generates TF Dataset from `params`."""
preprocess_fn_pretrain = get_preprocess_fn(
is_training,
is_pretrain=True,
image_size=image_size,
color_jitter_strength=color_jitter_strength,
options=options,
)
preprocess_fn_finetune = get_preprocess_fn(
is_training,
is_pretrain=False,
image_size=image_size,
rotation_range=rotation_range,
color_jitter_strength=color_jitter_strength,
options=options,
)
num_classes = builder.info.features['label'].num_classes
def _map_fn(image, label):
"""Produces multiple transformations of the same batch."""
if FLAGS.train_mode == 'pretrain':
xs = []
for _ in range(2): # Two transformations
xs.append(preprocess_fn_pretrain(image))
image = tf.concat(xs, -1)
label = tf.zeros([num_classes])
else:
image = preprocess_fn_finetune(image)
label = tf.one_hot(label, num_classes)
return image, label, 1.0
dataset = builder.as_dataset(
split=FLAGS.train_split if is_training else FLAGS.eval_split,
shuffle_files=is_training,
as_supervised=True,
)
if cache_dataset:
dataset = dataset.cache()
if is_training:
# Number of random elements to be shuffled in the batch.
buffer_multiplier = 50 if image_size <= 32 else 10
dataset = dataset.shuffle(params['batch_size'] * buffer_multiplier)
# Repeat to iterate infinitely
dataset = dataset.repeat(count=-1)
dataset = dataset.map(
_map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
dataset = dataset.batch(params['batch_size'], drop_remainder=is_training)
dataset = pad_to_batch(dataset, params['batch_size'])
images, labels, mask = tf.data.make_one_shot_iterator(dataset).get_next()
return images, {'labels': labels, 'mask': mask}
return _input_fn
def get_preprocess_fn(
is_training,
is_pretrain,
image_size=224,
rotation_range=0,
color_jitter_strength=1.0,
crop=True,
options=data_util.DistortionOptions(),
):
"""Get function that accepts an image and returns a preprocessed image."""
# Disable test cropping for small images (e.g. CIFAR)
if image_size <= 32:
test_crop = False
else:
test_crop = crop
return functools.partial(
data_util.preprocess_image,
height=image_size,
width=image_size,
is_training=is_training,
color_distort=is_pretrain,
test_crop=test_crop,
rotation_range=rotation_range,
color_jitter_strength=color_jitter_strength,
options=options,
)