diff --git a/openfl-workspace/tf_2dunet/README.md b/openfl-workspace/tf_2dunet/README.md index 12dab8fc2e..d7665c33ab 100644 --- a/openfl-workspace/tf_2dunet/README.md +++ b/openfl-workspace/tf_2dunet/README.md @@ -14,18 +14,27 @@ To use a `tree` command, you have to install it first: `sudo apt-get install tre - `HGG`: glioblastoma scans - `LGG`: lower grade glioma scans -Let's pick `HGG`: `export SUBFOLDER=HGG`. The learning rate has been already tuned for this task, so you don't have to change it. If you pick `LGG`, all the next steps will be the same. +Let's pick `HGG`: `export SUBFOLDER=MICCAI_BraTS_2019_Data_Training/HGG`. The learning rate has been already tuned for this task, so you don't have to change it. If you pick `LGG`, all the next steps will be the same. 3) In order for each collaborator to use separate slice of data, we split main folder into `n` subfolders: ```bash +#!/bin/bash cd $DATA_PATH/$SUBFOLDER -i=0; -for f in *; -do - d=dir_$(printf $((i%n))); # change n to number of data slices (number of collaborators in federation) - mkdir -p $d; - mv "$f" $d; - let i++; + +n=2 # Set this to the number of directories you want to create + +# Get a list of all files and shuffle them +files=($(ls | shuf)) + +# Create the target directories if they don't exist +for ((i=0; i -d ` command. diff --git a/openfl-workspace/tf_2dunet/plan/data.yaml b/openfl-workspace/tf_2dunet/plan/data.yaml index 69a3568b14..0b27947ad4 100644 --- a/openfl-workspace/tf_2dunet/plan/data.yaml +++ b/openfl-workspace/tf_2dunet/plan/data.yaml @@ -1,8 +1,8 @@ -# Copyright (C) 2020-2021 Intel Corporation +# Copyright (C) 2020-2024 Intel Corporation # Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. # all keys under 'collaborators' corresponds to a specific colaborator name the corresponding dictionary has data_name, data_path pairs. # Note that in the mnist case we do not store the data locally, and the data_path is used to pass an integer that helps the data object # construct the shard of the mnist dataset to be use for this collaborator. -one,/raid/datasets/MICCAI_BraTS_2019_Data_Training/HGG/0 -two,/raid/datasets/MICCAI_BraTS_2019_Data_Training/HGG/1 +collaborator1,../data/MICCAI_BraTS_2019_Data_Training/HGG/0 +collaborator2,../data/MICCAI_BraTS_2019_Data_Training/HGG/1 diff --git a/openfl-workspace/tf_2dunet/plan/plan.yaml b/openfl-workspace/tf_2dunet/plan/plan.yaml index 2d00302208..c925de767a 100644 --- a/openfl-workspace/tf_2dunet/plan/plan.yaml +++ b/openfl-workspace/tf_2dunet/plan/plan.yaml @@ -1,13 +1,13 @@ -# Copyright (C) 2020-2021 Intel Corporation +# Copyright (C) 2020-2024 Intel Corporation # Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. aggregator : defaults : plan/defaults/aggregator.yaml template : openfl.component.Aggregator settings : - init_state_path : save/tf_2dunet_brats_init.pbuf - last_state_path : save/tf_2dunet_brats_latest.pbuf - best_state_path : save/tf_2dunet_brats_best.pbuf + init_state_path : save/init.pbuf + last_state_path : save/latest.pbuf + best_state_path : save/best.pbuf rounds_to_train : 10 db_store_rounds : 2 @@ -20,7 +20,7 @@ collaborator : data_loader : defaults : plan/defaults/data_loader.yaml - template : src.tfbrats_inmemory.TensorFlowBratsInMemory + template : src.dataloader.BratsDataloader settings : batch_size: 64 percent_train: 0.8 @@ -29,7 +29,7 @@ data_loader : task_runner : defaults : plan/defaults/task_runner.yaml - template : src.tf_2dunet.TensorFlow2DUNet + template : src.taskrunner.UNet2D network : defaults : plan/defaults/network.yaml @@ -37,8 +37,32 @@ network : assigner : defaults : plan/defaults/assigner.yaml -tasks : - defaults : plan/defaults/tasks_tensorflow.yaml +tasks: + defaults : plan/defaults/task_tensorflow.yaml + aggregated_model_validation: + function : validate_task + kwargs : + batch_size : 32 + apply : global + metrics : + - dice_coef + - soft_dice_coef + locally_tuned_model_validation: + function : validate_task + kwargs : + batch_size : 32 + apply : local + metrics : + - dice_coef + - soft_dice_coef + train: + function : train_task + kwargs : + batch_size : 32 + metrics : + - loss + epochs : 1 + compression_pipeline : defaults : plan/defaults/compression_pipeline.yaml diff --git a/openfl-workspace/tf_2dunet/requirements.txt b/openfl-workspace/tf_2dunet/requirements.txt index f655f6c7d2..53a5095174 100644 --- a/openfl-workspace/tf_2dunet/requirements.txt +++ b/openfl-workspace/tf_2dunet/requirements.txt @@ -1,3 +1,3 @@ nibabel -tensorflow==2.13 +tensorflow==2.15.1 setuptools>=65.5.1 # not directly required, pinned by Snyk to avoid a vulnerability diff --git a/openfl-workspace/tf_2dunet/src/brats_utils.py b/openfl-workspace/tf_2dunet/src/brats_utils.py deleted file mode 100644 index 653e26cbca..0000000000 --- a/openfl-workspace/tf_2dunet/src/brats_utils.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright (C) 2020-2021 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -"""You may copy this file as the starting point of your own model.""" - -import logging -import os - -import numpy as np - -from .nii_reader import nii_reader - -logger = logging.getLogger(__name__) - - -def train_val_split(features, labels, percent_train, shuffle): - """Train/validation splot of the BraTS dataset. - - Splits incoming feature and labels into training and validation. The value - of shuffle determines whether shuffling occurs before the split is performed. - - Args: - features: The input images - labels: The ground truth labels - percent_train (float): The percentage of the dataset that is training. - shuffle (bool): True = shuffle the dataset before the split - - Returns: - train_features: The input images for the training dataset - train_labels: The ground truth labels for the training dataset - val_features: The input images for the validation dataset - val_labels: The ground truth labels for the validation dataset - """ - - def split(lst, idx): - """Split a Python list into 2 lists. - - Args: - lst: The Python list to split - idx: The index where to split the list into 2 parts - - Returns: - Two lists - - """ - if idx < 0 or idx > len(lst): - raise ValueError('split was out of expected range.') - return lst[:idx], lst[idx:] - - nb_features = len(features) - nb_labels = len(labels) - if nb_features != nb_labels: - raise RuntimeError('Number of features and labels do not match.') - if shuffle: - new_order = np.random.permutation(np.arange(nb_features)) - features = features[new_order] - labels = labels[new_order] - split_idx = int(percent_train * nb_features) - train_features, val_features = split(lst=features, idx=split_idx) - train_labels, val_labels = split(lst=labels, idx=split_idx) - return train_features, train_labels, val_features, val_labels - - -def load_from_nifti(parent_dir, - percent_train, - shuffle, - channels_last=True, - task='whole_tumor', - **kwargs): - """Load the BraTS dataset from the NiFTI file format. - - Loads data from the parent directory (NIfTI files for whole brains are - assumed to be contained in subdirectories of the parent directory). - Performs a split of the data into training and validation, and the value - of shuffle determined whether shuffling is performed before this split - occurs - both split and shuffle are done in a way to - keep whole brains intact. The kwargs are passed to nii_reader. - - Args: - parent_dir: The parent directory for the BraTS data - percent_train (float): The percentage of the data to make the training dataset - shuffle (bool): True means shuffle the dataset order before the split - channels_last (bool): Input tensor uses channels as last dimension (Default is True) - task: Prediction task (Default is 'whole_tumor' prediction) - **kwargs: Variable arguments to pass to the function - - Returns: - train_features: The input images for the training dataset - train_labels: The ground truth labels for the training dataset - val_features: The input images for the validation dataset - val_labels: The ground truth labels for the validation dataset - - """ - path = os.path.join(parent_dir) - subdirs = os.listdir(path) - subdirs.sort() - if not subdirs: - raise SystemError(f'''{parent_dir} does not contain subdirectories. -Please make sure you have BraTS dataset downloaded -and located in data directory for this collaborator. - ''') - subdir_paths = [os.path.join(path, subdir) for subdir in subdirs] - - imgs_all = [] - msks_all = [] - for brain_path in subdir_paths: - these_imgs, these_msks = nii_reader( - brain_path=brain_path, - task=task, - channels_last=channels_last, - **kwargs - ) - # the needed files where not present if a tuple of None is returned - if these_imgs is None: - logger.debug(f'Brain subdirectory: {brain_path} did not contain the needed files.') - else: - imgs_all.append(these_imgs) - msks_all.append(these_msks) - - # converting to arrays to allow for numpy indexing used during split - imgs_all = np.array(imgs_all) - msks_all = np.array(msks_all) - - # note here that each is a list of 155 slices per brain, and so the - # split keeps brains intact - imgs_all_train, msks_all_train, imgs_all_val, msks_all_val = train_val_split( - features=imgs_all, - labels=msks_all, - percent_train=percent_train, - shuffle=shuffle - ) - # now concatenate the lists - imgs_train = np.concatenate(imgs_all_train, axis=0) - msks_train = np.concatenate(msks_all_train, axis=0) - imgs_val = np.concatenate(imgs_all_val, axis=0) - msks_val = np.concatenate(msks_all_val, axis=0) - - return imgs_train, msks_train, imgs_val, msks_val diff --git a/openfl-workspace/tf_2dunet/src/nii_reader.py b/openfl-workspace/tf_2dunet/src/dataloader.py similarity index 63% rename from openfl-workspace/tf_2dunet/src/nii_reader.py rename to openfl-workspace/tf_2dunet/src/dataloader.py index ba90a644b1..4e2acd9c04 100644 --- a/openfl-workspace/tf_2dunet/src/nii_reader.py +++ b/openfl-workspace/tf_2dunet/src/dataloader.py @@ -1,13 +1,174 @@ -# Copyright (C) 2020-2021 Intel Corporation +# Copyright (C) 2020-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """You may copy this file as the starting point of your own model.""" import os +import logging import nibabel as nib import numpy as np import numpy.ma as ma +from openfl.federated import TensorFlowDataLoader + + +logger = logging.getLogger(__name__) + + +class BratsDataloader(TensorFlowDataLoader): + """TensorFlow Data Loader for the BraTS dataset.""" + + def __init__(self, data_path, batch_size, percent_train=0.8, pre_split_shuffle=True, num_classes=1, + **kwargs): + """Initialize. + + Args: + data_path: The file path for the BraTS dataset + batch_size (int): The batch size to use + percent_train (float): The percentage of the data to use for training (Default=0.8) + pre_split_shuffle (bool): True= shuffle the dataset before + performing the train/validate split (Default=True) + **kwargs: Additional arguments, passed to super init and load_from_nifti + + Returns: + Data loader with BraTS data + """ + super().__init__(batch_size, **kwargs) + + X_train, y_train, X_valid, y_valid = load_from_nifti(parent_dir=data_path, + percent_train=percent_train, + shuffle=pre_split_shuffle, + **kwargs) + self.X_train = X_train + self.y_train = y_train + self.X_valid = X_valid + self.y_valid = y_valid + self.num_classes = num_classes + + +def train_val_split(features, labels, percent_train, shuffle): + """Train/validation splot of the BraTS dataset. + + Splits incoming feature and labels into training and validation. The value + of shuffle determines whether shuffling occurs before the split is performed. + + Args: + features: The input images + labels: The ground truth labels + percent_train (float): The percentage of the dataset that is training. + shuffle (bool): True = shuffle the dataset before the split + + Returns: + train_features: The input images for the training dataset + train_labels: The ground truth labels for the training dataset + val_features: The input images for the validation dataset + val_labels: The ground truth labels for the validation dataset + """ + + def split(lst, idx): + """Split a Python list into 2 lists. + + Args: + lst: The Python list to split + idx: The index where to split the list into 2 parts + + Returns: + Two lists + + """ + if idx < 0 or idx > len(lst): + raise ValueError('split was out of expected range.') + return lst[:idx], lst[idx:] + + nb_features = len(features) + nb_labels = len(labels) + if nb_features != nb_labels: + raise RuntimeError('Number of features and labels do not match.') + if shuffle: + new_order = np.random.permutation(np.arange(nb_features)) + features = features[new_order] + labels = labels[new_order] + split_idx = int(percent_train * nb_features) + train_features, val_features = split(lst=features, idx=split_idx) + train_labels, val_labels = split(lst=labels, idx=split_idx) + return train_features, train_labels, val_features, val_labels + + +def load_from_nifti(parent_dir, + percent_train, + shuffle, + channels_last=True, + task='whole_tumor', + **kwargs): + """Load the BraTS dataset from the NiFTI file format. + + Loads data from the parent directory (NIfTI files for whole brains are + assumed to be contained in subdirectories of the parent directory). + Performs a split of the data into training and validation, and the value + of shuffle determined whether shuffling is performed before this split + occurs - both split and shuffle are done in a way to + keep whole brains intact. The kwargs are passed to nii_reader. + + Args: + parent_dir: The parent directory for the BraTS data + percent_train (float): The percentage of the data to make the training dataset + shuffle (bool): True means shuffle the dataset order before the split + channels_last (bool): Input tensor uses channels as last dimension (Default is True) + task: Prediction task (Default is 'whole_tumor' prediction) + **kwargs: Variable arguments to pass to the function + + Returns: + train_features: The input images for the training dataset + train_labels: The ground truth labels for the training dataset + val_features: The input images for the validation dataset + val_labels: The ground truth labels for the validation dataset + + """ + path = os.path.join(parent_dir) + subdirs = os.listdir(path) + subdirs.sort() + if not subdirs: + raise SystemError(f'''{parent_dir} does not contain subdirectories. +Please make sure you have BraTS dataset downloaded +and located in data directory for this collaborator. + ''') + subdir_paths = [os.path.join(path, subdir) for subdir in subdirs] + + imgs_all = [] + msks_all = [] + for brain_path in subdir_paths: + these_imgs, these_msks = nii_reader( + brain_path=brain_path, + task=task, + channels_last=channels_last, + **kwargs + ) + # the needed files where not present if a tuple of None is returned + if these_imgs is None: + logger.debug(f'Brain subdirectory: {brain_path} did not contain the needed files.') + else: + imgs_all.append(these_imgs) + msks_all.append(these_msks) + + # converting to arrays to allow for numpy indexing used during split + imgs_all = np.array(imgs_all) + msks_all = np.array(msks_all) + + # note here that each is a list of 155 slices per brain, and so the + # split keeps brains intact + imgs_all_train, msks_all_train, imgs_all_val, msks_all_val = train_val_split( + features=imgs_all, + labels=msks_all, + percent_train=percent_train, + shuffle=shuffle + ) + # now concatenate the lists + imgs_train = np.concatenate(imgs_all_train, axis=0) + msks_train = np.concatenate(msks_all_train, axis=0) + imgs_val = np.concatenate(imgs_all_val, axis=0) + msks_val = np.concatenate(msks_all_val, axis=0) + + return imgs_train, msks_train, imgs_val, msks_val def parse_segments(seg, msk_modes): diff --git a/openfl-workspace/tf_2dunet/src/taskrunner.py b/openfl-workspace/tf_2dunet/src/taskrunner.py new file mode 100644 index 0000000000..4f8ff86b16 --- /dev/null +++ b/openfl-workspace/tf_2dunet/src/taskrunner.py @@ -0,0 +1,230 @@ +# Copyright (C) 2020-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""You may copy this file as the starting point of your own model.""" + +import numpy as np +import tensorflow as tf + +from openfl.utilities import Metric +from openfl.federated import TensorFlowTaskRunner + +class UNet2D(TensorFlowTaskRunner): + def __init__(self, initial_filters=16, + depth=5, + batch_norm=True, + use_upsampling=False, + **kwargs): + super().__init__(**kwargs) + + self.model = self.build_model( + input_shape=self.feature_shape, + n_cl_out=self.data_loader.num_classes, + initial_filters=initial_filters, + use_upsampling=use_upsampling, + depth=depth, + batch_norm=batch_norm, + ) + self.initialize_tensorkeys_for_functions() + + self.model.summary(print_fn=self.logger.info, line_length=120) + + def build_model(self, + input_shape, + n_cl_out=1, + use_upsampling=False, + dropout=0.2, + seed=816, + depth=5, + dropout_at=(2, 3), + initial_filters=16, + batch_norm=True): + """ + Build and compile 2D UNet model. + + Args: + input_shape (List[int]): The shape of the data + n_cl_out (int): Number of channels in output layer (Default=1) + use_upsampling (bool): True = use bilinear interpolation; + False = use transposed convolution (Default=False) + dropout (float): Dropout percentage (Default=0.2) + seed: random seed (Default=816) + depth (int): Number of max pooling layers in encoder (Default=5) + dropout_at (List[int]): Layers to perform dropout after (Default=[2,3]) + initial_filters (int): Number of filters in first convolutional layer (Default=16) + batch_norm (bool): Aply batch normalization (Default=True) + + Returns: + keras.src.engine.functional.Functional + A compiled Keras model ready for training. + """ + + if (input_shape[0] % (2**depth)) > 0: + raise ValueError(f'Crop dimension must be a multiple of 2^(depth of U-Net) = {2**depth}') + + inputs = tf.keras.layers.Input(input_shape, name='brats_mr_image') + + activation = tf.keras.activations.relu + + params = {'kernel_size': (3, 3), 'activation': activation, + 'padding': 'same', + 'kernel_initializer': tf.keras.initializers.he_uniform(seed=seed)} + + convb_layers = {} + + net = inputs + filters = initial_filters + for i in range(depth): + name = f'conv{i + 1}a' + net = tf.keras.layers.Conv2D(name=name, filters=filters, **params)(net) + if i in dropout_at: + net = tf.keras.layers.Dropout(dropout)(net) + name = f'conv{i + 1}b' + net = tf.keras.layers.Conv2D(name=name, filters=filters, **params)(net) + if batch_norm: + net = tf.keras.layers.BatchNormalization()(net) + convb_layers[name] = net + # only pool if not last level + if i != depth - 1: + name = f'pool{i + 1}' + net = tf.keras.layers.MaxPooling2D(name=name, pool_size=(2, 2))(net) + filters *= 2 + + # do the up levels + filters //= 2 + for i in range(depth - 1): + if use_upsampling: + up = tf.keras.layers.UpSampling2D( + name=f'up{depth + i + 1}', size=(2, 2))(net) + else: + up = tf.keras.layers.Conv2DTranspose(name=f'transConv{depth + i + 1}', + filters=filters, + kernel_size=(2, 2), + strides=(2, 2), + padding='same')(net) + net = tf.keras.layers.concatenate( + [up, convb_layers[f'conv{depth - i - 1}b']], + axis=-1 + ) + net = tf.keras.layers.Conv2D( + name=f'conv{depth + i + 1}a', + filters=filters, **params)(net) + net = tf.keras.layers.Conv2D( + name=f'conv{depth + i + 1}b', + filters=filters, **params)(net) + filters //= 2 + + net = tf.keras.layers.Conv2D(name='prediction', filters=n_cl_out, + kernel_size=(1, 1), + activation='sigmoid')(net) + + model = tf.keras.models.Model(inputs=[inputs], outputs=[net]) + + model.compile( + loss=dice_loss, + optimizer=tf.keras.optimizers.Adam(), + metrics=[dice_coef, soft_dice_coef], + ) + + return model + + def train_(self, batch_generator, metrics: list = None, **kwargs): + """ + Train single epoch. + + Override this function for custom training. + + Args: + batch_generator (generator): Generator of training batches. + Each batch is a tuple of N train images and N train labels + where N is the batch size of the DataLoader of the current TaskRunner instance. + metrics (List[str]): A list of metric names to compute and save + **kwargs (dict): Additional keyword arguments + + Returns: + list: Metric objects containing the computed metrics + """ + import pdb; pdb.set_trace() + history = self.model.fit(batch_generator, + verbose=1, + **kwargs) + results = [] + for metric in metrics: + value = np.mean([history.history[metric]]) + results.append(Metric(name=metric, value=np.array(value))) + return results + + +def dice_coef(target, prediction, axis=(1, 2), smooth=0.0001): + """ + Calculate the Sorenson-Dice coefficient. + + Args: + target (tf.Tensor): The ground truth binary labels. + prediction (tf.Tensor): The predicted binary labels, rounded to 0 or 1. + axis (tuple, optional): The axes along which to compute the coefficient, typically the spatial dimensions. + smooth (float, optional): A small constant added to numerator and denominator for numerical stability. + + Returns: + tf.Tensor: The mean Dice coefficient over the batch. + """ + prediction = tf.round(prediction) # Round to 0 or 1 + + intersection = tf.reduce_sum(target * prediction, axis=axis) + union = tf.reduce_sum(target + prediction, axis=axis) + numerator = tf.constant(2.) * intersection + smooth + denominator = union + smooth + coef = numerator / denominator + + return tf.reduce_mean(coef) + + +def soft_dice_coef(target, prediction, axis=(1, 2), smooth=0.0001): + """ + Calculate the soft Sorenson-Dice coefficient. + + Does not round the predictions to either 0 or 1. + + Args: + target (tf.Tensor): The ground truth binary labels. + prediction (tf.Tensor): The predicted probabilities. + axis (tuple, optional): The axes along which to compute the coefficient, typically the spatial dimensions. + smooth (float, optional): A small constant added to numerator and denominator for numerical stability. + + Returns: + tf.Tensor: The mean soft Dice coefficient over the batch. + """ + intersection = tf.reduce_sum(target * prediction, axis=axis) + union = tf.reduce_sum(target + prediction, axis=axis) + numerator = tf.constant(2.) * intersection + smooth + denominator = union + smooth + coef = numerator / denominator + + return tf.reduce_mean(coef) + + +def dice_loss(target, prediction, axis=(1, 2), smooth=0.0001): + """ + Calculate the (Soft) Sorenson-Dice loss. + + Using -log(Dice) as the loss since it is better behaved. + Also, the log allows avoidance of the division which + can help prevent underflow when the numbers are very small. + + Args: + target (tf.Tensor): The ground truth binary labels. + prediction (tf.Tensor): The predicted probabilities. + axis (tuple, optional): The axes along which to compute the loss, typically the spatial dimensions. + smooth (float, optional): A small constant added to numerator and denominator for numerical stability. + + Returns: + tf.Tensor: The mean Dice loss over the batch. + """ + intersection = tf.reduce_sum(prediction * target, axis=axis) + p = tf.reduce_sum(prediction, axis=axis) + t = tf.reduce_sum(target, axis=axis) + numerator = tf.reduce_mean(intersection + smooth) + denominator = tf.reduce_mean(t + p + smooth) + dice_loss = -tf.math.log(2. * numerator) + tf.math.log(denominator) + + return dice_loss \ No newline at end of file diff --git a/openfl-workspace/tf_2dunet/src/tf_2dunet.py b/openfl-workspace/tf_2dunet/src/tf_2dunet.py deleted file mode 100644 index 5073344050..0000000000 --- a/openfl-workspace/tf_2dunet/src/tf_2dunet.py +++ /dev/null @@ -1,250 +0,0 @@ -# Copyright (C) 2020-2021 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -"""You may copy this file as the starting point of your own model.""" - -import tensorflow.compat.v1 as tf - -from openfl.federated import TensorFlowTaskRunner - -tf.disable_v2_behavior() - - -class TensorFlow2DUNet(TensorFlowTaskRunner): - """Initialize. - - Args: - **kwargs: Additional parameters to pass to the function - - """ - - def __init__(self, **kwargs): - """Initialize. - - Args: - **kwargs: Additional parameters to pass to the function - - """ - super().__init__(**kwargs) - - self.create_model(**kwargs) - self.initialize_tensorkeys_for_functions() - - def create_model(self, - training_smoothing=32.0, - validation_smoothing=1.0, - **kwargs): - """Create the TensorFlow 2D U-Net model. - - Args: - training_smoothing (float): (Default=32.0) - validation_smoothing (float): (Default=1.0) - **kwargs: Additional parameters to pass to the function - - """ - config = tf.ConfigProto() - config.gpu_options.allow_growth = True - config.intra_op_parallelism_threads = 112 - config.inter_op_parallelism_threads = 1 - self.sess = tf.Session(config=config) - - self.X = tf.placeholder(tf.float32, self.input_shape) - self.y = tf.placeholder(tf.float32, self.input_shape) - self.output = define_model(self.X, use_upsampling=True, **kwargs) - - self.loss = dice_coef_loss(self.y, self.output, smooth=training_smoothing) - self.loss_name = dice_coef_loss.__name__ - self.validation_metric = dice_coef( - self.y, self.output, smooth=validation_smoothing) - self.validation_metric_name = dice_coef.__name__ - - self.global_step = tf.train.get_or_create_global_step() - - self.tvars = tf.trainable_variables() - - self.optimizer = tf.train.RMSPropOptimizer(1e-2) - - self.gvs = self.optimizer.compute_gradients(self.loss, self.tvars) - self.train_step = self.optimizer.apply_gradients(self.gvs, - global_step=self.global_step) - - self.opt_vars = self.optimizer.variables() - - # FIXME: Do we really need to share the opt_vars? - # Two opt_vars for one tvar: gradient and square sum for RMSprop. - self.fl_vars = self.tvars + self.opt_vars - - self.initialize_globals() - - -def dice_coef(y_true, y_pred, smooth=1.0, **kwargs): - """Dice coefficient. - - Calculate the Dice Coefficient - - Args: - y_true: Ground truth annotation array - y_pred: Prediction array from model - smooth (float): Laplace smoothing factor (Default=1.0) - **kwargs: Additional parameters to pass to the function - - Returns: - float: Dice cofficient metric - - """ - intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3]) - coef = ( - (tf.constant(2.) * intersection + tf.constant(smooth)) - / (tf.reduce_sum(y_true, axis=[1, 2, 3]) - + tf.reduce_sum(y_pred, axis=[1, 2, 3]) + tf.constant(smooth)) - ) - return tf.reduce_mean(coef) - - -def dice_coef_loss(y_true, y_pred, smooth=1.0, **kwargs): - """Dice coefficient loss. - - Calculate the -log(Dice Coefficient) loss - - Args: - y_true: Ground truth annotation array - y_pred: Prediction array from model - smooth (float): Laplace smoothing factor (Default=1.0) - **kwargs: Additional parameters to pass to the function - - Returns: - float: -log(Dice cofficient) metric - - """ - intersection = tf.reduce_sum(y_true * y_pred, axis=(1, 2, 3)) - - term1 = -tf.log(tf.constant(2.0) * intersection + smooth) - term2 = tf.log(tf.reduce_sum(y_true, axis=(1, 2, 3)) - + tf.reduce_sum(y_pred, axis=(1, 2, 3)) + smooth) - - term1 = tf.reduce_mean(term1) - term2 = tf.reduce_mean(term2) - - loss = term1 + term2 - - return loss - - -CHANNEL_LAST = True -if CHANNEL_LAST: - concat_axis = -1 - data_format = 'channels_last' -else: - concat_axis = 1 - data_format = 'channels_first' - -tf.keras.backend.set_image_data_format(data_format) - - -def define_model(input_tensor, - use_upsampling=False, - n_cl_out=1, - dropout=0.2, - print_summary=True, - activation_function='relu', - seed=0xFEEDFACE, - depth=5, - dropout_at=None, - initial_filters=32, - batch_norm=True, - **kwargs): - """Define the TensorFlow model. - - Args: - input_tensor: input shape ot the model - use_upsampling (bool): True = use bilinear interpolation; - False = use transposed convolution (Default=False) - n_cl_out (int): Number of channels in input layer (Default=1) - dropout (float): Dropout percentage (Default=0.2) - print_summary (bool): True = print the model summary (Default = True) - activation_function: The activation function to use after convolutional - layers (Default='relu') - seed: random seed (Default=0xFEEDFACE) - depth (int): Number of max pooling layers in encoder (Default=5) - dropout_at: Layers to perform dropout after (Default=[2,3]) - initial_filters (int): Number of filters in first convolutional - layer (Default=32) - batch_norm (bool): True = use batch normalization (Default=True) - **kwargs: Additional parameters to pass to the function - - """ - if dropout_at is None: - dropout_at = [2, 3] - # Set keras learning phase to train - tf.keras.backend.set_learning_phase(True) - - # Don't initialize variables on the fly - tf.keras.backend.manual_variable_initialization(False) - - inputs = tf.keras.layers.Input(tensor=input_tensor, name='Images') - - if activation_function == 'relu': - activation = tf.nn.relu - elif activation_function == 'leakyrelu': - activation = tf.nn.leaky_relu - - params = { - 'activation': activation, - 'data_format': data_format, - 'kernel_initializer': tf.keras.initializers.he_uniform(seed=seed), - 'kernel_size': (3, 3), - 'padding': 'same', - } - - convb_layers = {} - - net = inputs - filters = initial_filters - for i in range(depth): - name = f'conv{i + 1}a' - net = tf.keras.layers.Conv2D(name=name, filters=filters, **params)(net) - if i in dropout_at: - net = tf.keras.layers.Dropout(dropout)(net) - name = f'conv{i + 1}b' - net = tf.keras.layers.Conv2D(name=name, filters=filters, **params)(net) - if batch_norm: - net = tf.keras.layers.BatchNormalization()(net) - convb_layers[name] = net - # only pool if not last level - if i != depth - 1: - name = f'pool{i + 1}' - net = tf.keras.layers.MaxPooling2D(name=name, pool_size=(2, 2))(net) - filters *= 2 - - # do the up levels - filters //= 2 - for i in range(depth - 1): - if use_upsampling: - up = tf.keras.layers.UpSampling2D( - name=f'up{depth + i + 1}', size=(2, 2))(net) - else: - up = tf.keras.layers.Conv2DTranspose( - name='transConv6', filters=filters, data_format=data_format, - kernel_size=(2, 2), strides=(2, 2), padding='same')(net) - net = tf.keras.layers.concatenate( - [up, convb_layers[f'conv{depth - i - 1}b']], - axis=concat_axis - ) - net = tf.keras.layers.Conv2D( - name=f'conv{depth + i + 1}a', - filters=filters, **params)(net) - net = tf.keras.layers.Conv2D( - name=f'conv{depth + i + 1}b', - filters=filters, **params)(net) - filters //= 2 - - net = tf.keras.layers.Conv2D(name='Mask', filters=n_cl_out, - kernel_size=(1, 1), data_format=data_format, - activation='sigmoid')(net) - - model = tf.keras.models.Model(inputs=[inputs], outputs=[net]) - - if print_summary: - print(model.summary()) - - return net diff --git a/openfl-workspace/tf_2dunet/src/tfbrats_inmemory.py b/openfl-workspace/tf_2dunet/src/tfbrats_inmemory.py deleted file mode 100644 index 49b4484fc2..0000000000 --- a/openfl-workspace/tf_2dunet/src/tfbrats_inmemory.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (C) 2020-2021 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -"""You may copy this file as the starting point of your own model.""" - -from openfl.federated import TensorFlowDataLoader -from .brats_utils import load_from_nifti - - -class TensorFlowBratsInMemory(TensorFlowDataLoader): - """TensorFlow Data Loader for the BraTS dataset.""" - - def __init__(self, data_path, batch_size, percent_train=0.8, pre_split_shuffle=True, **kwargs): - """Initialize. - - Args: - data_path: The file path for the BraTS dataset - batch_size (int): The batch size to use - percent_train (float): The percentage of the data to use for training (Default=0.8) - pre_split_shuffle (bool): True= shuffle the dataset before - performing the train/validate split (Default=True) - **kwargs: Additional arguments, passed to super init and load_from_nifti - - Returns: - Data loader with BraTS data - """ - super().__init__(batch_size, **kwargs) - - X_train, y_train, X_valid, y_valid = load_from_nifti(parent_dir=data_path, - percent_train=percent_train, - shuffle=pre_split_shuffle, - **kwargs) - self.X_train = X_train - self.y_train = y_train - self.X_valid = X_valid - self.y_valid = y_valid diff --git a/openfl-workspace/tf_3dunet_brats/plan/data.yaml b/openfl-workspace/tf_3dunet_brats/plan/data.yaml index d006410dda..ca127225af 100644 --- a/openfl-workspace/tf_3dunet_brats/plan/data.yaml +++ b/openfl-workspace/tf_3dunet_brats/plan/data.yaml @@ -11,6 +11,6 @@ # Symbolically link the ./data directory to whereever you have BraTS stored. # e.g. ln -s ~/data/MICCAI_BraTS2020_TrainingData ./data/one -one,~/MICCAI_BraTS2020_TrainingData/split_0 -two,~/MICCAI_BraTS2020_TrainingData/split_1 +collaborator1,../data/split/split_0 +collaborator2,../data/split/split_1 diff --git a/openfl-workspace/tf_3dunet_brats/plan/defaults b/openfl-workspace/tf_3dunet_brats/plan/defaults new file mode 100644 index 0000000000..5042bedbcf --- /dev/null +++ b/openfl-workspace/tf_3dunet_brats/plan/defaults @@ -0,0 +1 @@ +../../workspace/plan/defaults \ No newline at end of file diff --git a/openfl-workspace/tf_3dunet_brats/plan/defaults/aggregator.yaml b/openfl-workspace/tf_3dunet_brats/plan/defaults/aggregator.yaml deleted file mode 100644 index d3ef6e5082..0000000000 --- a/openfl-workspace/tf_3dunet_brats/plan/defaults/aggregator.yaml +++ /dev/null @@ -1,4 +0,0 @@ -template : openfl.component.Aggregator -settings : - db_store_rounds : 1 - diff --git a/openfl-workspace/tf_3dunet_brats/plan/defaults/assigner.yaml b/openfl-workspace/tf_3dunet_brats/plan/defaults/assigner.yaml deleted file mode 100644 index 0b7e744475..0000000000 --- a/openfl-workspace/tf_3dunet_brats/plan/defaults/assigner.yaml +++ /dev/null @@ -1,9 +0,0 @@ -template : openfl.component.RandomGroupedAssigner -settings : - task_groups : - - name : train_and_validate - percentage : 1.0 - tasks : - - aggregated_model_validation - - train - - locally_tuned_model_validation diff --git a/openfl-workspace/tf_3dunet_brats/plan/defaults/collaborator.yaml b/openfl-workspace/tf_3dunet_brats/plan/defaults/collaborator.yaml deleted file mode 100644 index a9c2e6eb7b..0000000000 --- a/openfl-workspace/tf_3dunet_brats/plan/defaults/collaborator.yaml +++ /dev/null @@ -1,5 +0,0 @@ -template : openfl.component.Collaborator -settings : - opt_treatment : 'CONTINUE_LOCAL' - delta_updates : True - db_store_rounds : 1 diff --git a/openfl-workspace/tf_3dunet_brats/plan/defaults/compression_pipeline.yaml b/openfl-workspace/tf_3dunet_brats/plan/defaults/compression_pipeline.yaml deleted file mode 100644 index a508f94fd2..0000000000 --- a/openfl-workspace/tf_3dunet_brats/plan/defaults/compression_pipeline.yaml +++ /dev/null @@ -1 +0,0 @@ -template: openfl.pipelines.NoCompressionPipeline diff --git a/openfl-workspace/tf_3dunet_brats/plan/defaults/data_loader.yaml b/openfl-workspace/tf_3dunet_brats/plan/defaults/data_loader.yaml deleted file mode 100644 index 33accd5ab2..0000000000 --- a/openfl-workspace/tf_3dunet_brats/plan/defaults/data_loader.yaml +++ /dev/null @@ -1 +0,0 @@ -template: openfl.federated.DataLoader diff --git a/openfl-workspace/tf_3dunet_brats/plan/defaults/network.yaml b/openfl-workspace/tf_3dunet_brats/plan/defaults/network.yaml deleted file mode 100644 index 9528631585..0000000000 --- a/openfl-workspace/tf_3dunet_brats/plan/defaults/network.yaml +++ /dev/null @@ -1,9 +0,0 @@ -template: openfl.federation.Network -settings: - agg_addr : auto - agg_port : auto - hash_salt : auto - disable_tls : False - client_reconnect_interval : 5 - disable_client_auth : False - cert_folder : cert diff --git a/openfl-workspace/tf_3dunet_brats/plan/defaults/task_runner.yaml b/openfl-workspace/tf_3dunet_brats/plan/defaults/task_runner.yaml deleted file mode 100644 index b162724693..0000000000 --- a/openfl-workspace/tf_3dunet_brats/plan/defaults/task_runner.yaml +++ /dev/null @@ -1 +0,0 @@ -template: openfl.federated.task_runner.CoreTaskRunner diff --git a/openfl-workspace/tf_3dunet_brats/plan/defaults/tasks_fast_estimator.yaml b/openfl-workspace/tf_3dunet_brats/plan/defaults/tasks_fast_estimator.yaml deleted file mode 100644 index 1548d4b225..0000000000 --- a/openfl-workspace/tf_3dunet_brats/plan/defaults/tasks_fast_estimator.yaml +++ /dev/null @@ -1,22 +0,0 @@ -aggregated_model_validation: - function : validate - kwargs : - batch_size : 32 - apply : global - metrics : - - accuracy - -locally_tuned_model_validation: - function : validate - kwargs : - batch_size : 32 - apply : local - metrics : - - accuracy -train: - function : train - kwargs : - batch_size : 32 - epochs : 1 - metrics : - - loss diff --git a/openfl-workspace/tf_3dunet_brats/plan/defaults/tasks_keras.yaml b/openfl-workspace/tf_3dunet_brats/plan/defaults/tasks_keras.yaml deleted file mode 100644 index 79d067d8d2..0000000000 --- a/openfl-workspace/tf_3dunet_brats/plan/defaults/tasks_keras.yaml +++ /dev/null @@ -1,23 +0,0 @@ -aggregated_model_validation: - function : validate - kwargs : - batch_size : 32 - apply : global - metrics : - - accuracy - -locally_tuned_model_validation: - function : validate - kwargs : - batch_size : 32 - apply : local - metrics : - - accuracy - -train: - function : train - kwargs : - batch_size : 32 - epochs : 1 - metrics : - - loss diff --git a/openfl-workspace/tf_3dunet_brats/plan/defaults/tasks_tensorflow.yaml b/openfl-workspace/tf_3dunet_brats/plan/defaults/tasks_tensorflow.yaml deleted file mode 100644 index 586a885b40..0000000000 --- a/openfl-workspace/tf_3dunet_brats/plan/defaults/tasks_tensorflow.yaml +++ /dev/null @@ -1,23 +0,0 @@ -aggregated_model_validation: - function : validate - kwargs : - batch_size : 32 - apply : global - metrics : - - acc - -locally_tuned_model_validation: - function : validate - kwargs : - batch_size : 32 - apply : local - metrics : - - acc - -train: - function : train_batches - kwargs : - batch_size : 32 - num_batches : 1 - metrics : - - loss diff --git a/openfl-workspace/tf_3dunet_brats/plan/defaults/tasks_torch.yaml b/openfl-workspace/tf_3dunet_brats/plan/defaults/tasks_torch.yaml deleted file mode 100644 index a240c2003b..0000000000 --- a/openfl-workspace/tf_3dunet_brats/plan/defaults/tasks_torch.yaml +++ /dev/null @@ -1,19 +0,0 @@ -aggregated_model_validation: - function : validate - kwargs : - apply : global - metrics : - - acc - -locally_tuned_model_validation: - function : validate - kwargs : - apply: local - metrics : - - acc - -train: - function : train_batches - kwargs : - metrics : - - loss diff --git a/openfl-workspace/tf_3dunet_brats/plan/plan.yaml b/openfl-workspace/tf_3dunet_brats/plan/plan.yaml index fa8fb911de..fe5a1edcd9 100644 --- a/openfl-workspace/tf_3dunet_brats/plan/plan.yaml +++ b/openfl-workspace/tf_3dunet_brats/plan/plan.yaml @@ -1,83 +1,68 @@ -aggregator: - defaults: plan/defaults/aggregator.yaml - settings: - best_state_path: save/tf_3dunet_brats_best.pbuf - init_state_path: save/tf_3dunet_brats_init.pbuf - last_state_path: save/tf_3dunet_brats_latest.pbuf - db_store_rounds: 2 - rounds_to_train: 10 - template: openfl.component.Aggregator -assigner: - defaults: plan/defaults/assigner.yaml - settings: - task_groups: - - name: train_and_validate - percentage: 1.0 - tasks: - - aggregated_model_validation - - train - - locally_tuned_model_validation - template: openfl.component.RandomGroupedAssigner -collaborator: - defaults: plan/defaults/collaborator.yaml - settings: - db_store_rounds: 2 - delta_updates: true - opt_treatment: RESET - template: openfl.component.Collaborator +aggregator : + defaults : plan/defaults/aggregator.yaml + template : openfl.component.Aggregator + settings : + init_state_path : save/init.pbuf + last_state_path : save/latest.pbuf + best_state_path : save/best.pbuf + rounds_to_train : 10 + db_store_rounds : 2 + +collaborator : + defaults : plan/defaults/collaborator.yaml + template : openfl.component.Collaborator + settings : + delta_updates : true + opt_treatment : RESET + data_loader: defaults: plan/defaults/data_loader.yaml + template: src.dataloader.BratsDataloader settings: batch_size: 4 crop_dim: 64 num_classes: 1 number_input_channels: 1 percent_train: 0.8 - template: src.tf_brats_dataloader.TensorFlowBratsDataLoader -network: - defaults: plan/defaults/network.yaml - settings: - agg_addr: DESKTOP-AOKV1IJ.localdomain - agg_port: auto - cert_folder: cert - client_reconnect_interval: 5 - disable_client_auth: false - disable_tls: false - hash_salt: auto - template: openfl.federation.Network + task_runner: defaults: plan/defaults/task_runner.yaml + template: src.taskrunner.UNet3D settings: batch_norm: true batch_size: 4 depth: 4 initial_filters: 16 use_upsampling: false - template: src.tf_3dunet_model.TensorFlow3dUNet + +network : + defaults : plan/defaults/network.yaml + +assigner : + defaults : plan/defaults/assigner.yaml + tasks: + defaults : plan/defaults/task_tensorflow.yaml aggregated_model_validation: - function: validate + function: validate_task kwargs: apply: global batch_size: 4 metrics: - - dice_coef - - soft_dice_coef - defaults: plan/defaults/tasks_tensorflow.yaml + - dice_coef + - soft_dice_coef locally_tuned_model_validation: - function: validate + function: validate_task kwargs: apply: local batch_size: 4 metrics: - - dice_coef - - soft_dice_coef - settings: {} + - dice_coef + - soft_dice_coef train: - function: train + function: train_task kwargs: batch_size: 4 epochs: 1 metrics: - - loss - num_batches: 1 + - loss diff --git a/openfl-workspace/tf_3dunet_brats/requirements.txt b/openfl-workspace/tf_3dunet_brats/requirements.txt index ed58705a66..d79023f767 100644 --- a/openfl-workspace/tf_3dunet_brats/requirements.txt +++ b/openfl-workspace/tf_3dunet_brats/requirements.txt @@ -1,4 +1,4 @@ -tensorflow>=2 +tensorflow==2.15.1 nibabel numpy diff --git a/openfl-workspace/tf_3dunet_brats/src/dataloader.py b/openfl-workspace/tf_3dunet_brats/src/dataloader.py index 80ac1cd004..afef27ca62 100644 --- a/openfl-workspace/tf_3dunet_brats/src/dataloader.py +++ b/openfl-workspace/tf_3dunet_brats/src/dataloader.py @@ -1,14 +1,103 @@ -# Copyright (C) 2020-2021 Intel Corporation +# Copyright (C) 2020-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """You may copy this file as the starting point of your own model.""" -import os - import nibabel as nib import numpy as np +import os import tensorflow as tf +from openfl.federated import TensorFlowDataLoader + + +class BratsDataloader(TensorFlowDataLoader): + """TensorFlow Data Loader for the BraTS dataset.""" + + def __init__(self, data_path, batch_size=4, + crop_dim=64, percent_train=0.8, + number_input_channels=1, + num_classes=1, + **kwargs): + """Initialize. + + Args: + data_path: The file path for the BraTS dataset + batch_size (int): The batch size to use + crop_dim (int): Crop the original image to this size on each dimension + percent_train (float): The percentage of the data to use for training (Default=0.8) + pre_split_shuffle (bool): True= shuffle the dataset before + performing the train/validate split (Default=True) + **kwargs: Additional arguments, passed to super init + + Returns: + Data loader with BraTS data + """ + super().__init__(batch_size, **kwargs) + + self.data_path = os.path.abspath(os.path.expanduser(data_path)) + self.batch_size = batch_size + self.crop_dim = [crop_dim, crop_dim, crop_dim, number_input_channels] + self.num_input_channels = number_input_channels + self.num_classes = num_classes + + self.train_test_split = percent_train + + self.brats_data = DatasetGenerator(crop_dim, + data_path=data_path, + number_input_channels=number_input_channels, + batch_size=batch_size, + train_test_split=percent_train, + validate_test_split=0.5, + num_classes=num_classes, + random_seed=816) + + def get_feature_shape(self): + """ + Get the shape of an example feature array. + + Returns: + tuple: shape of an example feature array + """ + return tuple(self.brats_data.get_input_shape()) + + def get_train_loader(self, batch_size=None, num_batches=None): + """ + Get training data loader. + + Returns + ------- + loader object + """ + return self.brats_data.ds_train + + def get_valid_loader(self, batch_size=None): + """ + Get validation data loader. + + Returns: + loader object + """ + return self.brats_data.ds_val + + def get_train_data_size(self): + """ + Get total number of training samples. + + Returns: + int: number of training samples + """ + return self.brats_data.num_train + + def get_valid_data_size(self): + """ + Get total number of validation samples. + + Returns: + int: number of validation samples + """ + return self.brats_data.num_val + class DatasetGenerator: """Generate a TensorFlow data loader from the BraTS .nii.gz files.""" diff --git a/openfl-workspace/tf_3dunet_brats/src/define_model.py b/openfl-workspace/tf_3dunet_brats/src/define_model.py index 148d66e9ad..df972ee30f 100644 --- a/openfl-workspace/tf_3dunet_brats/src/define_model.py +++ b/openfl-workspace/tf_3dunet_brats/src/define_model.py @@ -8,11 +8,16 @@ def dice_coef(target, prediction, axis=(1, 2, 3), smooth=0.0001): """ - Sorenson Dice. + Calculate the Sorenson-Dice coefficient. - Returns - ------- - dice coefficient (float) + Args: + target (tf.Tensor): The ground truth binary labels. + prediction (tf.Tensor): The predicted binary labels, rounded to 0 or 1. + axis (tuple, optional): The axes along which to compute the coefficient, typically the spatial dimensions. + smooth (float, optional): A small constant added to numerator and denominator for numerical stability. + + Returns: + tf.Tensor: The mean Dice coefficient over the batch. """ prediction = tf.round(prediction) # Round to 0 or 1 @@ -27,13 +32,18 @@ def dice_coef(target, prediction, axis=(1, 2, 3), smooth=0.0001): def soft_dice_coef(target, prediction, axis=(1, 2, 3), smooth=0.0001): """ - Soft Sorenson Dice. + Calculate the soft Sorenson-Dice coefficient. Does not round the predictions to either 0 or 1. - Returns - ------- - soft dice coefficient (float) + Args: + target (tf.Tensor): The ground truth binary labels. + prediction (tf.Tensor): The predicted probabilities. + axis (tuple, optional): The axes along which to compute the coefficient, typically the spatial dimensions. + smooth (float, optional): A small constant added to numerator and denominator for numerical stability. + + Returns: + tf.Tensor: The mean soft Dice coefficient over the batch. """ intersection = tf.reduce_sum(target * prediction, axis=axis) union = tf.reduce_sum(target + prediction, axis=axis) @@ -46,15 +56,20 @@ def soft_dice_coef(target, prediction, axis=(1, 2, 3), smooth=0.0001): def dice_loss(target, prediction, axis=(1, 2, 3), smooth=0.0001): """ - Sorenson (Soft) Dice loss. + Calculate the (Soft) Sorenson-Dice loss. Using -log(Dice) as the loss since it is better behaved. Also, the log allows avoidance of the division which can help prevent underflow when the numbers are very small. - Returns - ------- - dice loss (float) + Args: + target (tf.Tensor): The ground truth binary labels. + prediction (tf.Tensor): The predicted probabilities. + axis (tuple, optional): The axes along which to compute the loss, typically the spatial dimensions. + smooth (float, optional): A small constant added to numerator and denominator for numerical stability. + + Returns: + tf.Tensor: The mean Dice loss over the batch. """ intersection = tf.reduce_sum(prediction * target, axis=axis) p = tf.reduce_sum(prediction, axis=axis) @@ -70,29 +85,29 @@ def build_model(input_shape, n_cl_out=1, use_upsampling=False, dropout=0.2, - print_summary=True, seed=816, depth=5, dropout_at=(2, 3), initial_filters=16, - batch_norm=True, - **kwargs): - """Build the TensorFlow model. + batch_norm=True,): + """ + Build and compile 3D UNet model. Args: - input_tensor: input shape ot the model - use_upsampling (bool): True = use bilinear interpolation; - False = use transposed convolution (Default=False) + input_shape (List[int]): The shape of the data n_cl_out (int): Number of channels in output layer (Default=1) + use_upsampling (bool): True = use bilinear interpolation; + False = use transposed convolution (Default=False) dropout (float): Dropout percentage (Default=0.2) - print_summary (bool): True = print the model summary (Default = True) seed: random seed (Default=816) depth (int): Number of max pooling layers in encoder (Default=5) - dropout_at: Layers to perform dropout after (Default=[2,3]) - initial_filters (int): Number of filters in first convolutional - layer (Default=16) - batch_norm (bool): True = use batch normalization (Default=True) - **kwargs: Additional parameters to pass to the function + dropout_at (List[int]): Layers to perform dropout after (Default=[2,3]) + initial_filters (int): Number of filters in first convolutional layer (Default=16) + batch_norm (bool): Aply batch normalization (Default=True) + + Returns: + keras.src.engine.functional.Functional + A compiled Keras model ready for training. """ if (input_shape[0] % (2**depth)) > 0: raise ValueError(f'Crop dimension must be a multiple of 2^(depth of U-Net) = {2**depth}') diff --git a/openfl-workspace/tf_3dunet_brats/src/tf_3dunet_model.py b/openfl-workspace/tf_3dunet_brats/src/taskrunner.py similarity index 94% rename from openfl-workspace/tf_3dunet_brats/src/tf_3dunet_model.py rename to openfl-workspace/tf_3dunet_brats/src/taskrunner.py index 8beeaf2375..d15e569598 100644 --- a/openfl-workspace/tf_3dunet_brats/src/tf_3dunet_model.py +++ b/openfl-workspace/tf_3dunet_brats/src/taskrunner.py @@ -3,16 +3,15 @@ """You may copy this file as the starting point of your own model.""" +import argparse +from openfl.federated import TensorFlowTaskRunner +import os +from src.define_model import build_model, dice_coef, dice_loss, soft_dice_coef +from src.dataloader import DatasetGenerator import tensorflow as tf -from openfl.federated import KerasTaskRunner -from .define_model import build_model -from .define_model import dice_coef -from .define_model import dice_loss -from .define_model import soft_dice_coef - -class TensorFlow3dUNet(KerasTaskRunner): +class UNet3D(TensorFlowTaskRunner): """Initialize. Args: @@ -51,7 +50,6 @@ def create_model(self, n_cl_out=1, use_upsampling=False, dropout=0.2, - print_summary=True, seed=816, depth=5, dropout_at=(2, 3), @@ -73,7 +71,6 @@ def create_model(self, n_cl_out=n_cl_out, use_upsampling=use_upsampling, dropout=dropout, - print_summary=print_summary, seed=seed, depth=depth, dropout_at=dropout_at, @@ -101,12 +98,6 @@ def create_model(self, if __name__ == '__main__': - - from tf_brats_dataloader import DatasetGenerator - import os - - import argparse - parser = argparse.ArgumentParser( description='Train 3D U-Net model', add_help=True, formatter_class=argparse.ArgumentDefaultsHelpFormatter) diff --git a/openfl-workspace/tf_3dunet_brats/src/tf_brats_dataloader.py b/openfl-workspace/tf_3dunet_brats/src/tf_brats_dataloader.py deleted file mode 100644 index 85e5c576c3..0000000000 --- a/openfl-workspace/tf_3dunet_brats/src/tf_brats_dataloader.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (C) 2020-2021 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -"""You may copy this file as the starting point of your own model.""" - - -import os - -from openfl.federated import TensorFlowDataLoader -from .dataloader import DatasetGenerator - - -class TensorFlowBratsDataLoader(TensorFlowDataLoader): - """TensorFlow Data Loader for the BraTS dataset.""" - - def __init__(self, data_path, batch_size=4, - crop_dim=64, percent_train=0.8, - pre_split_shuffle=True, - number_input_channels=1, - num_classes=1, - **kwargs): - """Initialize. - - Args: - data_path: The file path for the BraTS dataset - batch_size (int): The batch size to use - crop_dim (int): Crop the original image to this size on each dimension - percent_train (float): The percentage of the data to use for training (Default=0.8) - pre_split_shuffle (bool): True= shuffle the dataset before - performing the train/validate split (Default=True) - **kwargs: Additional arguments, passed to super init - - Returns: - Data loader with BraTS data - """ - super().__init__(batch_size, **kwargs) - - self.data_path = os.path.abspath(os.path.expanduser(data_path)) - self.batch_size = batch_size - self.crop_dim = [crop_dim, crop_dim, crop_dim, number_input_channels] - self.num_input_channels = number_input_channels - self.num_classes = num_classes - - self.train_test_split = percent_train - - self.brats_data = DatasetGenerator(crop_dim, - data_path=data_path, - number_input_channels=number_input_channels, - batch_size=batch_size, - train_test_split=percent_train, - validate_test_split=0.5, - num_classes=num_classes, - random_seed=816) - - def get_feature_shape(self): - """ - Get the shape of an example feature array. - - Returns: - tuple: shape of an example feature array - """ - return tuple(self.brats_data.get_input_shape()) - - def get_train_loader(self, batch_size=None, num_batches=None): - """ - Get training data loader. - - Returns - ------- - loader object - """ - return self.brats_data.ds_train - - def get_valid_loader(self, batch_size=None): - """ - Get validation data loader. - - Returns: - loader object - """ - return self.brats_data.ds_val - - def get_train_data_size(self): - """ - Get total number of training samples. - - Returns: - int: number of training samples - """ - return self.brats_data.num_train - - def get_valid_data_size(self): - """ - Get total number of validation samples. - - Returns: - int: number of validation samples - """ - return self.brats_data.num_val diff --git a/openfl-workspace/tf_cnn_histology/plan/plan.yaml b/openfl-workspace/tf_cnn_histology/plan/plan.yaml index f834794113..87ecfa207f 100644 --- a/openfl-workspace/tf_cnn_histology/plan/plan.yaml +++ b/openfl-workspace/tf_cnn_histology/plan/plan.yaml @@ -5,10 +5,10 @@ aggregator : defaults : plan/defaults/aggregator.yaml template : openfl.component.Aggregator settings : - init_state_path : save/tf_cnn_histology_init.pbuf - last_state_path : save/tf_cnn_histology_latest.pbuf - best_state_path : save/tf_cnn_histology_best.pbuf - db_store_rounds: 2 + init_state_path : save/init.pbuf + last_state_path : save/latest.pbuf + best_state_path : save/best.pbuf + db_store_rounds : 2 rounds_to_train : 10 collaborator : @@ -21,7 +21,7 @@ collaborator : data_loader : defaults : plan/defaults/data_loader.yaml - template : src.tfhistology_inmemory.TensorFlowHistologyInMemory + template : src.dataloader.HistologyDataloader settings : batch_size: 64 percent_train: 0.8 @@ -30,7 +30,7 @@ data_loader : task_runner : defaults : plan/defaults/task_runner.yaml - template : src.tf_cnn.TensorFlowCNN + template : src.taskrunner.CNN network : defaults : plan/defaults/network.yaml @@ -41,14 +41,14 @@ assigner : tasks: defaults: plan/defaults/tasks_tensorflow.yaml aggregated_model_validation: - function: validate + function: validate_task kwargs: apply: global batch_size: 32 metrics: - sparse_categorical_accuracy locally_tuned_model_validation: - function: validate + function: validate_task kwargs: apply: local batch_size: 32 @@ -56,7 +56,7 @@ tasks: - sparse_categorical_accuracy settings: {} train: - function: train + function: train_task kwargs: batch_size: 32 epochs: 1 diff --git a/openfl-workspace/tf_cnn_histology/requirements.txt b/openfl-workspace/tf_cnn_histology/requirements.txt index 59ee6430c8..23b0b78e6f 100644 --- a/openfl-workspace/tf_cnn_histology/requirements.txt +++ b/openfl-workspace/tf_cnn_histology/requirements.txt @@ -1,3 +1,3 @@ pillow -tensorflow==2.13 +tensorflow==2.15.1 tensorflow-datasets diff --git a/openfl-workspace/tf_cnn_histology/src/tfds_utils.py b/openfl-workspace/tf_cnn_histology/src/dataloader.py similarity index 79% rename from openfl-workspace/tf_cnn_histology/src/tfds_utils.py rename to openfl-workspace/tf_cnn_histology/src/dataloader.py index 92977ebad0..8314df8153 100644 --- a/openfl-workspace/tf_cnn_histology/src/tfds_utils.py +++ b/openfl-workspace/tf_cnn_histology/src/dataloader.py @@ -3,6 +3,7 @@ """You may copy this file as the starting point of your own model.""" +from openfl.federated import TensorFlowDataLoader from logging import getLogger import numpy as np @@ -11,18 +12,31 @@ logger = getLogger(__name__) -def one_hot(labels, classes): - """ - One Hot encode a vector. +class HistologyDataloader(TensorFlowDataLoader): + """TensorFlow Data Loader for Colorectal Histology Dataset.""" - Args: - labels (list): List of labels to onehot encode - classes (int): Total number of categorical classes + def __init__(self, data_path, batch_size, **kwargs): + """ + Initialize. - Returns: - np.array: Matrix of one-hot encoded labels - """ - return np.eye(classes)[labels] + Args: + data_path: File path for the dataset + batch_size (int): The batch size for the data loader + **kwargs: Additional arguments, passed to super init and load_mnist_shard + """ + super().__init__(batch_size, **kwargs) + + _, num_classes, X_train, y_train, X_valid, y_valid = load_histology_shard( + shard_num=data_path, + categorical=False, **kwargs + ) + + self.X_train = X_train + self.y_train = y_train + self.X_valid = X_valid + self.y_valid = y_valid + + self.num_classes = num_classes def _load_raw_datashards(shard_num, collaborator_count): @@ -120,7 +134,7 @@ def load_histology_shard(shard_num, collaborator_count, categorical=True, if categorical: # convert class vectors to binary class matrices - y_train = one_hot(y_train, num_classes) - y_valid = one_hot(y_valid, num_classes) + y_train = np.eye(num_classes)[y_train] + y_valid = np.eye(num_classes)[y_valid] return input_shape, num_classes, X_train, y_train, X_valid, y_valid diff --git a/openfl-workspace/tf_cnn_histology/src/tf_cnn.py b/openfl-workspace/tf_cnn_histology/src/taskrunner.py similarity index 60% rename from openfl-workspace/tf_cnn_histology/src/tf_cnn.py rename to openfl-workspace/tf_cnn_histology/src/taskrunner.py index 7041396678..0f33c45317 100644 --- a/openfl-workspace/tf_cnn_histology/src/tf_cnn.py +++ b/openfl-workspace/tf_cnn_histology/src/taskrunner.py @@ -3,12 +3,14 @@ """You may copy this file as the starting point of your own model.""" +import numpy as np import tensorflow as tf -from openfl.federated import KerasTaskRunner +from openfl.utilities import Metric +from openfl.federated import TensorFlowTaskRunner -class TensorFlowCNN(KerasTaskRunner): +class CNN(TensorFlowTaskRunner): """Initialize. Args: @@ -17,40 +19,38 @@ class TensorFlowCNN(KerasTaskRunner): """ def __init__(self, **kwargs): - """Initialize. - - Args: - **kwargs: Additional parameters to pass to the function - - """ super().__init__(**kwargs) - self.model = self.create_model( + self.model = self.build_model( self.feature_shape, self.data_loader.num_classes, **kwargs ) self.initialize_tensorkeys_for_functions() - def create_model(self, - input_shape, - num_classes, - training_smoothing=32.0, - validation_smoothing=1.0, - **kwargs): - """Create the TensorFlow CNN Histology model. + self.model.summary(print_fn=self.logger.info) + + self.logger.info(f'Train Set Size : {self.get_train_data_size()}') + self.logger.info(f'Valid Set Size : {self.get_valid_data_size()}') + + def build_model(self, + input_shape, + num_classes, + **kwargs): + """ + Build and compile a convolutional neural network model. Args: - training_smoothing (float): (Default=32.0) - validation_smoothing (float): (Default=1.0) + input_shape (List[int]): The shape of the data + num_classes (int): The number of classes of the dataset **kwargs: Additional parameters to pass to the function + Returns: + keras.src.engine.functional.Functional + A compiled Keras model ready for training. """ - print(tf.config.threading.get_intra_op_parallelism_threads()) - print(tf.config.threading.get_inter_op_parallelism_threads()) - # ## Define Model - # - # Convolutional neural network model + + # Define Model using Functional API inputs = tf.keras.layers.Input(shape=input_shape) conv = tf.keras.layers.Conv2D( @@ -96,13 +96,30 @@ def create_model(self, metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], ) - self.tvars = model.layers - print(f'layer names: {[var.name for var in self.tvars]}') + return model - self.opt_vars = self.optimizer.variables() - print(f'optimizer vars: {self.opt_vars}') + def train_(self, batch_generator, metrics: list = None, **kwargs): + """ + Train single epoch. - # Two opt_vars for one tvar: gradient and square sum for RMSprop. - self.fl_vars = self.tvars + self.opt_vars + Override this function for custom training. - return model + Args: + batch_generator (generator): Generator of training batches. + Each batch is a tuple of N train images and N train labels + where N is the batch size of the DataLoader of the current TaskRunner instance. + metrics (List[str]): A list of metric names to compute and save + **kwargs (dict): Additional keyword arguments + + Returns: + list: Metric objects containing the computed metrics + """ + + history = self.model.fit(batch_generator, + verbose=1, + **kwargs) + results = [] + for metric in metrics: + value = np.mean([history.history[metric]]) + results.append(Metric(name=metric, value=np.array(value))) + return results diff --git a/openfl-workspace/tf_cnn_histology/src/tfhistology_inmemory.py b/openfl-workspace/tf_cnn_histology/src/tfhistology_inmemory.py deleted file mode 100644 index 69cf5fc7e6..0000000000 --- a/openfl-workspace/tf_cnn_histology/src/tfhistology_inmemory.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (C) 2020-2021 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -"""You may copy this file as the starting point of your own model.""" - -from openfl.federated import TensorFlowDataLoader -from .tfds_utils import load_histology_shard - - -class TensorFlowHistologyInMemory(TensorFlowDataLoader): - """TensorFlow Data Loader for Colorectal Histology Dataset.""" - - def __init__(self, data_path, batch_size, **kwargs): - """ - Initialize. - - Args: - data_path: File path for the dataset - batch_size (int): The batch size for the data loader - **kwargs: Additional arguments, passed to super init and load_mnist_shard - """ - super().__init__(batch_size, **kwargs) - - _, num_classes, X_train, y_train, X_valid, y_valid = load_histology_shard( - shard_num=data_path, - categorical=False, **kwargs - ) - - self.X_train = X_train - self.y_train = y_train - self.X_valid = X_valid - self.y_valid = y_valid - - self.num_classes = num_classes diff --git a/openfl-workspace/tf_cnn_mnist/.workspace b/openfl-workspace/tf_cnn_mnist/.workspace new file mode 100644 index 0000000000..3c2c5d08b4 --- /dev/null +++ b/openfl-workspace/tf_cnn_mnist/.workspace @@ -0,0 +1,2 @@ +current_plan_name: default + diff --git a/openfl-workspace/tf_cnn_mnist/plan/cols.yaml b/openfl-workspace/tf_cnn_mnist/plan/cols.yaml new file mode 100644 index 0000000000..95307de3bc --- /dev/null +++ b/openfl-workspace/tf_cnn_mnist/plan/cols.yaml @@ -0,0 +1,5 @@ +# Copyright (C) 2020-2021 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +collaborators: + \ No newline at end of file diff --git a/openfl-workspace/tf_cnn_mnist/plan/data.yaml b/openfl-workspace/tf_cnn_mnist/plan/data.yaml new file mode 100644 index 0000000000..257c7825fe --- /dev/null +++ b/openfl-workspace/tf_cnn_mnist/plan/data.yaml @@ -0,0 +1,7 @@ +# Copyright (C) 2020-2021 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +# collaborator_name,data_directory_path +one,1 + + diff --git a/openfl-workspace/tf_3dunet_brats/plan/defaults/defaults b/openfl-workspace/tf_cnn_mnist/plan/defaults similarity index 100% rename from openfl-workspace/tf_3dunet_brats/plan/defaults/defaults rename to openfl-workspace/tf_cnn_mnist/plan/defaults diff --git a/openfl-workspace/tf_cnn_mnist/plan/plan.yaml b/openfl-workspace/tf_cnn_mnist/plan/plan.yaml new file mode 100644 index 0000000000..5a9164260a --- /dev/null +++ b/openfl-workspace/tf_cnn_mnist/plan/plan.yaml @@ -0,0 +1,42 @@ +# Copyright (C) 2020-2024 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +aggregator : + defaults : plan/defaults/aggregator.yaml + template : openfl.component.Aggregator + settings : + init_state_path : save/init.pbuf + best_state_path : save/best.pbuf + last_state_path : save/last.pbuf + rounds_to_train : 10 + +collaborator : + defaults : plan/defaults/collaborator.yaml + template : openfl.component.Collaborator + settings : + delta_updates : false + opt_treatment : RESET + +data_loader : + defaults : plan/defaults/data_loader.yaml + template : src.dataloader.MNISTDataloader + settings : + collaborator_count : 2 + data_group_name : mnist + batch_size : 256 + +task_runner : + defaults : plan/defaults/task_runner.yaml + template : src.taskrunner.CNN + +network : + defaults : plan/defaults/network.yaml + +assigner : + defaults : plan/defaults/assigner.yaml + +tasks : + defaults : plan/defaults/tasks_tensorflow.yaml + +compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml diff --git a/openfl-workspace/tf_cnn_mnist/requirements.txt b/openfl-workspace/tf_cnn_mnist/requirements.txt new file mode 100644 index 0000000000..4a8d507a47 --- /dev/null +++ b/openfl-workspace/tf_cnn_mnist/requirements.txt @@ -0,0 +1 @@ +tensorflow==2.15.1 \ No newline at end of file diff --git a/openfl-workspace/tf_cnn_mnist/src/__init__.py b/openfl-workspace/tf_cnn_mnist/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openfl-workspace/tf_cnn_mnist/src/dataloader.py b/openfl-workspace/tf_cnn_mnist/src/dataloader.py new file mode 100644 index 0000000000..5adac25874 --- /dev/null +++ b/openfl-workspace/tf_cnn_mnist/src/dataloader.py @@ -0,0 +1,129 @@ +# Copyright (C) 2020-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""You may copy this file as the starting point of your own model.""" + +from openfl.federated import TensorFlowDataLoader +from logging import getLogger + +import numpy as np +from tensorflow.python.keras.utils.data_utils import get_file + +logger = getLogger(__name__) + + +class MNISTDataloader(TensorFlowDataLoader): + """TensorFlow Data Loader for MNIST Dataset.""" + + def __init__(self, data_path, batch_size, **kwargs): + """ + Initialize. + + Args: + data_path: File path for the dataset + batch_size (int): The batch size for the data loader + **kwargs: Additional arguments, passed to super init and load_mnist_shard + """ + super().__init__(batch_size, **kwargs) + + num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard( + shard_num=int(data_path), **kwargs + ) + + self.X_train = X_train + self.y_train = y_train + self.X_valid = X_valid + self.y_valid = y_valid + + self.num_classes = num_classes + + +def _load_raw_datashards(shard_num, collaborator_count): + """ + Load the raw data by shard. + + Returns tuples of the dataset shard divided into training and validation. + + Args: + shard_num (int): The shard number to use + collaborator_count (int): The number of collaborators in the federation + + Returns: + 2 tuples: (image, label) of the training, validation dataset + """ + origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/' + path = get_file('mnist.npz', + origin=origin_folder + 'mnist.npz', + file_hash='731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1') + + with np.load(path) as f: + # get all of mnist + X_train_tot = f['x_train'] + y_train_tot = f['y_train'] + + X_valid_tot = f['x_test'] + y_valid_tot = f['y_test'] + + # create the shards + shard_num = int(shard_num) + X_train = X_train_tot[shard_num::collaborator_count] + y_train = y_train_tot[shard_num::collaborator_count] + + X_valid = X_valid_tot[shard_num::collaborator_count] + y_valid = y_valid_tot[shard_num::collaborator_count] + + return (X_train, y_train), (X_valid, y_valid) + + +def load_mnist_shard(shard_num, collaborator_count, categorical=True, + channels_last=True, **kwargs): + """ + Load the MNIST dataset. + + Args: + shard_num (int): The shard to use from the dataset + collaborator_count (int): The number of collaborators in the federation + categorical (bool): True = convert the labels to one-hot encoded + vectors (Default = True) + channels_last (bool): True = The input images have the channels + last (Default = True) + **kwargs: Additional parameters to pass to the function + + Returns: + list: The input shape + int: The number of classes + numpy.ndarray: The training data + numpy.ndarray: The training labels + numpy.ndarray: The validation data + numpy.ndarray: The validation labels + """ + img_rows, img_cols = 28, 28 + num_classes = 10 + + (X_train, y_train), (X_valid, y_valid) = _load_raw_datashards( + shard_num, collaborator_count + ) + + if channels_last: + X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1) + X_valid = X_valid.reshape(X_valid.shape[0], img_rows, img_cols, 1) + else: + X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols) + X_valid = X_valid.reshape(X_valid.shape[0], 1, img_rows, img_cols) + + X_train = X_train.astype('float32') + X_valid = X_valid.astype('float32') + X_train /= 255 + X_valid /= 255 + + logger.info(f'MNIST > X_train Shape : {X_train.shape}') + logger.info(f'MNIST > y_train Shape : {y_train.shape}') + logger.info(f'MNIST > Train Samples : {X_train.shape[0]}') + logger.info(f'MNIST > Valid Samples : {X_valid.shape[0]}') + + if categorical: + # convert class vectors to binary class matrices + y_train = np.eye(num_classes)[y_train] + y_valid = np.eye(num_classes)[y_valid] + + return num_classes, X_train, y_train, X_valid, y_valid diff --git a/openfl-workspace/tf_cnn_mnist/src/taskrunner.py b/openfl-workspace/tf_cnn_mnist/src/taskrunner.py new file mode 100644 index 0000000000..bb3cced194 --- /dev/null +++ b/openfl-workspace/tf_cnn_mnist/src/taskrunner.py @@ -0,0 +1,93 @@ +# Copyright (C) 2020-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""You may copy this file as the starting point of your own model.""" + +import numpy as np +import tensorflow as tf + +from openfl.utilities import Metric +from openfl.federated import TensorFlowTaskRunner + + +class CNN(TensorFlowTaskRunner): + """A basic convolutional neural network model.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.model = self.build_model(self.feature_shape, self.data_loader.num_classes, **kwargs) + + self.initialize_tensorkeys_for_functions() + + self.model.summary(print_fn=self.logger.info) + + self.logger.info(f'Train Set Size : {self.get_train_data_size()}') + self.logger.info(f'Valid Set Size : {self.get_valid_data_size()}') + + def build_model(self, + input_shape, + num_classes, + **kwargs): + """ + Build and compile a convolutional neural network model. + + Args: + input_shape (List[int]): The shape of the data + num_classes (int): The number of classes of the dataset + **kwargs (dict): Additional keyword arguments [optional] + + Returns: + tf.keras.models.Sequential + A compiled Keras Sequential model ready for training. + """ + + model = tf.keras.models.Sequential([ + tf.keras.layers.Conv2D(16, + kernel_size=(4, 4), + strides=(2, 2), + activation='relu', + input_shape=input_shape), + tf.keras.layers.Conv2D(32, + kernel_size=(4, 4), + strides=(2, 2), + activation='relu'), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(100, + activation='relu'), + tf.keras.layers.Dense(num_classes, + activation='softmax') + ]) + + model.compile(loss=tf.keras.losses.categorical_crossentropy, + optimizer=tf.keras.optimizers.Adam(), + metrics=['accuracy']) + + return model + + def train_(self, batch_generator, metrics: list = None, **kwargs): + """ + Train single epoch. + + Override this function for custom training. + + Args: + batch_generator (generator): Generator of training batches. + Each batch is a tuple of N train images and N train labels + where N is the batch size of the DataLoader of the current TaskRunner instance. + metrics (List[str]): A list of metric names to compute and save + **kwargs (dict): Additional keyword arguments + + Returns: + list: Metric objects containing the computed metrics + """ + + history = self.model.fit(batch_generator, + verbose=1, + **kwargs) + results = [] + for metric in metrics: + value = np.mean([history.history[metric]]) + results.append(Metric(name=metric, value=np.array(value))) + + return results diff --git a/openfl-workspace/workspace/plan/defaults/tasks_keras.yaml b/openfl-workspace/workspace/plan/defaults/tasks_keras.yaml index 79d067d8d2..0ef460da87 100644 --- a/openfl-workspace/workspace/plan/defaults/tasks_keras.yaml +++ b/openfl-workspace/workspace/plan/defaults/tasks_keras.yaml @@ -1,5 +1,5 @@ aggregated_model_validation: - function : validate + function : validate_task kwargs : batch_size : 32 apply : global @@ -7,7 +7,7 @@ aggregated_model_validation: - accuracy locally_tuned_model_validation: - function : validate + function : validate_task kwargs : batch_size : 32 apply : local @@ -15,7 +15,7 @@ locally_tuned_model_validation: - accuracy train: - function : train + function : train_task kwargs : batch_size : 32 epochs : 1 diff --git a/openfl-workspace/workspace/plan/defaults/tasks_tensorflow.yaml b/openfl-workspace/workspace/plan/defaults/tasks_tensorflow.yaml index 6d000cc618..e3d5348ca3 100644 --- a/openfl-workspace/workspace/plan/defaults/tasks_tensorflow.yaml +++ b/openfl-workspace/workspace/plan/defaults/tasks_tensorflow.yaml @@ -1,23 +1,23 @@ aggregated_model_validation: - function : validate + function : validate_task kwargs : batch_size : 32 apply : global metrics : - - acc + - accuracy locally_tuned_model_validation: - function : validate + function : validate_task kwargs : batch_size : 32 apply : local metrics : - - acc + - accuracy train: - function : train_batches + function : train_task kwargs : - batch_size : 32 - metrics : + batch_size : 32 + metrics : - loss - epochs : 1 + epochs : 1 diff --git a/openfl/federated/__init__.py b/openfl/federated/__init__.py index b2b4f4fd1f..7849172bf4 100644 --- a/openfl/federated/__init__.py +++ b/openfl/federated/__init__.py @@ -3,14 +3,14 @@ """openfl.federated package.""" -import pkgutil +import importlib.util from .plan import Plan # NOQA from .task import TaskRunner # NOQA from .data import DataLoader # NOQA -if pkgutil.find_loader('tensorflow'): +if importlib.util.find_spec('tensorflow'): from .task import TensorFlowTaskRunner, KerasTaskRunner, FederatedModel # NOQA from .data import TensorFlowDataLoader, KerasDataLoader, FederatedDataSet # NOQA -if pkgutil.find_loader('torch'): +if importlib.util.find_spec('torch'): from .task import PyTorchTaskRunner, FederatedModel # NOQA from .data import PyTorchDataLoader, FederatedDataSet # NOQA diff --git a/openfl/federated/task/__init__.py b/openfl/federated/task/__init__.py index b5efcdcd50..a0837db687 100644 --- a/openfl/federated/task/__init__.py +++ b/openfl/federated/task/__init__.py @@ -3,23 +3,22 @@ """Task package.""" -import pkgutil +import importlib.util from warnings import catch_warnings from warnings import simplefilter with catch_warnings(): simplefilter(action='ignore', category=FutureWarning) - if pkgutil.find_loader('tensorflow'): + if importlib.util.find_spec('tensorflow'): # ignore deprecation warnings in command-line interface import tensorflow # NOQA from .runner import TaskRunner # NOQA - -if pkgutil.find_loader('tensorflow'): +if importlib.util.find_spec('tensorflow'): from .runner_tf import TensorFlowTaskRunner # NOQA from .runner_keras import KerasTaskRunner # NOQA from .fl_model import FederatedModel # NOQA -if pkgutil.find_loader('torch'): +if importlib.util.find_spec('torch'): from .runner_pt import PyTorchTaskRunner # NOQA from .fl_model import FederatedModel # NOQA diff --git a/openfl/federated/task/runner_keras.py b/openfl/federated/task/runner_keras.py index c7daaa3d33..8cd0cab6c3 100644 --- a/openfl/federated/task/runner_keras.py +++ b/openfl/federated/task/runner_keras.py @@ -61,8 +61,8 @@ def rebuild_model(self, round_num, input_tensor_dict, validation=False): else: self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) - def train(self, col_name, round_num, input_tensor_dict, - metrics, epochs=1, batch_size=1, **kwargs): + def train_task(self, col_name, round_num, input_tensor_dict, + metrics, epochs=1, batch_size=1, **kwargs): """ Perform the training. @@ -81,9 +81,9 @@ def train(self, col_name, round_num, input_tensor_dict, self.rebuild_model(round_num, input_tensor_dict) for epoch in range(epochs): self.logger.info(f'Run {epoch} epoch of {round_num} round') - results = self.train_iteration(self.data_loader.get_train_loader(batch_size), - metrics=metrics, - **kwargs) + results = self.train_(self.data_loader.get_train_loader(batch_size), + metrics=metrics, + **kwargs) # output metric tensors (scalar) origin = col_name @@ -145,7 +145,7 @@ def train(self, col_name, round_num, input_tensor_dict, return global_tensor_dict, local_tensor_dict - def train_iteration(self, batch_generator, metrics: list = None, **kwargs): + def train_(self, batch_generator, metrics: list = None, **kwargs): """Train single epoch. Override this function for custom training. @@ -185,7 +185,7 @@ def train_iteration(self, batch_generator, metrics: list = None, **kwargs): results.append(Metric(name=metric, value=np.array(value))) return results - def validate(self, col_name, round_num, input_tensor_dict, **kwargs): + def validate_task(self, col_name, round_num, input_tensor_dict, **kwargs): """ Run the trained model on validation data; report results. @@ -396,7 +396,7 @@ def set_required_tensorkeys_for_function(self, func_name, # of the methods in the class and declare the tensors. # For now this is done manually - if func_name == 'validate': + if func_name == 'validate_task': # Should produce 'apply=global' or 'apply=local' local_model = 'apply' + kwargs['apply'] self.required_tensorkeys_for_function[func_name][ @@ -419,7 +419,7 @@ def get_required_tensorkeys_for_function(self, func_name, **kwargs): List [TensorKey] """ - if func_name == 'validate': + if func_name == 'validate_task': local_model = 'apply=' + str(kwargs['apply']) return self.required_tensorkeys_for_function[func_name][local_model] else: @@ -449,19 +449,19 @@ def update_tensorkeys_for_functions(self): opt_names = self._get_weights_names(self.model.optimizer) tensor_names = model_layer_names + opt_names self.logger.debug(f'Updating model tensor names: {tensor_names}') - self.required_tensorkeys_for_function['train'] = [ + self.required_tensorkeys_for_function['train_task'] = [ TensorKey(tensor_name, 'GLOBAL', 0, ('model',)) for tensor_name in tensor_names ] # Validation may be performed on local or aggregated (global) model, # so there is an extra lookup dimension for kwargs - self.required_tensorkeys_for_function['validate'] = {} - self.required_tensorkeys_for_function['validate']['local_model=True'] = [ + self.required_tensorkeys_for_function['validate_task'] = {} + self.required_tensorkeys_for_function['validate_task']['local_model=True'] = [ TensorKey(tensor_name, 'LOCAL', 0, ('trained',)) for tensor_name in tensor_names ] - self.required_tensorkeys_for_function['validate']['local_model=False'] = [ + self.required_tensorkeys_for_function['validate_task']['local_model=False'] = [ TensorKey(tensor_name, 'GLOBAL', 0, ('model',)) for tensor_name in tensor_names ] @@ -502,31 +502,31 @@ def initialize_tensorkeys_for_functions(self, with_opt_vars=False): **self.tensor_dict_split_fn_kwargs ) - self.required_tensorkeys_for_function['train'] = [ + self.required_tensorkeys_for_function['train_task'] = [ TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',)) for tensor_name in global_model_dict ] - self.required_tensorkeys_for_function['train'] += [ + self.required_tensorkeys_for_function['train_task'] += [ TensorKey(tensor_name, 'LOCAL', 0, False, ('model',)) for tensor_name in local_model_dict ] # Validation may be performed on local or aggregated (global) model, # so there is an extra lookup dimension for kwargs - self.required_tensorkeys_for_function['validate'] = {} + self.required_tensorkeys_for_function['validate_task'] = {} # TODO This is not stateless. The optimizer will not be - self.required_tensorkeys_for_function['validate']['apply=local'] = [ + self.required_tensorkeys_for_function['validate_task']['apply=local'] = [ TensorKey(tensor_name, 'LOCAL', 0, False, ('trained',)) for tensor_name in { **global_model_dict_val, **local_model_dict_val } ] - self.required_tensorkeys_for_function['validate']['apply=global'] = [ + self.required_tensorkeys_for_function['validate_task']['apply=global'] = [ TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',)) for tensor_name in global_model_dict_val ] - self.required_tensorkeys_for_function['validate']['apply=global'] += [ + self.required_tensorkeys_for_function['validate_task']['apply=global'] += [ TensorKey(tensor_name, 'LOCAL', 0, False, ('model',)) for tensor_name in local_model_dict_val ] diff --git a/openfl/federated/task/runner_tf.py b/openfl/federated/task/runner_tf.py index f63ffce3f8..56b61c5b64 100644 --- a/openfl/federated/task/runner_tf.py +++ b/openfl/federated/task/runner_tf.py @@ -1,24 +1,17 @@ -# Copyright (C) 2020-2023 Intel Corporation +# Copyright (C) 2020-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """TensorFlowTaskRunner module.""" - import numpy as np -import tensorflow.compat.v1 as tf -from tqdm import tqdm +import tensorflow as tf +from openfl.utilities import change_tags, Metric, TensorKey from openfl.utilities.split import split_tensor_dict_for_holdouts -from openfl.utilities import TensorKey from .runner import TaskRunner class TensorFlowTaskRunner(TaskRunner): - """ - Base class for TensorFlow models in the Federated Learning solution. - - child classes should have __init__ function signature (self, data, kwargs), - and should overwrite at least the following while defining the model - """ + """The base model for Keras models in the federation.""" def __init__(self, **kwargs): """ @@ -27,53 +20,26 @@ def __init__(self, **kwargs): Args: **kwargs: Additional parameters to pass to the function """ - tf.disable_v2_behavior() - super().__init__(**kwargs) - self.assign_ops = None - self.placeholders = None - - self.tvar_assign_ops = None - self.tvar_placeholders = None + self.model = tf.keras.Model() - # construct the shape needed for the input features - self.input_shape = (None,) + self.data_loader.get_feature_shape() + self.model_tensor_names = [] - # Required tensorkeys for all public functions in TensorFlowTaskRunner + # this is a map of all of the required tensors for each of the public + # functions in KerasTaskRunner self.required_tensorkeys_for_function = {} - - # Required tensorkeys for all public functions in TensorFlowTaskRunner - self.required_tensorkeys_for_function = {} - - # tensorflow session - self.sess = None - # input featrures to the model - self.X = None - # input labels to the model - self.y = None - # optimizer train step operation - self.train_step = None - # model loss function - self.loss = None - # model output tensor - self.output = None - # function used to validate the model outputs against labels - self.validation_metric = None - # tensorflow trainable variables - self.tvars = None - # self.optimizer.variables() once self.optimizer is defined - self.opt_vars = None - # self.tvars + self.opt_vars - self.fl_vars = None + tf.keras.backend.clear_session() def rebuild_model(self, round_num, input_tensor_dict, validation=False): """ Parse tensor names and update weights of model. Handles the optimizer treatment. - Returns: - None + Returns + ------- + None """ + if self.opt_treatment == 'RESET': self.reset_opt_vars() self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) @@ -83,49 +49,38 @@ def rebuild_model(self, round_num, input_tensor_dict, validation=False): else: self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) - def train_batches(self, col_name, round_num, input_tensor_dict, - epochs=1, use_tqdm=False, **kwargs): + def train_task(self, col_name, round_num, input_tensor_dict, + metrics, epochs=1, batch_size=1, **kwargs): """ Perform the training. - Is expected to perform draws randomly, without replacement until data is exausted. Then - data is replaced and shuffled and draws continue. + Is expected to perform draws randomly, without replacement until data is exausted. + Then data is replaced and shuffled and draws continue. - Args: - use_tqdm (bool): True = use tqdm to print a progress - bar (Default=False) - epochs (int): Number of epochs to train - Returns: - float: loss metric + Returns + ------- + dict + 'TensorKey: nparray' """ - batch_size = self.data_loader.batch_size - - if kwargs['batch_size']: - batch_size = kwargs['batch_size'] + if metrics is None: + raise KeyError('metrics must be defined') # rebuild model with updated weights self.rebuild_model(round_num, input_tensor_dict) - - tf.keras.backend.set_learning_phase(True) - losses = [] - for epoch in range(epochs): self.logger.info(f'Run {epoch} epoch of {round_num} round') - # get iterator for batch draws (shuffling happens here) - gen = self.data_loader.get_train_loader(batch_size) - if use_tqdm: - gen = tqdm.tqdm(gen, desc='training epoch') + results = self.train_(self.data_loader.get_train_loader(batch_size), + metrics=metrics, + **kwargs) - for (X, y) in gen: - losses.append(self.train_batch(X, y)) - - # Output metric tensors (scalar) + # output metric tensors (scalar) origin = col_name tags = ('trained',) output_metric_dict = { TensorKey( - self.loss_name, origin, round_num, True, ('metric',) - ): np.array(np.mean(losses)) + metric_name, origin, round_num, True, ('metric',) + ): metric_value + for (metric_name, metric_value) in results } # output model tensors (Doesn't include TensorKey) @@ -135,23 +90,24 @@ def train_batches(self, col_name, round_num, input_tensor_dict, **self.tensor_dict_split_fn_kwargs ) - # Create global tensorkeys + # create global tensorkeys global_tensorkey_model_dict = { TensorKey(tensor_name, origin, round_num, False, tags): nparray for tensor_name, nparray in global_model_dict.items() } - # Create tensorkeys that should stay local + # create tensorkeys that should stay local local_tensorkey_model_dict = { TensorKey(tensor_name, origin, round_num, False, tags): nparray for tensor_name, nparray in local_model_dict.items() } - # The train/validate aggregated function of the next round will - # look for the updated model parameters. - # This ensures they will be resolved locally + # the train/validate aggregated function of the next round will look + # for the updated model parameters. + # this ensures they will be resolved locally next_local_tensorkey_model_dict = { TensorKey( tensor_name, origin, round_num + 1, False, ('model',) - ): nparray for tensor_name, nparray in local_model_dict.items()} + ): nparray for tensor_name, nparray in local_model_dict.items() + } global_tensor_dict = { **output_metric_dict, @@ -162,11 +118,11 @@ def train_batches(self, col_name, round_num, input_tensor_dict, **next_local_tensorkey_model_dict } - # Update the required tensors if they need to be pulled from - # the aggregator + # update the required tensors if they need to be pulled from the + # aggregator # TODO this logic can break if different collaborators have different - # roles between rounds. - # For example, if a collaborator only performs validation in the first + # roles between rounds. + # for example, if a collaborator only performs validation in the first # round but training in the second, it has no way of knowing the # optimizer state tensor names to request from the aggregator because # these are only created after training occurs. A work around could @@ -177,51 +133,85 @@ def train_batches(self, col_name, round_num, input_tensor_dict, return global_tensor_dict, local_tensor_dict - def train_batch(self, X, y): - """ - Train the model on a single batch. + def train_(self, batch_generator, metrics: list = None, **kwargs): + """Train single epoch. + + Override this function for custom training. Args: - X: Input to the model - y: Ground truth label to the model + batch_generator: Generator of training batches. + Each batch is a tuple of N train images and N train labels + where N is the batch size of the DataLoader of the current TaskRunner instance. - Returns: - float: loss metric + epochs: Number of epochs to train. + metrics: Names of metrics to save. """ - feed_dict = {self.X: X, self.y: y} - - # run the train step and return the loss - _, loss = self.sess.run([self.train_step, self.loss], feed_dict=feed_dict) - - return loss - - def validate(self, col_name, round_num, - input_tensor_dict, use_tqdm=False, **kwargs): + if metrics is None: + metrics = [] + # TODO Currently assuming that all metrics are defined at + # initialization (build_model). + # If metrics are added (i.e. not a subset of what was originally + # defined) then the model must be recompiled. + model_metrics_names = self.model.metrics_names + + # TODO if there are new metrics in the flplan that were not included + # in the originally + # compiled model, that behavior is not currently handled. + for param in metrics: + if param not in model_metrics_names: + raise ValueError( + f'KerasTaskRunner does not support specifying new metrics. ' + f'Param_metrics = {metrics}, model_metrics_names = {model_metrics_names}' + ) + + history = self.model.fit(batch_generator, + verbose=1, + **kwargs) + results = [] + for metric in metrics: + value = np.mean([history.history[metric]]) + results.append(Metric(name=metric, value=np.array(value))) + return results + + def validate_task(self, col_name, round_num, input_tensor_dict, **kwargs): """ - Run validation. + Run the trained model on validation data; report results. - Returns: - dict: {: } - """ - batch_size = self.data_loader.batch_size + Parameters + ---------- + input_tensor_dict : either the last aggregated or locally trained model - if kwargs['batch_size']: + Returns + ------- + output_tensor_dict : {TensorKey: nparray} (these correspond to acc, + precision, f1_score, etc.) + """ + if 'batch_size' in kwargs: batch_size = kwargs['batch_size'] + else: + batch_size = 1 self.rebuild_model(round_num, input_tensor_dict, validation=True) + param_metrics = kwargs['metrics'] - tf.keras.backend.set_learning_phase(False) - - score = 0 - - gen = self.data_loader.get_valid_loader(batch_size) - if use_tqdm: - gen = tqdm.tqdm(gen, desc='validating') - - for X, y in gen: - weight = X.shape[0] / self.data_loader.get_valid_data_size() - _, s = self.validate_batch(X, y) - score += s * weight + vals = self.model.evaluate( + self.data_loader.get_valid_loader(batch_size), + verbose=1 + ) + model_metrics_names = self.model.metrics_names + if type(vals) is not list: + vals = [vals] + ret_dict = dict(zip(model_metrics_names, vals)) + + # TODO if there are new metrics in the flplan that were not included in + # the originally compiled model, that behavior is not currently + # handled. + for param in param_metrics: + if param not in model_metrics_names: + raise ValueError( + f'KerasTaskRunner does not support specifying new metrics. ' + f'Param_metrics = {param_metrics}, model_metrics_names = {model_metrics_names}' + ) origin = col_name suffix = 'validate' @@ -229,112 +219,201 @@ def validate(self, col_name, round_num, suffix += '_local' else: suffix += '_agg' - tags = ('metric', suffix) + tags = ('metric',) + tags = change_tags(tags, add_field=suffix) output_tensor_dict = { - TensorKey( - self.validation_metric_name, origin, round_num, True, tags - ): np.array(score)} + TensorKey(metric, origin, round_num, True, tags): + np.array(ret_dict[metric]) + for metric in param_metrics} - # return empty dict for local metrics return output_tensor_dict, {} - def validate_batch(self, X, y): - """Validate the model on a single local batch. - - Args: - X: Input to the model - y: Ground truth label to the model + def save_native(self, filepath): + """Save model.""" + self.model.save(filepath) - Returns: - float: loss metric + def load_native(self, filepath): + """Load model.""" + self.model = tf.keras.models.load_model(filepath) + @staticmethod + def _get_weights_names(obj, with_opt_vars): """ - feed_dict = {self.X: X, self.y: y} + Get the list of weight names. - return self.sess.run( - [self.output, self.validation_metric], feed_dict=feed_dict) + Parameters + ---------- + obj : Model or Optimizer + The target object that we want to get the weights. - def get_tensor_dict(self, with_opt_vars=True): - """Get the dictionary weights. + with_opt_vars (bool): Specify if we want to get optimizer weights - Get the weights from the tensor + Returns + ------- + dict + The weight name list + """ + if with_opt_vars: + weight_names = [weight.name for weight in obj.variables] - Args: - with_opt_vars (bool): Specify if we also want to get the variables - of the optimizer + weight_names = [weight.name for weight in obj.weights] + return weight_names - Returns: - dict: The weight dictionary {: } + @staticmethod + def _get_weights_dict(obj, suffix='', with_opt_vars=False): + """ + Get the dictionary of weights. + Parameters + ---------- + obj : Model or Optimizer + The target object that we want to get the weights. + + with_opt_vars (bool): Specify if we want to get optimizer weights + + Returns + ------- + dict + The weight dictionary. """ - if with_opt_vars is True: - variables = self.fl_vars + + weights_dict = {} + if with_opt_vars: + weight_names = [weight.name for weight in obj.variables] + weight_values = [weight.numpy() for weight in obj.variables] else: - variables = self.tvars + weight_names = [weight.name for weight in obj.weights] + weight_values = obj.get_weights() - # FIXME: do this in one call? - return {var.name: val for var, val in zip( - variables, self.sess.run(variables))} + for name, value in zip(weight_names, weight_values): + weights_dict[name + suffix] = value + return weights_dict - def set_tensor_dict(self, tensor_dict, with_opt_vars): - """Set the tensor dictionary. + @staticmethod + def _set_weights_dict(obj, weights_dict, with_opt_vars=False): + """Set the object weights with a dictionary. - Set the model weights with a tensor - dictionary: {: }. + The obj can be a model or an optimizer. Args: - tensor_dict (dict): The model weights dictionary - with_opt_vars (bool): Specify if we also want to set the variables - of the optimizer - + obj (Model or Optimizer): The target object that we want to set + the weights. + weights_dict (dict): The weight dictionary. + with_opt_vars (bool): Specify if we want to set optimizer weights Returns: None """ + if with_opt_vars: - self.assign_ops, self.placeholders = tf_set_tensor_dict( - tensor_dict, self.sess, self.fl_vars, - self.assign_ops, self.placeholders - ) + weight_names = [weight.name for weight in obj.variables] else: - self.tvar_assign_ops, self.tvar_placeholders = tf_set_tensor_dict( - tensor_dict, - self.sess, - self.tvars, - self.tvar_assign_ops, - self.tvar_placeholders - ) + weight_names = [weight.name for weight in obj.weights] - def reset_opt_vars(self): - """Reinitialize the optimizer variables.""" - for v in self.opt_vars: - v.initializer.run(session=self.sess) + weight_values = [weights_dict[name] for name in weight_names] - def initialize_globals(self): - """Initialize Global Variables. + obj.set_weights(weight_values) - Initialize all global variables + def get_tensor_dict(self, with_opt_vars, suffix=''): + """ + Get the model weights as a tensor dictionary. + + Parameters + ---------- + with_opt_vars : bool + If we should include the optimizer's status. + suffix : string + Universally Returns: - None + dict: The tensor dictionary. """ - self.sess.run(tf.global_variables_initializer()) - def _get_weights_names(self, with_opt_vars=True): - """Get the weights. + model_weights = self._get_weights_dict(self.model, suffix) - Args: - with_opt_vars (bool): Specify if we also want to get the variables - of the optimizer. + if with_opt_vars: - Returns: - list : The weight names list + opt_weights = self._get_weights_dict(self.model.optimizer, suffix, with_opt_vars) + + model_weights.update(opt_weights) + + if len(opt_weights) == 0: + self.logger.debug( + "WARNING: We didn't find variables for the optimizer.") + return model_weights + + def set_tensor_dict(self, tensor_dict, with_opt_vars): """ - if with_opt_vars is True: - variables = self.fl_vars + Set the model weights with a tensor dictionary. + + Args: + tensor_dict: the tensor dictionary + with_opt_vars (bool): True = include the optimizer's status. + """ + if with_opt_vars is False: + # It is possible to pass in opt variables from the input tensor dict + # This will make sure that the correct layers are updated + model_weight_names = [weight.name for weight in self.model.weights] + model_weights_dict = { + name: tensor_dict[name] for name in model_weight_names + } + self._set_weights_dict(self.model, model_weights_dict) else: - variables = self.tvars + model_weight_names = [ + weight.name for weight in self.model.weights + ] + model_weights_dict = { + name: tensor_dict[name] for name in model_weight_names + } + + opt_weight_names = [ + weight.name for weight in self.model.optimizer.variables + ] + opt_weights_dict = { + name: tensor_dict[name] for name in opt_weight_names + } + self._set_weights_dict(self.model, model_weights_dict) + self._set_weights_dict(self.model.optimizer, opt_weights_dict, with_opt_vars) + + def reset_opt_vars(self): + """ + Reset optimizer variables. + + Resets the optimizer variables + + """ + for var in self.model.optimizer.variables(): + var.assign(tf.zeros_like(var)) + self.logger.debug('Optimizer variables reset') + + def set_required_tensorkeys_for_function(self, func_name, + tensor_key, **kwargs): + """ + Set the required tensors for specified function that could be called as part of a task. + + By default, this is just all of the layers and optimizer of the model. + Custom tensors should be added to this function + + Parameters + ---------- + func_name: string + tensor_key: TensorKey (namedtuple) + **kwargs: Any function arguments {} + + Returns + ------- + None + """ + # TODO there should be a way to programmatically iterate through all + # of the methods in the class and declare the tensors. + # For now this is done manually - return [var.name for var in variables] + if func_name == 'validate_task': + # Should produce 'apply=global' or 'apply=local' + local_model = 'apply' + kwargs['apply'] + self.required_tensorkeys_for_function[func_name][ + local_model].append(tensor_key) + else: + self.required_tensorkeys_for_function[func_name].append(tensor_key) def get_required_tensorkeys_for_function(self, func_name, **kwargs): """ @@ -342,15 +421,62 @@ def get_required_tensorkeys_for_function(self, func_name, **kwargs): By default, this is just all of the layers and optimizer of the model. - Returns: - list : [TensorKey] + Parameters + ---------- + None + + Returns + ------- + List + [TensorKey] """ - if func_name == 'validate': + if func_name == 'validate_task': local_model = 'apply=' + str(kwargs['apply']) return self.required_tensorkeys_for_function[func_name][local_model] else: return self.required_tensorkeys_for_function[func_name] + def update_tensorkeys_for_functions(self): + """ + Update the required tensors for all publicly accessible methods \ + that could be called as part of a task. + + By default, this is just all of the layers and optimizer of the model. + Custom tensors should be added to this function + + Parameters + ---------- + None + + Returns + ------- + None + """ + # TODO complete this function. It is only needed for opt_treatment, + # and making the model stateless + + # Minimal required tensors for train function + model_layer_names = self._get_weights_names(self.model) + opt_names = self._get_weights_names(self.model.optimizer) + tensor_names = model_layer_names + opt_names + self.logger.debug(f'Updating model tensor names: {tensor_names}') + self.required_tensorkeys_for_function['train_task'] = [ + TensorKey(tensor_name, 'GLOBAL', 0, ('model',)) + for tensor_name in tensor_names + ] + + # Validation may be performed on local or aggregated (global) model, + # so there is an extra lookup dimension for kwargs + self.required_tensorkeys_for_function['validate_task'] = {} + self.required_tensorkeys_for_function['validate_task']['local_model=True'] = [ + TensorKey(tensor_name, 'LOCAL', 0, ('trained',)) + for tensor_name in tensor_names + ] + self.required_tensorkeys_for_function['validate_task']['local_model=False'] = [ + TensorKey(tensor_name, 'GLOBAL', 0, ('model',)) + for tensor_name in tensor_names + ] + def initialize_tensorkeys_for_functions(self, with_opt_vars=False): """ Set the required tensors for all publicly accessible methods \ @@ -359,9 +485,16 @@ def initialize_tensorkeys_for_functions(self, with_opt_vars=False): By default, this is just all of the layers and optimizer of the model. Custom tensors should be added to this function + Parameters + ---------- + None + + Returns + ------- + None """ - # TODO there should be a way to programmatically iterate through - # all of the methods in the class and declare the tensors. + # TODO there should be a way to programmatically iterate through all + # of the methods in the class and declare the tensors. # For now this is done manually output_model_dict = self.get_tensor_dict(with_opt_vars=with_opt_vars) @@ -380,66 +513,31 @@ def initialize_tensorkeys_for_functions(self, with_opt_vars=False): **self.tensor_dict_split_fn_kwargs ) - self.required_tensorkeys_for_function['train_batches'] = [ + self.required_tensorkeys_for_function['train_task'] = [ TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',)) - for tensor_name in global_model_dict] - self.required_tensorkeys_for_function['train_batches'] += [ + for tensor_name in global_model_dict + ] + self.required_tensorkeys_for_function['train_task'] += [ TensorKey(tensor_name, 'LOCAL', 0, False, ('model',)) - for tensor_name in local_model_dict] + for tensor_name in local_model_dict + ] - # Validation may be performed on local or aggregated (global) - # model, so there is an extra lookup dimension for kwargs - self.required_tensorkeys_for_function['validate'] = {} + # Validation may be performed on local or aggregated (global) model, + # so there is an extra lookup dimension for kwargs + self.required_tensorkeys_for_function['validate_task'] = {} # TODO This is not stateless. The optimizer will not be - self.required_tensorkeys_for_function['validate']['apply=local'] = [ + self.required_tensorkeys_for_function['validate_task']['apply=local'] = [ TensorKey(tensor_name, 'LOCAL', 0, False, ('trained',)) for tensor_name in { **global_model_dict_val, **local_model_dict_val } ] - self.required_tensorkeys_for_function['validate']['apply=global'] = [ + self.required_tensorkeys_for_function['validate_task']['apply=global'] = [ TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',)) for tensor_name in global_model_dict_val ] - self.required_tensorkeys_for_function['validate']['apply=global'] += [ + self.required_tensorkeys_for_function['validate_task']['apply=global'] += [ TensorKey(tensor_name, 'LOCAL', 0, False, ('model',)) for tensor_name in local_model_dict_val - ] - - -# FIXME: what's a nicer construct than this? ugly interface. Perhaps we -# get an object with an assumed interface that lets is set/get these? -# Note that this will return the assign_ops and placeholder nodes it uses -# if called with None, it will create them. -# to avoid inflating the graph, caller should keep these and pass them back -# What if we want to set a different group of vars in the middle? -# It is good if it is the subset of the original variables. -def tf_set_tensor_dict(tensor_dict, session, variables, - assign_ops=None, placeholders=None): - """Tensorflow set tensor dictionary. - - Args: - tensor_dict: Dictionary of tensors - session: TensorFlow session - variables: TensorFlow variables - assign_ops: TensorFlow operations (Default=None) - placeholders: TensorFlow placeholders (Default=None) - - Returns: - assign_ops, placeholders - - """ - if placeholders is None: - placeholders = { - v.name: tf.placeholder(v.dtype, shape=v.shape) for v in variables - } - if assign_ops is None: - assign_ops = { - v.name: tf.assign(v, placeholders[v.name]) for v in variables - } - - for k, v in tensor_dict.items(): - session.run(assign_ops[k], feed_dict={placeholders[k]: v}) - - return assign_ops, placeholders + ] \ No newline at end of file