Skip to content

Commit

Permalink
update docstrings
Browse files Browse the repository at this point in the history
Signed-off-by: kta-intel <[email protected]>
  • Loading branch information
kta-intel committed Jul 12, 2024
1 parent 1810f31 commit 6f3f3da
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 192 deletions.
4 changes: 2 additions & 2 deletions openfl-workspace/tf_2dunet/plan/data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
# all keys under 'collaborators' corresponds to a specific colaborator name the corresponding dictionary has data_name, data_path pairs.
# Note that in the mnist case we do not store the data locally, and the data_path is used to pass an integer that helps the data object
# construct the shard of the mnist dataset to be use for this collaborator.
collaborator1,~/MICCAI_BraTS_2019_Data_Training/HGG/0
collaborator2,~/MICCAI_BraTS_2019_Data_Training/HGG/1
collaborator1,../data/MICCAI_BraTS_2019_Data_Training/HGG/0
collaborator2,../data/MICCAI_BraTS_2019_Data_Training/HGG/1
275 changes: 124 additions & 151 deletions openfl-workspace/tf_2dunet/src/taskrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,70 +10,115 @@
from openfl.federated import TensorFlowTaskRunner

class UNet2D(TensorFlowTaskRunner):
"""Initialize.
Args:
**kwargs: Additional parameters to pass to the function
"""

def __init__(self, initial_filters=16,
depth=5,
batch_norm=True,
use_upsampling=False,
**kwargs):
"""Initialize.
Args:
**kwargs: Additional parameters to pass to the function
"""
super().__init__(**kwargs)

self.model = self.create_model(
self.model = self.build_model(
input_shape=self.feature_shape,
n_cl_out=self.data_loader.num_classes,
initial_filters=initial_filters,
use_upsampling=use_upsampling,
depth=depth,
batch_norm=batch_norm,
**kwargs
)
self.initialize_tensorkeys_for_functions()

self.model.summary(print_fn=self.logger.info, line_length=120)

def create_model(self,
input_shape,
n_cl_out=1,
use_upsampling=False,
dropout=0.2,
print_summary=True,
seed=816,
depth=5,
dropout_at=(2, 3),
initial_filters=16,
batch_norm=True,
**kwargs):
"""Create the TensorFlow 3D U-Net CNN model.
def build_model(self,
input_shape,
n_cl_out=1,
use_upsampling=False,
dropout=0.2,
seed=816,
depth=5,
dropout_at=(2, 3),
initial_filters=16,
batch_norm=True):
"""
Build and compile 2D UNet model.
Args:
input_shape (list): input shape of the data
n_cl_out (int): Number of output classes in label (Default=1)
**kwargs: Additional parameters to pass to the function
input_shape (List[int]): The shape of the data
n_cl_out (int): Number of channels in output layer (Default=1)
use_upsampling (bool): True = use bilinear interpolation;
False = use transposed convolution (Default=False)
dropout (float): Dropout percentage (Default=0.2)
seed: random seed (Default=816)
depth (int): Number of max pooling layers in encoder (Default=5)
dropout_at (List[int]): Layers to perform dropout after (Default=[2,3])
initial_filters (int): Number of filters in first convolutional layer (Default=16)
batch_norm (bool): Aply batch normalization (Default=True)
Returns:
keras.src.engine.functional.Functional
A compiled Keras model ready for training.
"""

model = build_model(input_shape,
n_cl_out=n_cl_out,
use_upsampling=use_upsampling,
dropout=dropout,
print_summary=print_summary,
seed=seed,
depth=depth,
dropout_at=dropout_at,
initial_filters=initial_filters,
batch_norm=batch_norm)

if (input_shape[0] % (2**depth)) > 0:
raise ValueError(f'Crop dimension must be a multiple of 2^(depth of U-Net) = {2**depth}')

inputs = tf.keras.layers.Input(input_shape, name='brats_mr_image')

activation = tf.keras.activations.relu

params = {'kernel_size': (3, 3), 'activation': activation,
'padding': 'same',
'kernel_initializer': tf.keras.initializers.he_uniform(seed=seed)}

convb_layers = {}

net = inputs
filters = initial_filters
for i in range(depth):
name = f'conv{i + 1}a'
net = tf.keras.layers.Conv2D(name=name, filters=filters, **params)(net)
if i in dropout_at:
net = tf.keras.layers.Dropout(dropout)(net)
name = f'conv{i + 1}b'
net = tf.keras.layers.Conv2D(name=name, filters=filters, **params)(net)
if batch_norm:
net = tf.keras.layers.BatchNormalization()(net)
convb_layers[name] = net
# only pool if not last level
if i != depth - 1:
name = f'pool{i + 1}'
net = tf.keras.layers.MaxPooling2D(name=name, pool_size=(2, 2))(net)
filters *= 2

# do the up levels
filters //= 2
for i in range(depth - 1):
if use_upsampling:
up = tf.keras.layers.UpSampling2D(
name=f'up{depth + i + 1}', size=(2, 2))(net)
else:
up = tf.keras.layers.Conv2DTranspose(name=f'transConv{depth + i + 1}',
filters=filters,
kernel_size=(2, 2),
strides=(2, 2),
padding='same')(net)
net = tf.keras.layers.concatenate(
[up, convb_layers[f'conv{depth - i - 1}b']],
axis=-1
)
net = tf.keras.layers.Conv2D(
name=f'conv{depth + i + 1}a',
filters=filters, **params)(net)
net = tf.keras.layers.Conv2D(
name=f'conv{depth + i + 1}b',
filters=filters, **params)(net)
filters //= 2

net = tf.keras.layers.Conv2D(name='prediction', filters=n_cl_out,
kernel_size=(1, 1),
activation='sigmoid')(net)

model = tf.keras.models.Model(inputs=[inputs], outputs=[net])

model.compile(
loss=dice_loss,
Expand All @@ -84,18 +129,22 @@ def create_model(self,
return model

def train_(self, batch_generator, metrics: list = None, **kwargs):
"""Train single epoch.
"""
Train single epoch.
Override this function for custom training.
Args:
batch_generator: Generator of training batches.
batch_generator (generator): Generator of training batches.
Each batch is a tuple of N train images and N train labels
where N is the batch size of the DataLoader of the current TaskRunner instance.
metrics (List[str]): A list of metric names to compute and save
**kwargs (dict): Additional keyword arguments
epochs: Number of epochs to train.
metrics: Names of metrics to save.
Returns:
list: Metric objects containing the computed metrics
"""
import pdb; pdb.set_trace()
history = self.model.fit(batch_generator,
verbose=1,
**kwargs)
Expand All @@ -108,11 +157,16 @@ def train_(self, batch_generator, metrics: list = None, **kwargs):

def dice_coef(target, prediction, axis=(1, 2), smooth=0.0001):
"""
Sorenson Dice.
Calculate the Sorenson-Dice coefficient.
Args:
target (tf.Tensor): The ground truth binary labels.
prediction (tf.Tensor): The predicted binary labels, rounded to 0 or 1.
axis (tuple, optional): The axes along which to compute the coefficient, typically the spatial dimensions.
smooth (float, optional): A small constant added to numerator and denominator for numerical stability.
Returns
-------
dice coefficient (float)
Returns:
tf.Tensor: The mean Dice coefficient over the batch.
"""
prediction = tf.round(prediction) # Round to 0 or 1

Expand All @@ -127,13 +181,18 @@ def dice_coef(target, prediction, axis=(1, 2), smooth=0.0001):

def soft_dice_coef(target, prediction, axis=(1, 2), smooth=0.0001):
"""
Soft Sorenson Dice.
Calculate the soft Sorenson-Dice coefficient.
Does not round the predictions to either 0 or 1.
Returns
-------
soft dice coefficient (float)
Args:
target (tf.Tensor): The ground truth binary labels.
prediction (tf.Tensor): The predicted probabilities.
axis (tuple, optional): The axes along which to compute the coefficient, typically the spatial dimensions.
smooth (float, optional): A small constant added to numerator and denominator for numerical stability.
Returns:
tf.Tensor: The mean soft Dice coefficient over the batch.
"""
intersection = tf.reduce_sum(target * prediction, axis=axis)
union = tf.reduce_sum(target + prediction, axis=axis)
Expand All @@ -146,15 +205,20 @@ def soft_dice_coef(target, prediction, axis=(1, 2), smooth=0.0001):

def dice_loss(target, prediction, axis=(1, 2), smooth=0.0001):
"""
Sorenson (Soft) Dice loss.
Calculate the (Soft) Sorenson-Dice loss.
Using -log(Dice) as the loss since it is better behaved.
Also, the log allows avoidance of the division which
can help prevent underflow when the numbers are very small.
Returns
-------
dice loss (float)
Args:
target (tf.Tensor): The ground truth binary labels.
prediction (tf.Tensor): The predicted probabilities.
axis (tuple, optional): The axes along which to compute the loss, typically the spatial dimensions.
smooth (float, optional): A small constant added to numerator and denominator for numerical stability.
Returns:
tf.Tensor: The mean Dice loss over the batch.
"""
intersection = tf.reduce_sum(prediction * target, axis=axis)
p = tf.reduce_sum(prediction, axis=axis)
Expand All @@ -163,95 +227,4 @@ def dice_loss(target, prediction, axis=(1, 2), smooth=0.0001):
denominator = tf.reduce_mean(t + p + smooth)
dice_loss = -tf.math.log(2. * numerator) + tf.math.log(denominator)

return dice_loss


def build_model(input_shape,
n_cl_out=1,
use_upsampling=False,
dropout=0.2,
seed=816,
depth=5,
dropout_at=(2, 3),
initial_filters=16,
batch_norm=True,
**kwargs):
"""Build the TensorFlow model.
Args:
input_tensor: input shape ot the model
use_upsampling (bool): True = use bilinear interpolation;
False = use transposed convolution (Default=False)
n_cl_out (int): Number of channels in output layer (Default=1)
dropout (float): Dropout percentage (Default=0.2)
print_summary (bool): True = print the model summary (Default = True)
seed: random seed (Default=816)
depth (int): Number of max pooling layers in encoder (Default=5)
dropout_at: Layers to perform dropout after (Default=[2,3])
initial_filters (int): Number of filters in first convolutional
layer (Default=16)
batch_norm (bool): True = use batch normalization (Default=True)
**kwargs: Additional parameters to pass to the function
"""
if (input_shape[0] % (2**depth)) > 0:
raise ValueError(f'Crop dimension must be a multiple of 2^(depth of U-Net) = {2**depth}')

inputs = tf.keras.layers.Input(input_shape, name='brats_mr_image')

activation = tf.keras.activations.relu

params = {'kernel_size': (3, 3), 'activation': activation,
'padding': 'same',
'kernel_initializer': tf.keras.initializers.he_uniform(seed=seed)}

convb_layers = {}

net = inputs
filters = initial_filters
for i in range(depth):
name = f'conv{i + 1}a'
net = tf.keras.layers.Conv2D(name=name, filters=filters, **params)(net)
if i in dropout_at:
net = tf.keras.layers.Dropout(dropout)(net)
name = f'conv{i + 1}b'
net = tf.keras.layers.Conv2D(name=name, filters=filters, **params)(net)
if batch_norm:
net = tf.keras.layers.BatchNormalization()(net)
convb_layers[name] = net
# only pool if not last level
if i != depth - 1:
name = f'pool{i + 1}'
net = tf.keras.layers.MaxPooling2D(name=name, pool_size=(2, 2))(net)
filters *= 2

# do the up levels
filters //= 2
for i in range(depth - 1):
if use_upsampling:
up = tf.keras.layers.UpSampling2D(
name=f'up{depth + i + 1}', size=(2, 2))(net)
else:
up = tf.keras.layers.Conv2DTranspose(name=f'transConv{depth + i + 1}',
filters=filters,
kernel_size=(2, 2),
strides=(2, 2),
padding='same')(net)
net = tf.keras.layers.concatenate(
[up, convb_layers[f'conv{depth - i - 1}b']],
axis=-1
)
net = tf.keras.layers.Conv2D(
name=f'conv{depth + i + 1}a',
filters=filters, **params)(net)
net = tf.keras.layers.Conv2D(
name=f'conv{depth + i + 1}b',
filters=filters, **params)(net)
filters //= 2

net = tf.keras.layers.Conv2D(name='prediction', filters=n_cl_out,
kernel_size=(1, 1),
activation='sigmoid')(net)

model = tf.keras.models.Model(inputs=[inputs], outputs=[net])

return model
return dice_loss
Loading

0 comments on commit 6f3f3da

Please sign in to comment.