From 25797ad5fa3be4a399a51b81115d8c2ac2c40c80 Mon Sep 17 00:00:00 2001
From: Mody-SHARK
Date: Tue, 30 Nov 2021 09:19:13 +0100
Subject: [PATCH] First commit
---
.gitignore | 9 +
README.md | 49 +
__init__.py | 0
demo/demo_dropout_ce.py | 105 +
demo/demo_dropout_cebasic.py | 105 +
demo/demo_dropout_dice.py | 105 +
demo/demo_flipout_ce.py | 109 +
src/__init__.py | 0
src/config.py | 401 +++
src/dataloader/__init__.py | 0
src/dataloader/augmentations.py | 592 ++++
src/dataloader/dataset.py | 53 +
src/dataloader/extractors/__init__.py | 0
src/dataloader/extractors/han_deepmindtcia.py | 339 +++
src/dataloader/extractors/han_miccai2015.py | 339 +++
src/dataloader/han_deepmindtcia.py | 726 +++++
src/dataloader/han_miccai2015.py | 765 ++++++
src/dataloader/utils.py | 1273 +++++++++
src/dataloader/utils_viz.py | 437 +++
src/model/__init__.py | 0
src/model/losses.py | 488 ++++
src/model/models.py | 572 ++++
src/model/trainer.py | 1718 ++++++++++++
src/model/trainer_flipout.py | 2435 +++++++++++++++++
src/model/utils.py | 585 ++++
25 files changed, 11205 insertions(+)
create mode 100644 .gitignore
create mode 100644 README.md
create mode 100644 __init__.py
create mode 100644 demo/demo_dropout_ce.py
create mode 100644 demo/demo_dropout_cebasic.py
create mode 100644 demo/demo_dropout_dice.py
create mode 100644 demo/demo_flipout_ce.py
create mode 100644 src/__init__.py
create mode 100644 src/config.py
create mode 100644 src/dataloader/__init__.py
create mode 100644 src/dataloader/augmentations.py
create mode 100644 src/dataloader/dataset.py
create mode 100644 src/dataloader/extractors/__init__.py
create mode 100644 src/dataloader/extractors/han_deepmindtcia.py
create mode 100644 src/dataloader/extractors/han_miccai2015.py
create mode 100644 src/dataloader/han_deepmindtcia.py
create mode 100644 src/dataloader/han_miccai2015.py
create mode 100644 src/dataloader/utils.py
create mode 100644 src/dataloader/utils_viz.py
create mode 100644 src/model/__init__.py
create mode 100644 src/model/losses.py
create mode 100644 src/model/models.py
create mode 100644 src/model/trainer.py
create mode 100644 src/model/trainer_flipout.py
create mode 100644 src/model/utils.py
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..ad9170e
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,9 @@
+# Python
+__pycache__
+.pyc
+
+
+# Modeling
+_data
+_models
+_logs
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..496df7e
--- /dev/null
+++ b/README.md
@@ -0,0 +1,49 @@
+# Bayesian Uncertainty for Quality Assessment of Deep Learning Contours
+This repository contains Tensorflow2.4 code for the paper(s)
+ - Comparing Bayesian Models for Organ Contouring in Headand Neck Radiotherapy
+
+
+## Installation
+1. Install [Anaconda](https://docs.anaconda.com/anaconda/install/) with python3.7
+2. Install [git](https://git-scm.com/downloads)
+3. Open a terminal and follow the commands
+ - Clone this repository
+ - `git clone git@github.com:prerakmody/hansegmentation-uncertainty-qa.git`
+ - Create conda env
+ - (Specifically For Windows): `conda init powershell` (and restart the terminal)
+ - (For all plaforms)
+ ```
+ cd hansegmentation-uncertainty-qa
+ conda deactivate
+ conda create --name hansegmentation-uncertainty-qa python=3.8
+ conda activate hansegmentation-uncertainty-qa
+ conda develop . # check for conda.pth file in $ANACONDA_HOME/envs/hansegmentation-uncertainty-qa/lib/python3.8/site-packages
+ ```
+ - Install packages
+ - Tensorflow (check [here]((https://www.tensorflow.org/install/source#tested_build_configurations)) for CUDA/cuDNN requirements)
+ - (stick to the exact commands)
+ - For tensorflow2.4
+ ```
+ conda install -c nvidia cudnn=8.0.0=cuda11.0_0
+ pip install tensorflow==2.4
+ ```
+ - Check tensorflow installation
+ ```
+ python -c "import tensorflow as tf;print('\n\n\n====================== \n GPU Devices: ',tf.config.list_physical_devices('GPU'), '\n======================')"
+ python -c "import tensorflow as tf;print('\n\n\n====================== \n', tf.reduce_sum(tf.random.normal([1000, 1000])), '\n======================' )"
+ ```
+ - [unix] upon running either of the above commands, you will see tensorflow searching for library files like libcudart.so, libcublas.so, libcublasLt.so, libcufft.so, libcurand.so, libcusolver.so, libcusparse.so, libcudnn.so in the location `$ANACONDA_HOME/envs/hansegmentation-uncertainty-qa/lib/`
+ - [windows] upon running either of the above commands, you will see tensorflow searching for library files like cudart64_110.dll ... and so on in the location `$ANACONDA_HOME\envs\hansegmentation-uncertainty-qa\Library\bin`
+
+ - Other tensorflow pacakges
+ ```
+ pip install tensorflow-probability==0.12.1 tensorflow-addons==0.12.1
+ ```
+ - Other packages
+ ```
+ pip install scipy seaborn tqdm psutil humanize pynrrd pydicom SimpleITK itk scikit-image
+ pip install psutil humanize pynvml
+ ```
+
+# Notes
+ - All the `src/train{}.py` files are the ones used to train the models as shown in the `demo/` folder
\ No newline at end of file
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/demo/demo_dropout_ce.py b/demo/demo_dropout_ce.py
new file mode 100644
index 0000000..7ba6001
--- /dev/null
+++ b/demo/demo_dropout_ce.py
@@ -0,0 +1,105 @@
+# Import private libraries
+import src.config as config
+from src.model.trainer import Trainer,Validator
+
+# Import public libraries
+import os
+import pdb
+import traceback
+import tensorflow as tf
+from pathlib import Path
+
+
+
+if __name__ == "__main__":
+
+ exp_name = 'HansegmentationUncertaintyQA-Dropout-CE'
+
+ data_dir = Path(config.MAIN_DIR).joinpath('medical_dataloader', '_data')
+ resampled = True
+ crop_init = True
+ grid = True
+ batch_size = 2
+
+ model = config.MODEL_FOCUSNET_DROPOUT
+
+ # To train
+ params = {
+ 'exp_name': exp_name
+ , 'random_seed':42
+ , 'dataloader':{
+ 'data_dir': data_dir
+ , 'dir_type' : [config.DATALOADER_MICCAI2015_TRAIN, config.DATALOADER_MICCAI2015_TRAIN_ADD]
+ , 'resampled' : resampled
+ , 'crop_init' : crop_init
+ , 'grid' : grid
+ , 'random_grid': True
+ , 'filter_grid': False
+ , 'centred_prob' : 0.3
+ , 'batch_size' : batch_size
+ , 'shuffle' : 5
+ , 'prefetch_batch': 4
+ , 'parallel_calls': 3
+ }
+ , 'model': {
+ 'name': model
+ , 'optimizer' : config.OPTIMIZER_ADAM
+ , 'init_lr' : 0.001
+ , 'fixed_lr' : True
+ , 'epochs' : 1500
+ , 'epochs_save': 50
+ , 'epochs_eval': 50
+ , 'epochs_viz' : 500
+ , 'load_model':{
+ 'load':False, 'load_exp_name': None, 'load_epoch':-1, 'load_optimizer_lr':None
+ }
+ , 'profiler': {
+ 'profile': False
+ , 'epochs': [2,3]
+ , 'steps_per_epoch': 60
+ , 'starting_step': 4
+ }
+ , 'model_tboard': False
+ }
+ , 'metrics' : {
+ 'logging_tboard': True
+ # for full 3D volume
+ , 'metrics_eval': {'Dice': config.LOSS_DICE}
+ ## for smaller grid/patch
+ , 'metrics_loss' : {'CE': config.LOSS_CE} # [config.LOSS_CE, config.LOSS_DICE]
+ , 'loss_weighted' : {'CE': True}
+ , 'loss_mask' : {'CE': True}
+ , 'loss_combo' : {'CE': 1.0}
+ }
+ , 'others': {
+ 'epochs_timer': 20
+ , 'epochs_memory':5
+ }
+ }
+
+ # Call the trainer
+ trainer = Trainer(params)
+ trainer.train()
+
+ # To evaluate on MICCAI2015
+ params = {
+ 'exp_name': exp_name
+ , 'pid' : os.getpid()
+ , 'dataloader': {
+ 'data_dir' : data_dir
+ , 'resampled' : resampled
+ , 'grid' : grid
+ , 'crop_init' : crop_init
+ , 'batch_size' : batch_size
+ , 'prefetch_batch': 1
+ , 'dir_type' : [config.DATALOADER_MICCAI2015_TEST] # [config.DATALOADER_MICCAI2015_TESTONSITE]
+ , 'eval_type' : config.MODE_TEST
+ }
+ , 'model': {
+ 'name': model
+ , 'load_epoch' : 1000
+ , 'MC_RUNS' : 30
+ , 'training_bool' : True # [True=dropout-at-test-time, False=no-dropout-at-test-time]
+ }
+ , 'save': True
+ }
\ No newline at end of file
diff --git a/demo/demo_dropout_cebasic.py b/demo/demo_dropout_cebasic.py
new file mode 100644
index 0000000..d3ff9ad
--- /dev/null
+++ b/demo/demo_dropout_cebasic.py
@@ -0,0 +1,105 @@
+# Import private libraries
+import src.config as config
+from src.model.trainer import Trainer,Validator
+
+# Import public libraries
+import os
+import pdb
+import traceback
+import tensorflow as tf
+from pathlib import Path
+
+
+
+if __name__ == "__main__":
+
+ exp_name = 'HansegmentationUncertaintyQA-Dropout-CEBasic'
+
+ data_dir = Path(config.MAIN_DIR).joinpath('medical_dataloader', '_data')
+ resampled = True
+ crop_init = True
+ grid = True
+ batch_size = 2
+
+ model = config.MODEL_FOCUSNET_DROPOUT
+
+ # To train
+ params = {
+ 'exp_name': exp_name
+ , 'random_seed':42
+ , 'dataloader':{
+ 'data_dir': data_dir
+ , 'dir_type' : [config.DATALOADER_MICCAI2015_TRAIN, config.DATALOADER_MICCAI2015_TRAIN_ADD]
+ , 'resampled' : resampled
+ , 'crop_init' : crop_init
+ , 'grid' : grid
+ , 'random_grid': True
+ , 'filter_grid': False
+ , 'centred_prob' : 0.3
+ , 'batch_size' : batch_size
+ , 'shuffle' : 5
+ , 'prefetch_batch': 4
+ , 'parallel_calls': 3
+ }
+ , 'model': {
+ 'name': model
+ , 'optimizer' : config.OPTIMIZER_ADAM
+ , 'init_lr' : 0.001
+ , 'fixed_lr' : True
+ , 'epochs' : 1500
+ , 'epochs_save': 50
+ , 'epochs_eval': 50
+ , 'epochs_viz' : 500
+ , 'load_model':{
+ 'load':False, 'load_exp_name': None, 'load_epoch':-1, 'load_optimizer_lr':None
+ }
+ , 'profiler': {
+ 'profile': False
+ , 'epochs': [2,3]
+ , 'steps_per_epoch': 60
+ , 'starting_step': 4
+ }
+ , 'model_tboard': False
+ }
+ , 'metrics' : {
+ 'logging_tboard': True
+ # for full 3D volume
+ , 'metrics_eval': {'Dice': config.LOSS_DICE}
+ ## for smaller grid/patch
+ , 'metrics_loss' : {'CE-Basic': config.LOSS_CE_BASIC}
+ , 'loss_weighted' : {'CE-Basic': True}
+ , 'loss_mask' : {'CE-Basic': True}
+ , 'loss_combo' : {'CE-Basic': 1.0}
+ }
+ , 'others': {
+ 'epochs_timer': 20
+ , 'epochs_memory':5
+ }
+ }
+
+ # Call the trainer
+ trainer = Trainer(params)
+ trainer.train()
+
+ # To evaluate on MICCAI2015
+ params = {
+ 'exp_name': exp_name
+ , 'pid' : os.getpid()
+ , 'dataloader': {
+ 'data_dir' : data_dir
+ , 'resampled' : resampled
+ , 'grid' : grid
+ , 'crop_init' : crop_init
+ , 'batch_size' : batch_size
+ , 'prefetch_batch': 1
+ , 'dir_type' : [config.DATALOADER_MICCAI2015_TEST] # [config.DATALOADER_MICCAI2015_TESTONSITE]
+ , 'eval_type' : config.MODE_TEST
+ }
+ , 'model': {
+ 'name': model
+ , 'load_epoch' : 1000
+ , 'MC_RUNS' : 30
+ , 'training_bool' : True # [True=dropout-at-test-time, False=no-dropout-at-test-time]
+ }
+ , 'save': True
+ }
\ No newline at end of file
diff --git a/demo/demo_dropout_dice.py b/demo/demo_dropout_dice.py
new file mode 100644
index 0000000..3257769
--- /dev/null
+++ b/demo/demo_dropout_dice.py
@@ -0,0 +1,105 @@
+# Import private libraries
+import src.config as config
+from src.model.trainer import Trainer,Validator
+
+# Import public libraries
+import os
+import pdb
+import traceback
+import tensorflow as tf
+from pathlib import Path
+
+
+
+if __name__ == "__main__":
+
+ exp_name = 'HansegmentationUncertaintyQA-Dropout-DICE'
+
+ data_dir = Path(config.MAIN_DIR).joinpath('medical_dataloader', '_data')
+ resampled = True
+ crop_init = True
+ grid = True
+ batch_size = 2
+
+ model = config.MODEL_FOCUSNET_DROPOUT
+
+ # To train
+ params = {
+ 'exp_name': exp_name
+ , 'random_seed':42
+ , 'dataloader':{
+ 'data_dir': data_dir
+ , 'dir_type' : [config.DATALOADER_MICCAI2015_TRAIN, config.DATALOADER_MICCAI2015_TRAIN_ADD]
+ , 'resampled' : resampled
+ , 'crop_init' : crop_init
+ , 'grid' : grid
+ , 'random_grid': True
+ , 'filter_grid': False
+ , 'centred_prob' : 0.3
+ , 'batch_size' : batch_size
+ , 'shuffle' : 5
+ , 'prefetch_batch': 4
+ , 'parallel_calls': 3
+ }
+ , 'model': {
+ 'name': model
+ , 'optimizer' : config.OPTIMIZER_ADAM
+ , 'init_lr' : 0.001
+ , 'fixed_lr' : True
+ , 'epochs' : 1500
+ , 'epochs_save': 50
+ , 'epochs_eval': 50
+ , 'epochs_viz' : 500
+ , 'load_model':{
+ 'load':False, 'load_exp_name': None, 'load_epoch':-1, 'load_optimizer_lr':None
+ }
+ , 'profiler': {
+ 'profile': False
+ , 'epochs': [2,3]
+ , 'steps_per_epoch': 60
+ , 'starting_step': 4
+ }
+ , 'model_tboard': False
+ }
+ , 'metrics' : {
+ 'logging_tboard': True
+ # for full 3D volume
+ , 'metrics_eval': {'Dice': config.LOSS_DICE}
+ ## for smaller grid/patch
+ , 'metrics_loss' : {'Dice': config.LOSS_DICE}
+ , 'loss_weighted' : {'Dice': True}
+ , 'loss_mask' : {'Dice': True}
+ , 'loss_combo' : {'Dice': 1.0}
+ }
+ , 'others': {
+ 'epochs_timer': 20
+ , 'epochs_memory':5
+ }
+ }
+
+ # Call the trainer
+ trainer = Trainer(params)
+ trainer.train()
+
+ # To evaluate on MICCAI2015
+ params = {
+ 'exp_name': exp_name
+ , 'pid' : os.getpid()
+ , 'dataloader': {
+ 'data_dir' : data_dir
+ , 'resampled' : resampled
+ , 'grid' : grid
+ , 'crop_init' : crop_init
+ , 'batch_size' : batch_size
+ , 'prefetch_batch': 1
+ , 'dir_type' : [config.DATALOADER_MICCAI2015_TEST] # [config.DATALOADER_MICCAI2015_TESTONSITE]
+ , 'eval_type' : config.MODE_TEST
+ }
+ , 'model': {
+ 'name': model
+ , 'load_epoch' : 1000
+ , 'MC_RUNS' : 30
+ , 'training_bool' : True # [True=dropout-at-test-time, False=no-dropout-at-test-time]
+ }
+ , 'save': True
+ }
\ No newline at end of file
diff --git a/demo/demo_flipout_ce.py b/demo/demo_flipout_ce.py
new file mode 100644
index 0000000..faad657
--- /dev/null
+++ b/demo/demo_flipout_ce.py
@@ -0,0 +1,109 @@
+# Import private libraries
+import src.config as config
+from src.model.trainer import Trainer,Validator
+
+# Import public libraries
+import os
+import pdb
+import traceback
+import tensorflow as tf
+from pathlib import Path
+
+
+
+if __name__ == "__main__":
+
+ exp_name = 'HansegmentationUncertaintyQA-Flipout-CE'
+
+ data_dir = Path(config.MAIN_DIR).joinpath('medical_dataloader', '_data')
+ resampled = True
+ crop_init = True
+ grid = True
+ batch_size = 2
+
+ model = config.MODEL_FOCUSNET_FLIPOUT
+
+ # To train
+ params = {
+ 'exp_name': exp_name
+ , 'random_seed':42
+ , 'dataloader':{
+ 'data_dir': data_dir
+ , 'dir_type' : [config.DATALOADER_MICCAI2015_TRAIN, config.DATALOADER_MICCAI2015_TRAIN_ADD]
+ , 'resampled' : resampled
+ , 'crop_init' : crop_init
+ , 'grid' : grid
+ , 'random_grid': True
+ , 'filter_grid': False
+ , 'centred_prob' : 0.3
+ , 'batch_size' : batch_size
+ , 'shuffle' : 5
+ , 'prefetch_batch': 4
+ , 'parallel_calls': 3
+ }
+ , 'model': {
+ 'name': model
+ , 'kl_alpha_init' : 0.01
+ , 'kl_schedule' : config.KL_DIV_FIXED # [config.KL_DIV_FIXED, config.KL_DIV_ANNEALING]
+ , 'kl_scale_factor' : 154 # distribute the kl_div loss 'kl_scale_factor' times across an epoch i.e. dataset_len/batch_size=>
+ # (140,140,40)=308/2.0=154
+ , 'optimizer' : config.OPTIMIZER_ADAM
+ , 'init_lr' : 0.001
+ , 'fixed_lr' : True
+ , 'epochs' : 1500
+ , 'epochs_save': 50
+ , 'epochs_eval': 50
+ , 'epochs_viz' : 500
+ , 'load_model':{
+ 'load':False, 'load_exp_name': None, 'load_epoch':-1, 'load_optimizer_lr':None
+ }
+ , 'profiler': {
+ 'profile': False
+ , 'epochs': [2,3]
+ , 'steps_per_epoch': 60
+ , 'starting_step': 4
+ }
+ , 'model_tboard': False
+ }
+ , 'metrics' : {
+ 'logging_tboard': True
+ # for full 3D volume
+ , 'metrics_eval': {'Dice': config.LOSS_DICE}
+ ## for smaller grid/patch
+ , 'metrics_loss' : {'CE': config.LOSS_CE} # [config.LOSS_CE, config.LOSS_DICE]
+ , 'loss_weighted' : {'CE': True}
+ , 'loss_mask' : {'CE': True}
+ , 'loss_combo' : {'CE': 1.0}
+ }
+ , 'others': {
+ 'epochs_timer': 20
+ , 'epochs_memory':5
+ }
+ }
+
+ # Call the trainer
+ trainer = Trainer(params)
+ trainer.train()
+
+ # To evaluate on MICCAI2015
+ params = {
+ 'exp_name': exp_name
+ , 'pid' : os.getpid()
+ , 'dataloader': {
+ 'data_dir' : data_dir
+ , 'resampled' : resampled
+ , 'grid' : grid
+ , 'crop_init' : crop_init
+ , 'batch_size' : batch_size
+ , 'prefetch_batch': 1
+ , 'dir_type' : [config.DATALOADER_MICCAI2015_TEST] # [config.DATALOADER_MICCAI2015_TESTONSITE]
+ , 'eval_type' : config.MODE_TEST
+ }
+ , 'model': {
+ 'name': model
+ , 'load_epoch' : 1000
+ , 'MC_RUNS' : 30
+ , 'training_bool' : True # [True=dropout-at-test-time, False=no-dropout-at-test-time]
+ }
+ , 'save': True
+ }
\ No newline at end of file
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/config.py b/src/config.py
new file mode 100644
index 0000000..8bee63c
--- /dev/null
+++ b/src/config.py
@@ -0,0 +1,401 @@
+############################################################
+# INIT #
+############################################################
+
+from pathlib import Path
+PROJECT_DIR = Path(__file__).parent.absolute().parent.absolute()
+MAIN_DIR = Path(PROJECT_DIR).parent.absolute()
+
+import os
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+os.environ["TF_GPU_THREAD_MODE"] = "gpu_private" # to avoid large "Kernel Launch Time"
+
+import tensorflow as tf
+try:
+ if len(tf.config.list_physical_devices('GPU')):
+ tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True)
+ sys_details = tf.sysconfig.get_build_info()
+ print (' - [TFlow Build Info] ver: ', tf.__version__, 'CUDA(major.minor):', sys_details["cuda_version"], ' || cuDNN(major): ', sys_details["cudnn_version"])
+
+ else:
+ print (' - No GPU present!! Exiting ...')
+ import sys; sys.exit(1)
+except:
+ pass
+
+############################################################
+# MODEL RELATED #
+############################################################
+MODEL_CHKPOINT_MAINFOLDER = '_models'
+MODEL_CHKPOINT_NAME_FMT = 'ckpt_epoch{:03d}'
+MODEL_LOGS_FOLDERNAME = 'logs'
+MODEL_IMGS_FOLDERNAME = 'images'
+MODEL_PATCHES_FOLDERNAME = 'patches'
+
+EXT_NRRD = '.nrrd'
+
+MODE_TRAIN = 'Train'
+MODE_TRAIN_VAL = 'Train_val'
+MODE_VAL = 'Val'
+MODE_VAL_NEW = 'Val_New'
+MODE_TEST = 'Test'
+MODE_DEEPMINDTCIA_TEST_ONC = 'DeepMindTCIATestOnc'
+MODE_DEEPMINDTCIA_TEST_RAD = 'DeepMindTCIATestRad'
+MODE_DEEPMINDTCIA_VAL_ONC = 'DeepMindTCIAValOnc'
+MODE_DEEPMINDTCIA_VAL_RAD = 'DeepMindTCIAValRad'
+
+ACT_SIGMOID = 'sigmoid'
+ACT_SOFTMAX = 'softmax'
+
+MODEL_FOCUSNET_DROPOUT = 'ModelFocusNetDropOut'
+MODEL_FOCUSNET_FLIPOUT = 'ModelFocusNetFlipOut'
+
+OPTIMIZER_ADAM = 'Adam'
+
+THRESHOLD_SIGMA_IGNORE = 0.3
+MIN_SIZE_COMPONENT = 10
+
+KL_DIV_FIXED = 'fixed'
+KL_DIV_ANNEALING = 'annealing'
+
+############################################################
+# EVAL RELATED #
+############################################################
+KEY_DICE_AVG = 'dice_avg'
+KEY_DICE_LABELS = 'dice_labels'
+KEY_HD_AVG = 'hd_avg'
+KEY_HD_LABELS = 'hd_labels'
+KEY_HD95_AVG = 'hd95_avg'
+KEY_HD95_LABELS = 'hd95_labels'
+KEY_MSD_AVG = 'msd_avg'
+KEY_MSD_LABELS = 'msd_labels'
+KEY_ECE_AVG = 'ece_avg'
+KEY_ECE_LABELS = 'ece_labels'
+KEY_AVU_ENT = 'avu_ent'
+KEY_AVU_PAC_ENT = 'avu_pac_ent'
+KEY_AVU_PUI_ENT = 'avu_pui_ent'
+KEY_THRESH_ENT = 'avu_thresh_ent'
+KEY_AVU_MIF = 'avu_mif'
+KEY_AVU_PAC_MIF = 'avu_pac_mif'
+KEY_AVU_PUI_MIF = 'avu_pui_mif'
+KEY_THRESH_MIF = 'avu_thresh_mif'
+KEY_AVU_UNC = 'avu_unc'
+KEY_AVU_PAC_UNC = 'avu_pac_unc'
+KEY_AVU_PUI_UNC = 'avu_pui_unc'
+KEY_THRESH_UNC = 'avu_thresh_unc'
+
+PAVPU_UNC_THRESHOLD = 'adaptive-median' # [0.3, 'adaptive', 'adaptive-median']
+PAVPU_ENT_THRESHOLD = 0.5
+PAVPU_MIF_THRESHOLD = 0.1
+PAVPU_GRID_SIZE = (4,4,2)
+PAVPU_RATIO_NEG = 0.9
+
+KEY_ENT = 'ent'
+KEY_MIF = 'mif'
+KEY_STD = 'std'
+KEY_PERC = 'perc'
+
+KEY_SUM = 'sum'
+KEY_AVG = 'avg'
+
+CMAP_MAGMA = 'magma'
+CMAP_GRAY = 'gray'
+
+FILENAME_EVAL3D_JSON = 'res.json'
+
+FOLDERNAME_TMP = '_tmp'
+FOLDERNAME_TMP_BOKEH = 'bokeh-plots'
+FOLDERNAME_TMP_ENTMIF = 'entmif'
+
+VAL_ECE_NAN = -0.1
+VAL_DICE_NAN = -1.0
+VAL_MC_RUNS_DEFAULT = 20
+
+KEY_PATIENT_GLOBAL = 'global'
+
+SUFFIX_DET = '-Det'
+SUFFIX_MC = '-MC{}'
+
+KEY_MC_RUNS = 'MC_RUNS'
+KEY_TRAINING_BOOL = 'training_bool'
+
+############################################################
+# LOSSES RELATED #
+############################################################
+LOSS_DICE = 'Dice'
+LOSS_CE = 'CE'
+LOSS_CE_BASIC = 'CE-Basic'
+
+############################################################
+# DATALOADER RELATED #
+############################################################
+
+HEAD_AND_NECK = 'HaN'
+PROSTATE = 'Prostrate'
+THORACIC = 'Thoracic'
+
+DIRNAME_PROCESSED = 'processed'
+DIRNAME_PROCESSED_SPACING = 'processed_{}'
+DIRNAME_RAW = 'raw'
+DIRNAME_SAVE_3D = 'data_3D'
+
+FILENAME_JSON_IMG = 'img.json'
+FILENAME_JSON_MASK = 'mask.json'
+FILENAME_JSON_IMG_RESAMPLED = 'img_resampled.json'
+FILENAME_JSON_MASK_RESAMPLED = 'mask_resampled.json'
+FILENAME_CSV_IMG = 'img.csv'
+FILENAME_CSV_MASK = 'mask.csv'
+FILENAME_CSV_IMG_RESAMPLED = 'img_resampled.csv'
+FILENAME_CSV_MASK_RESAMPLED = 'mask_resampled.csv'
+
+FILENAME_VOXEL_INFO = 'voxelinfo.json'
+
+KEYNAME_LABEL_OARS = 'labels_oars'
+KEYNAME_LABEL_EXTERNAL = 'labels_external'
+KEYNAME_LABEL_TUMORS = 'labels_tumors'
+KEYNAME_LABEL_MISSING = 'labels_missing'
+
+DATAEXTRACTOR_WORKERS = 8
+
+import itk
+import numpy as np
+import tensorflow as tf
+import SimpleITK as sitk
+
+DATATYPE_VOXEL_IMG = np.int16
+DATATYPE_VOXEL_MASK = np.uint8
+
+DATATYPE_SITK_VOXEL_IMG = sitk.sitkInt16
+DATATYPE_SITK_VOXEL_MASK = sitk.sitkUInt8
+
+DATATYPE_ITK_VOXEL_MASK = itk.UC
+
+DATATYPE_NP_INT32 = np.int32
+
+DATATYPE_TF_STRING = tf.string
+DATATYPE_TF_UINT8 = tf.uint8
+DATATYPE_TF_INT16 = tf.int16
+DATATYPE_TF_INT32 = tf.int32
+DATATYPE_TF_FLOAT32 = tf.float32
+
+DUMMY_LABEL = 255
+
+# Keys - Volume params
+KEYNAME_PIXEL_SPACING = 'pixel_spacing'
+KEYNAME_ORIGIN = 'origin'
+KEYNAME_SHAPE = 'shape'
+KEYNAME_INTERCEPT = 'intercept'
+KEYNAME_SLOPE = 'slope'
+KEYNAME_ZVALS = 'z_vals'
+KEYNAME_MEAN_MIDPOINT = 'mean_midpoint'
+KEYNAME_OTHERS = 'others'
+KEYNAME_INTERPOLATOR = 'interpolator'
+KEYNAME_INTERPOLATOR_IMG = 'interpolator_img'
+KEYNAME_INTERPOLATOR_MASK = 'interpolator_mask'
+KEYNAME_SHAPE_ORIG = 'shape_orig'
+KEYNAME_SHAPE_RESAMPLED = 'shape_resampled'
+
+# Keys - Dataset params
+KEY_VOXELRESO = 'VOXEL_RESO'
+KEY_LABEL_MAP = 'LABEL_MAP'
+KEY_LABEL_MAP_FULL = 'LABEL_MAP_FULL'
+KEY_LABEL_MAP_EXTERNAL = 'LABEL_MAP_EXTERNAL'
+KEY_LABEL_COLORS = 'LABEL_COLORS'
+KEY_LABEL_WEIGHTS = 'LABEL_WEIGHTS'
+KEY_IGNORE_LABELS = 'IGNORE_LABELS'
+KEY_LABELID_BACKGROUND = 'LABELID_BACKGROUND'
+KEY_LABELID_MIDPOINT = 'LABELID_MIDPOINT'
+KEY_HU_MIN = 'HU_MIN'
+KEY_HU_MAX = 'HU_MAX'
+KEY_PREPROCESS = 'PREPROCESS'
+KEY_CROP = 'CROP'
+KEY_GRID_3D = 'GRID_3D'
+
+# Common Labels
+LABELNAME_BACKGROUND = 'Background'
+
+# Keys - Cropping
+KEY_MIDPOINT_EXTENSION_W_LEFT = 'MIDPOINT_EXTENSION_W_LEFT'
+KEY_MIDPOINT_EXTENSION_W_RIGHT = 'MIDPOINT_EXTENSION_W_RIGHT'
+KEY_MIDPOINT_EXTENSION_H_BACK = 'MIDPOINT_EXTENSION_H_BACK'
+KEY_MIDPOINT_EXTENSION_H_FRONT = 'MIDPOINT_EXTENSION_H_FRONT'
+KEY_MIDPOINT_EXTENSION_D_TOP = 'MIDPOINT_EXTENSION_D_TOP'
+KEY_MIDPOINT_EXTENSION_D_BOTTOM = 'MIDPOINT_EXTENSION_D_BOTTOM'
+
+# Keys - Gridding
+KEY_GRID_SIZE = 'GRID_SIZE'
+KEY_GRID_OVERLAP = 'GRID_OVERLAP'
+KEY_GRID_SAMPLER_PERC = 'GRID_SAMPLER_PERC'
+KEY_GRID_RANDOM_SHIFT_MAX = 'GRID_RANDOM_SHIFT_MAX'
+KEY_GRID_RANDOM_SHIFT_PERC = 'GRID_RANDOM_SHIFT_PERC'
+
+# Keys - .nrrd file keys
+KEY_NRRD_PIXEL_SPACING = 'space directions'
+KEY_NRRD_ORIGIN = 'space origin'
+KEY_NRRD_SHAPE = 'sizes'
+
+TYPE_VOXEL_ORIGSHAPE = 'orig'
+TYPE_VOXEL_RESAMPLED = 'resampled'
+
+MASK_TYPE_ONEHOT = 'one_hot'
+MASK_TYPE_COMBINED = 'combined'
+
+PREFETCH_BUFFER = 5
+
+################################### HaN - MICCAI 2015 ###################################
+
+DATASET_MICCAI2015 = 'miccai2015'
+DATALOADER_MICCAI2015_TRAIN = 'train'
+DATALOADER_MICCAI2015_TRAIN_ADD = 'train_additional'
+DATALOADER_MICCAI2015_TEST = 'test_offsite'
+DATALOADER_MICCAI2015_TESTONSITE = 'test_onsite'
+
+HaN_MICCAI2015 = {
+ KEY_LABEL_MAP : {
+ 'Background':0
+ , 'BrainStem':1 , 'Chiasm':2, 'Mandible':3
+ , 'OpticNerve_L':4, 'OpticNerve_R':5
+ , 'Parotid_L':6,'Parotid_R':7
+ ,'Submandibular_L':8, 'Submandibular_R':9
+ }
+ , KEY_LABEL_COLORS : {
+ 0: [255,255,255,10]
+ , 1:[0,110,254,255], 2: [225,128,128,255], 3:[254,0,128,255]
+ , 4:[191,50,191,255], 5:[254,128,254,255]
+ , 6: [182, 74, 74,255], 7:[128,128,0,255]
+ , 8:[50,105,161,255], 9:[46,194,194,255]
+ }
+ , KEY_LABEL_WEIGHTS : [1/4231347, 1/16453, 1/372, 1/32244, 1/444, 1/397, 1/16873, 1/17510, 1/4419, 1/4410] # avg voxel count
+ , KEY_IGNORE_LABELS : [0]
+ , KEY_LABELID_BACKGROUND : 0
+ , KEY_LABELID_MIDPOINT : 1
+ , KEY_HU_MIN : -125 # window_levl=50, window_width=350, 50 - (350/2) for soft tissue
+ , KEY_HU_MAX : 225 # window_levl=50, window_width=350, 50 + (350/2) for soft tissue
+ , KEY_VOXELRESO : (0.8, 0.8, 2.5) # [(0.8, 0.8, 2.5), (1.0, 1.0, 2.0) , (1,1,1), (1,1,2)]
+ , KEY_GRID_3D : {
+ TYPE_VOXEL_RESAMPLED:{
+ '(0.8, 0.8, 2.5)':{
+ KEY_GRID_SIZE : [140,140,40]
+ , KEY_GRID_OVERLAP : [20,20,0]
+ , KEY_GRID_SAMPLER_PERC : 0.90
+ , KEY_GRID_RANDOM_SHIFT_MAX : 40
+ , KEY_GRID_RANDOM_SHIFT_PERC : 0.5
+ }
+ }
+ }
+ , KEY_PREPROCESS:{
+ TYPE_VOXEL_RESAMPLED:{
+ KEY_CROP: {
+ '(0.8, 0.8, 2.5)':{
+ KEY_MIDPOINT_EXTENSION_W_LEFT : 120
+ ,KEY_MIDPOINT_EXTENSION_W_RIGHT : 120 # 240
+ ,KEY_MIDPOINT_EXTENSION_H_BACK : 66
+ ,KEY_MIDPOINT_EXTENSION_H_FRONT : 174 # 240
+ ,KEY_MIDPOINT_EXTENSION_D_TOP : 20
+ ,KEY_MIDPOINT_EXTENSION_D_BOTTOM : 60 # 80 [96(20-76) ]
+ }
+ }
+ }
+ }
+}
+
+
+PATIENT_MICCAI2015_TESTOFFSITE = 'HaN_MICCAI2015-test_offsite-{}_resample_True'
+FILENAME_SAVE_CT_MICCAI2015 = 'nrrd_HaN_MICCAI2015-test_offsite-{}_resample_True_img.nrrd'
+FILENAME_SAVE_GT_MICCAI2015 = 'nrrd_HaN_MICCAI2015-test_offsite-{}_resample_True_mask.nrrd'
+FILENAME_SAVE_PRED_MICCAI2015 = 'nrrd_HaN_MICCAI2015-test_offsite-{}_resample_True_maskpred.nrrd'
+FILENAME_SAVE_MIF_MICCAI2015 = 'nrrd_HaN_MICCAI2015-test_offsite-{}_resample_True_maskpredmif.nrrd'
+FILENAME_SAVE_ENT_MICCAI2015 = 'nrrd_HaN_MICCAI2015-test_offsite-{}_resample_True_maskpredent.nrrd'
+FILENAME_SAVE_STD_MICCAI2015 = 'nrrd_HaN_MICCAI2015-test_offsite-{}_resample_True_maskpredstd.nrrd'
+
+PATIENTIDS_MICCAI2015_TEST = ['0522c0555', '0522c0576', '0522c0598', '0522c0659', '0522c0661', '0522c0667', '0522c0669', '0522c0708', '0522c0727', '0522c0746']
+
+################################### HaN - DeepMindTCIA ###################################
+
+DATALOADER_DEEPMINDTCIA_TEST = 'test'
+DATALOADER_DEEPMINDTCIA_VAL = 'validation'
+DATALOADER_DEEPMINDTCIA_ONC = 'oncologist'
+DATALOADER_DEEPMINDTCIA_RAD = 'radiographer'
+DATASET_DEEPMINDTCIA = 'deepmindtcia'
+
+
+KEY_LABELMAP_MICCAI_DEEPMINDTCIA = KEY_LABEL_MAP + 'MICCAI_TCIADEEPMIND'
+HaN_DeepMindTCIA = {
+ KEY_LABELMAP_MICCAI_DEEPMINDTCIA: {
+ 'Background': 'Background'
+ , 'Brainstem': 'BrainStem'
+ , 'Mandible':'Mandible'
+ , 'Optic-Nerve-Lt': 'OpticNerve_L', 'Optic-Nerve-Rt': 'OpticNerve_R'
+ , 'Optic_Nerve_Lt': 'OpticNerve_L', 'Optic_Nerve_Rt': 'OpticNerve_R'
+ , 'Parotid-Lt':'Parotid_L', 'Parotid-Rt':'Parotid_R'
+ , 'Parotid_Lt':'Parotid_L', 'Parotid_Rt':'Parotid_R'
+ , 'Submandibular-Lt': 'Submandibular_L', 'Submandibular-Rt':'Submandibular_R'
+ , 'Submandibular_Lt': 'Submandibular_L', 'Submandibular_Rt':'Submandibular_R'
+ }
+ , KEY_LABEL_MAP : {
+ 'Background':0
+ , 'BrainStem':1 , 'Chiasm':2, 'Mandible':3
+ , 'OpticNerve_L':4, 'OpticNerve_R':5
+ , 'Parotid_L':6,'Parotid_R':7
+ ,'Submandibular_L':8, 'Submandibular_R':9
+ }
+ , KEY_LABEL_COLORS : {
+ 0: [255,255,255,10]
+ , 1:[0,110,254,255], 2: [225,128,128,255], 3:[254,0,128,255]
+ , 4:[191,50,191,255], 5:[254,128,254,255]
+ , 6: [182, 74, 74,255], 7:[128,128,0,255]
+ , 8:[50,105,161,255], 9:[46,194,194,255]
+ }
+ , KEY_LABEL_WEIGHTS : []
+ , KEY_IGNORE_LABELS : [0]
+ , KEY_LABELID_BACKGROUND : 0
+ , KEY_LABELID_MIDPOINT : 1
+ , KEY_HU_MIN : -125 # window_levl=50, window_width=350, 50 - (350/2) for soft tissue
+ , KEY_HU_MAX : 225 # window_levl=50, window_width=350, 50 + (350/2) for soft tissue
+ , KEY_VOXELRESO : (0.8, 0.8, 2.5) # [(0.8, 0.8, 2.5), (1.0, 1.0, 2.0) , (1,1,1), (1,1,2)]
+ , KEY_GRID_3D : {
+ TYPE_VOXEL_RESAMPLED:{
+ '(0.8, 0.8, 2.5)':{
+ KEY_GRID_SIZE : [140,140,40]
+ , KEY_GRID_OVERLAP : [20,20,0]
+ , KEY_GRID_SAMPLER_PERC : 0.90
+ , KEY_GRID_RANDOM_SHIFT_MAX : 40
+ , KEY_GRID_RANDOM_SHIFT_PERC : 0.5
+ }
+ }
+ }
+ , KEY_PREPROCESS:{
+ TYPE_VOXEL_RESAMPLED:{
+ KEY_CROP: {
+ '(0.8, 0.8, 2.5)':{
+ KEY_MIDPOINT_EXTENSION_W_LEFT : 120
+ ,KEY_MIDPOINT_EXTENSION_W_RIGHT : 120 # 240
+ ,KEY_MIDPOINT_EXTENSION_H_BACK : 66
+ ,KEY_MIDPOINT_EXTENSION_H_FRONT : 174 # 240
+ ,KEY_MIDPOINT_EXTENSION_D_TOP : 20
+ ,KEY_MIDPOINT_EXTENSION_D_BOTTOM : 60 # 80 [96(20-76) ]
+ }
+ }
+ }
+ }
+}
+
+PATIENTIDS_DEEPMINDTCIA_TEST = ['0522c0331', '0522c0416', '0522c0419', '0522c0629', '0522c0768', '0522c0770', '0522c0773', '0522c0845', 'TCGA-CV-7236', 'TCGA-CV-7243', 'TCGA-CV-7245', 'TCGA-CV-A6JO', 'TCGA-CV-A6JY', 'TCGA-CV-A6K0', 'TCGA-CV-A6K1']
+
+PATIENT_DEEPMINDTCIA_TEST_ONC = 'HaN_DeepMindTCIA-test-oncologist-{}_resample_True'
+FILENAME_SAVE_CT_DEEPMINDTCIA_TEST_ONC = 'nrrd_HaN_DeepMindTCIA-test-oncologist-{}_resample_True_img.nrrd'
+FILENAME_SAVE_GT_DEEPMINDTCIA_TEST_ONC = 'nrrd_HaN_DeepMindTCIA-test-oncologist-{}_resample_True_mask.nrrd'
+FILENAME_SAVE_PRED_DEEPMINDTCIA_TEST_ONC = 'nrrd_HaN_DeepMindTCIA-test-oncologist-{}_resample_True_maskpred.nrrd'
+FILENAME_SAVE_MIF_DEEPMINDTCIA_TEST_ONC = 'nrrd_HaN_DeepMindTCIA-test-oncologist-{}_resample_True_maskpredmif.nrrd'
+FILENAME_SAVE_ENT_DEEPMINDTCIA_TEST_ONC = 'nrrd_HaN_DeepMindTCIA-test-oncologist-{}_resample_True_maskpredent.nrrd'
+FILENAME_SAVE_STD_DEEPMINDTCIA_TEST_ONC = 'nrrd_HaN_DeepMindTCIA-test-oncologist-{}_resample_True_maskpredstd.nrrd'
+
+############################################################
+# VISUALIZATION #
+############################################################
+FIGSIZE=(15,15)
+IGNORE_LABELS = []
+PREDICT_THRESHOLD_MASK = 0.6
+
+ENT_MIN, ENT_MAX = 0.0, 0.5
+MIF_MIN, MIF_MAX = 0.0, 0.1
\ No newline at end of file
diff --git a/src/dataloader/__init__.py b/src/dataloader/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/dataloader/augmentations.py b/src/dataloader/augmentations.py
new file mode 100644
index 0000000..f941252
--- /dev/null
+++ b/src/dataloader/augmentations.py
@@ -0,0 +1,592 @@
+# Import private libraries
+import src.config as config
+import src.dataloader.utils as utils
+
+
+# Import public libraries
+import pdb
+import math
+import traceback
+import numpy as np
+import matplotlib.pyplot as plt
+from pathlib import Path
+import SimpleITK as sitk
+import tensorflow as tf
+import tensorflow_addons as tfa
+
+class Rotate3D:
+
+ def __init__(self):
+ self.name = 'Rotate3D'
+
+ @tf.function
+ def execute(self, x, y, meta1, meta2):
+ """
+ Rotates a 3D image along the z-axis by some random angle
+ - Ref: https://www.tensorflow.org/api_docs/python/tf/image/rot90
+
+ Parameters
+ ----------
+ x: tf.Tensor
+ This is the 3D image of dtype=tf.int16 and shape=(H,W,C,1)
+ y: tf.Tensor
+ This is the 3D mask of dtype=tf.uint8 and shape=(H,W,C,Labels)
+ meta1 = tf.Tensor
+ This contains some indexing and meta info. Irrelevant to this function
+ meta2 = tf.Tensor
+ This contains some string information on patient identification. Irrelevant to this function
+ """
+
+ try:
+ if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= 0.2:
+ rotate_count = tf.random.uniform([], minval=1, maxval=4, dtype=tf.dtypes.int32)
+
+ # k=3 (rot=270) (anti-clockwise)
+ if rotate_count == 3:
+ xtmp = tf.transpose(tf.reverse(x, [0]), [1,0,2,3])
+ ytmp = tf.transpose(tf.reverse(y, [0]), [1,0,2,3])
+
+ # k = 1 (rot=90) (anti-clockwise)
+ elif rotate_count == 1:
+ xtmp = tf.reverse(tf.transpose(x, [1,0,2,3]), [0])
+ ytmp = tf.reverse(tf.transpose(y, [1,0,2,3]), [0])
+
+ # k=2 (rot=180) (clock-wise)
+ elif rotate_count == 2:
+ xtmp = tf.reverse(x, [0,1])
+ ytmp = tf.reverse(y, [0,1])
+
+ else:
+ xtmp = x
+ ytmp = y
+
+ return xtmp, ytmp, meta1, meta2
+ # return xtmp.read_value(), ytmp.read_value(), meta1, meta2
+
+ else:
+ return x, y, meta1, meta2
+ except:
+ traceback.print_exc()
+ return x, y, meta1, meta2
+
+ @tf.function
+ def execute2(self, x_moving, x_fixed, y_moving, y_fixed, meta1, meta2):
+
+ try:
+ if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= 0.5:
+ rotate_count = tf.random.uniform([], minval=1, maxval=4, dtype=tf.dtypes.int32)
+
+ # k=3 (rot=270) (anti-clockwise)
+ if rotate_count == 3:
+ x_moving_tmp = tf.transpose(tf.reverse(x_moving, [0]), [1,0,2,3])
+ x_fixed_tmp = tf.transpose(tf.reverse(x_fixed, [0]), [1,0,2,3])
+ y_moving_tmp = tf.transpose(tf.reverse(y_moving, [0]), [1,0,2,3])
+ y_fixed_tmp = tf.transpose(tf.reverse(y_fixed, [0]), [1,0,2,3])
+
+ # k = 1 (rot=90) (anti-clockwise)
+ elif rotate_count == 1:
+ x_moving_tmp = tf.reverse(tf.transpose(x_moving, [1,0,2,3]), [0])
+ x_fixed_tmp = tf.reverse(tf.transpose(x_fixed, [1,0,2,3]), [0])
+ y_moving_tmp = tf.reverse(tf.transpose(y_moving, [1,0,2,3]), [0])
+ y_fixed_tmp = tf.reverse(tf.transpose(y_fixed, [1,0,2,3]), [0])
+
+ # k=2 (rot=180) (clock-wise)
+ elif rotate_count == 2:
+ x_moving_tmp = tf.reverse(x_moving, [0,1])
+ x_fixed_tmp = tf.reverse(x_fixed, [0,1])
+ y_moving_tmp = tf.reverse(y_moving, [0,1])
+ y_fixed_tmp = tf.reverse(y_fixed, [0,1])
+
+ else:
+ x_moving_tmp = x_moving
+ x_fixed_tmp = x_fixed
+ y_moving_tmp = y_moving
+ y_fixed_tmp = y_fixed
+
+ return x_moving_tmp, x_fixed_tmp, y_moving_tmp, y_fixed_tmp, meta1, meta2
+
+ else:
+ return x_moving, x_fixed, y_moving, y_fixed, meta1, meta2
+
+ except:
+ tf.print(' - [ERROR][Rotate3D][execute2]')
+ return x_moving, x_fixed, y_moving, y_fixed, meta1, meta2
+
+class Rotate3DSmall:
+
+ def __init__(self, label_map, mask_type, prob=0.2, angle_degrees=15, interpolation='bilinear'):
+
+ self.name = 'Rotate3DSmall'
+
+ self.label_ids = label_map.values()
+ self.class_count = len(label_map)
+ self.mask_type = mask_type
+
+ self.prob = prob
+ self.angle_degrees = angle_degrees
+ self.interpolation = interpolation
+
+ @tf.function
+ def execute(self, x, y, meta1, meta2):
+ """
+ - Ref: https://www.tensorflow.org/addons/api_docs/python/tfa/image/rotate
+
+ """
+
+ x = tf.cast(x, dtype=tf.float32)
+ y = tf.cast(y, dtype=tf.float32)
+
+ if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.prob:
+
+ angle_radians = tf.random.uniform([], minval=math.radians(-self.angle_degrees), maxval=math.radians(self.angle_degrees)
+ , dtype=tf.dtypes.float32)
+
+
+ if self.mask_type == config.MASK_TYPE_ONEHOT:
+ """
+ - x: [H,W,D,1]
+ - y: [H,W,D,class]
+ """
+
+ x = tf.expand_dims(x[:,:,:,0], axis=0) # [1,H,W,D]
+ x = tf.transpose(x) # [D,W,H,1]
+ x = tfa.image.rotate(x, angle_radians, interpolation='bilinear') # [D,W,H,1]
+ x = tf.transpose(x) # [1,H,W,D]
+ x = tf.expand_dims(x[0], axis=-1) # [H,W,D,1]
+
+ y = tf.concat([
+ tf.expand_dims(
+ tf.transpose( # [1,H,W,D]
+ tfa.image.rotate( # [D,W,H,1]
+ tf.transpose( # [D,W,H,1]
+ tf.expand_dims(y[:,:,:,class_id], axis=0) # [1,H,W,D]
+ )
+ , angle_radians, interpolation='bilinear'
+ )
+ )[0] # [H,W,D]
+ , axis=-1 # [H,W,D,1]
+ ) for class_id in range(self.class_count)
+ ], axis=-1) # [H,W,D,10]
+ y = tf.where(tf.math.greater_equal(y,0.5), 1.0, y)
+ y = tf.where(tf.math.less(y,0.5), 0.0, y)
+
+ elif self.mask_type == config.MASK_TYPE_COMBINED:
+ """
+ - x: [H,W,D]
+ - y: [H,W,D]
+ """
+
+ x = tf.expand_dims(x,axis=0) # [1,H,W,D]
+ x = tf.transpose(x) # [D,W,H,1]
+ x = tfa.image.rotate(x, angle_radians, interpolation=self.interpolation) # [D,W,H,1]
+ x = tf.transpose(x) # [1,H,W,D]
+ x = x[0] # [H,W,D]
+
+ y = tf.concat([tf.expand_dims(tf.math.equal(y, label), axis=-1) for label in self.label_ids], axis=-1) # [H,W,D,L]
+ y = tf.cast(y, dtype=tf.float32)
+ y = tf.concat([
+ tf.expand_dims(
+ tf.transpose( # [1,H,W,D]
+ tfa.image.rotate( # [D,W,H,1]
+ tf.transpose( # [D,W,H,1]
+ tf.expand_dims(y[:,:,:,class_id], axis=0) # [1,H,W,D]
+ )
+ , angle_radians, interpolation='bilinear'
+ )
+ )[0] # [H,W,D]
+ , axis=-1 # [H,W,D,1]
+ ) for class_id in range(self.class_count)
+ ], axis=-1) # [H,W,D,L]
+ y = tf.where(tf.math.greater_equal(y,0.5), 1.0, y)
+ y = tf.where(tf.math.less(y,0.5), 0.0, y)
+ y = tf.math.argmax(y, axis=-1) # [H,W,D]
+
+ x = tf.cast(x, dtype=tf.float32)
+ y = tf.cast(y, dtype=tf.float32)
+
+ else:
+ x = tf.cast(x, dtype=tf.float32)
+ y = tf.cast(y, dtype=tf.float32)
+
+ return x, y, meta1, meta2
+
+class Rotate3DSmallZ:
+
+ def __init__(self, label_map, mask_type, prob=0.2, angle_degrees=5, interpolation='bilinear'):
+
+ self.name = 'Rotate3DSmallZ'
+
+ self.label_ids = label_map.values()
+ self.class_count = len(label_map)
+ self.mask_type = mask_type
+
+ self.prob = prob
+ self.angle_degrees = angle_degrees
+ self.interpolation = interpolation
+
+ @tf.function
+ def execute(self, x, y, meta1, meta2):
+ """
+ Params
+ ------
+ x: [H,W,D,1]
+ y: [H,W,D,C]
+ - Ref: https://www.tensorflow.org/addons/api_docs/python/tfa/image/rotate
+
+ """
+
+ x = tf.cast(x, dtype=tf.float32)
+ y = tf.cast(y, dtype=tf.float32)
+
+ if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.prob:
+
+ angle_radians = tf.random.uniform([], minval=math.radians(-self.angle_degrees), maxval=math.radians(self.angle_degrees)
+ , dtype=tf.dtypes.float32)
+
+ if self.mask_type == config.MASK_TYPE_ONEHOT:
+ """
+ - x: [H,W,D,1]
+ - y: [H,W,D,class]
+ """
+
+ x = tfa.image.rotate(x, angle_radians, interpolation='bilinear') # [H,W,D,1]
+
+ y = tf.concat([
+ tfa.image.rotate(
+ tf.expand_dims(y[:,:,:,class_id], axis=-1) # [H,W,D,1]
+ , angle_radians, interpolation='bilinear'
+ )
+ for class_id in range(self.class_count)
+ ], axis=-1) # [H,W,D,10]
+ y = tf.where(tf.math.greater_equal(y,0.5), 1.0, y)
+ y = tf.where(tf.math.less(y,0.5), 0.0, y)
+
+ elif self.mask_type == config.MASK_TYPE_COMBINED:
+ pass
+
+ x = tf.cast(x, dtype=tf.float32)
+ y = tf.cast(y, dtype=tf.float32)
+
+ else:
+ x = tf.cast(x, dtype=tf.float32)
+ y = tf.cast(y, dtype=tf.float32)
+
+ return x, y, meta1, meta2
+
+class Translate:
+
+ def __init__(self, label_map, translations=[40,40], prob=0.2):
+
+ self.translations = translations
+ self.prob = prob
+ self.label_ids = label_map.values()
+ self.class_count = len(label_map)
+ self.name = 'Translate'
+
+
+
+ @tf.function
+ def execute(self,x,y,meta1,meta2):
+ """
+ Params
+ ------
+ x: [H,W,D,1]
+ y: [H,W,D,class]
+
+ Ref
+ ---
+ - tfa.image.translate(image): image= (num_images, num_rows, num_columns, num_channels)
+ """
+
+ if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.prob:
+
+ translate_x = tf.random.uniform([], minval=-self.translations[0], maxval=self.translations[0], dtype=tf.dtypes.int32)
+ translate_y = tf.random.uniform([], minval=-self.translations[1], maxval=self.translations[1], dtype=tf.dtypes.int32)
+
+ x = tf.expand_dims(x[:,:,:,0], axis=0) # [1,H,W,D]
+ x = tf.transpose(x) # [D,W,H,1]
+ x = tfa.image.translate(x, [translate_x, translate_y], interpolation='bilinear') # [D,W,H,1]; x=(num_images, num_rows, num_columns, num_channels)
+ x = tf.transpose(x) # [1,H,W,D]
+ x = tf.expand_dims(x[0], axis=-1) # [H,W,D,1]
+
+ y = tf.concat([
+ tf.expand_dims(
+ tf.transpose( # [1,H,W,D]
+ tfa.image.translate( # [D,W,H,1]
+ tf.transpose( # [D,W,H,1]
+ tf.expand_dims(y[:,:,:,class_id], axis=0) # [1,H,W,D]
+ )
+ , [translate_x, translate_y], interpolation='bilinear'
+ )
+ )[0] # [H,W,D]
+ , axis=-1 # [H,W,D,1]
+ ) for class_id in range(self.class_count)
+ ], axis=-1) # [H,W,D,10]
+
+ return x,y,meta1,meta2
+
+class Noise:
+
+ def __init__(self, x_shape, mean=0.0, std=0.1, prob=0.2):
+
+ self.mean = mean
+ self.std = std
+ self.prob = prob
+ self.x_shape = x_shape
+ self.name = 'Noise'
+
+ @tf.function
+ def execute(self,x,y,meta1,meta2):
+ """
+ Params
+ ------
+ x: [H,W,D,1]
+ y: [H,W,D,class]
+ """
+
+ if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.prob:
+
+ x = x + tf.random.normal(self.x_shape, self.mean, self.std)
+
+ return x,y,meta1,meta2
+
+class Deform2Punt5D:
+
+ def __init__(self, img_shape, label_map, grid_points=50, stddev=2.0, div_factor=2, prob=0.2, debug=False):
+ """
+ img_shape = [H,W] or [H,W,D]
+ """
+
+ self.name = 'Deform2Punt5D'
+
+ if debug:
+ import os
+ import psutil
+ import pynvml
+ pynvml.nvmlInit()
+ self.device_id = pynvml.nvmlDeviceGetHandleByIndex(0)
+ self.process = psutil.Process(os.getpid())
+
+ self.img_shape = img_shape[:2] # no deformation in z-dimension if img=3D
+ self.depth = img_shape[-1]
+ self.grid_points = grid_points
+ self.stddev = stddev
+ self.div_factor = int(div_factor)
+ self.debug = debug
+ self.label_ids = label_map.values()
+ self.prob = prob
+
+ self.flattened_grid_locations = []
+
+ self._get_grid_control_points()
+ self._get_grid_locations(*self.img_shape)
+
+ def _get_grid_control_points(self):
+
+ # Step 1 - Define grid shape using control grid spacing & final image shape
+ grid_shape = np.zeros(len(self.img_shape), dtype=int)
+
+ for idx in range(len(self.img_shape)):
+ num_elem = float(self.img_shape[idx])
+ if num_elem % 2 == 0:
+ grid_shape[idx] = np.ceil( (num_elem - 1) / (2*self.grid_points) + 0.5) * 2 + 2
+ else:
+ grid_shape[idx] = np.ceil((num_elem - 1) / (2*self.grid_points)) * 2 + 3
+
+ coords = []
+ for i, size in enumerate(grid_shape):
+ coords.append(tf.linspace(-(size - 1) / 2*self.grid_points, (size - 1) / 2*self.grid_points, size))
+ permutation = np.roll(np.arange(len(coords) + 1), -1)
+ self.grid_control_points_orig = tf.cast(tf.transpose(tf.meshgrid(*coords, indexing="ij"), permutation), dtype=tf.float32)
+ self.grid_control_points_orig += tf.expand_dims(tf.expand_dims(tf.constant(self.img_shape, dtype=tf.float32)/2.0,0),0)
+ self.grid_control_points = tf.reshape(self.grid_control_points_orig, [-1,2])
+
+ def _get_grid_locations(self, image_height, image_width):
+
+ image_height = image_height//self.div_factor
+ image_width = image_width//self.div_factor
+
+ y_range = np.linspace(0, image_height - 1, image_height)
+ x_range = np.linspace(0, image_width - 1, image_width)
+ y_grid, x_grid = np.meshgrid(y_range, x_range, indexing="ij") # grid_x, grid_y = tf.meshgrid(tf.range(width), tf.range(height))
+ grid_locations = np.stack((y_grid, x_grid), -1)
+
+ flattened_grid_locations = np.reshape(grid_locations, [image_height * image_width, 2])
+ self.flattened_grid_locations = tf.cast(tf.expand_dims(flattened_grid_locations, 0), dtype=tf.float32)
+
+ return flattened_grid_locations
+
+ @tf.function
+ def _get_dense_flow(self, grid_control_points, grid_control_points_new):
+ """
+ params
+ -----
+ grid_control_points: [B, points, 2]
+ grid_control_points_new: [B, points, 2]
+
+ return
+ dense_flows: [1,H,W,2]
+ """
+
+ # Step 1 - Get flows
+ source_control_point_locations = tf.cast(grid_control_points/self.div_factor, dtype=tf.float32)
+ dest_control_point_locations = tf.cast(grid_control_points_new/self.div_factor, dtype=tf.float32)
+ control_point_flows = dest_control_point_locations - source_control_point_locations
+
+ # Step 2 - Get dense flow via bspline interp
+ if self.debug:
+ import pynvml
+ tf.print (' - [_get_dense_flow()] GPU mem: ', '%.4f' % (pynvml.nvmlDeviceGetMemoryInfo(self.device_id).used/1024/1024/1000),'GB), ')
+ tf.print (' - [_get_dense_flow()] img size: ', self.img_shape)
+ tf.print (' - [_get_dense_flow()] img size for bspline interp: ', self.img_shape[0]//self.div_factor, self.img_shape[1]//self.div_factor)
+ flattened_flows = tfa.image.interpolate_spline(
+ dest_control_point_locations,
+ control_point_flows,
+ self.flattened_grid_locations,
+ order=2
+ )
+ if self.debug:
+ import pynvml
+ tf.print (' - [_get_dense_flow()] GPU mem: ', '%.4f' % (pynvml.nvmlDeviceGetMemoryInfo(self.device_id).used/1024/1024/1000),'GB), ')
+
+ # Step 3 - Reshape dense flow to original image size
+ dense_flows = tf.reshape(flattened_flows, [1, self.img_shape[0]//self.div_factor, self.img_shape[1]//self.div_factor, 2])
+ dense_flows = tf.image.resize(dense_flows, (self.img_shape[0], self.img_shape[1]))
+ if self.debug:
+ print (' - [_get_dense_flow()] dense flow: ', dense_flows.shape)
+
+ return dense_flows
+
+ def execute2D(self, x, y, show=True):
+ """
+ x = [H,W]
+ y = [H,W,L]
+ """
+
+ # Step 1 - Get new control points and dense flow for each slice
+ grid_control_points_new = self.grid_control_points + tf.random.normal(self.grid_control_points.shape, 0, self.stddev)
+ x_flow = self._get_dense_flow(tf.expand_dims(self.grid_control_points,0), tf.expand_dims(grid_control_points_new,0)) # [1,H,W,2]
+
+ # Step 2 - Transform
+ x_tf = tf.expand_dims(tf.expand_dims(x, -1), 0) # [1,H,W,1]
+ y_tf = tf.expand_dims(y, 0) # [1,H,W,L]
+ if 0:
+ x_tf, x_flow = tfa.image.sparse_image_warp(x_tf, tf.expand_dims(self.grid_control_points,0), tf.expand_dims(grid_control_points_new,0))
+ else:
+ x_tf = tfa.image.dense_image_warp(x_tf, x_flow) # [1,H,W,1]
+ y_tf = tf.concat([
+ tfa.image.dense_image_warp(
+ tf.expand_dims(y_tf[:,:,:,class_id],-1), x_flow
+ ) for class_id in self.label_ids]
+ , axis=-1
+ ) # [B,H,W,L]
+ y_tf = tf.where(tf.math.greater_equal(y_tf, 0.5), 1.0, y_tf)
+ y_tf = tf.where(tf.math.less(y_tf, 0.5), 0.0, y_tf)
+
+
+ # Step 99 - Show
+ if show:
+ f,axarr = plt.subplots(2,2, figsize=(15,10))
+ axarr[0][0].imshow(x, cmap='gray')
+ y_plot = tf.argmax(y, axis=-1) # [H,W]
+ axarr[0][0].imshow(y_plot, alpha=0.5)
+ grid_og = self.grid_control_points_orig
+ axarr[0][0].plot(grid_og[:,:,1], grid_og[:,:,0], 'y-.', alpha=0.2)
+ axarr[0][0].plot(tf.transpose(grid_og[:,:,1]), tf.transpose(grid_og[:,:,0]), 'y-.', alpha=0.2)
+ axarr[0][0].set_xlim([0, self.img_shape[1]])
+ axarr[0][0].set_ylim([0, self.img_shape[0]])
+
+ axarr[0][1].imshow(x_tf[0,:,:,0], cmap='gray')
+ y_tf_plot = tf.argmax(y_tf, axis=-1)[0,:,:] # [H,W]
+ axarr[0][1].imshow(y_tf_plot, alpha=0.5)
+ grid_new = tf.reshape(grid_control_points_new, self.grid_control_points_orig.shape)
+ axarr[0][1].plot(grid_new[:,:,1], grid_new[:,:,0], 'y-.', alpha=0.2)
+ axarr[0][1].plot(tf.transpose(grid_new[:,:,1]), tf.transpose(grid_new[:,:,0]), 'y-.', alpha=0.2)
+ axarr[0][1].set_xlim([0, self.img_shape[1]])
+ axarr[0][1].set_ylim([0, self.img_shape[0]])
+
+ axarr_dy = axarr[1][0].imshow(x_flow[0,:,:,1], vmin=-7.5, vmax=7.5, interpolation='none', cmap='rainbow')
+ f.colorbar(axarr_dy, ax=axarr[1][0], extend='both')
+ axarr_dx = axarr[1][1].imshow(x_flow[0,:,:,0], vmin=-7.5, vmax=7.5, interpolation='none', cmap='rainbow')
+ f.colorbar(axarr_dx, ax=axarr[1][1], extend='both')
+
+ plt.suptitle('Image Size: {} \n Control Points: {}\n StdDev: {}\nDiv Factor={}'.format(self.img_shape, self.grid_points, self.stddev, self.div_factor))
+ plt.show()
+ pdb.set_trace()
+
+ return x_tf, y_tf
+
+ @tf.function
+ def execute(self, x, y, meta1, meta2, show=False):
+ """
+ x = [H,W,D,1]
+ y = [H,W,D,L]
+ No deformation in z-axis
+ """
+
+ if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.prob:
+ if self.debug:
+ tf.print (' - [Deform2Punt5D()][execute()] img_shape: ', self.img_shape, ' || depth: ', self.depth)
+
+ # Step 1 - Get new control points and dense flow for each slice
+ grid_control_points_new = self.grid_control_points + tf.random.normal(self.grid_control_points.shape, 0, self.stddev)
+ x_flow = self._get_dense_flow(tf.expand_dims(self.grid_control_points,0), tf.expand_dims(grid_control_points_new,0)) # [1,H,W,2]
+
+ # Step 2 - Transform
+ if show:
+ idx = 70
+ import matplotlib.pyplot as plt
+ f,axarr = plt.subplots(2,2, figsize=(15,10))
+
+ x_slice = x[:,:,idx,0] # [H,W]
+ axarr[0][0].imshow(x_slice, cmap='gray')
+ y_slice = tf.argmax(y[:,:,idx,:], axis=-1) # [H,W]
+ axarr[0][0].imshow(y_slice, cmap='gray', alpha=0.5)
+ grid_og = self.grid_control_points_orig
+ axarr[0][0].plot(grid_og[:,:,1], grid_og[:,:,0], 'y-.', alpha=0.2)
+ axarr[0][0].plot(tf.transpose(grid_og[:,:,1]), tf.transpose(grid_og[:,:,0]), 'y-.', alpha=0.5)
+ axarr[0][0].set_xlim([0, self.img_shape[1]])
+ axarr[0][0].set_ylim([0, self.img_shape[0]])
+
+ x = tf.transpose(x, [2,0,1,3]) # [H,W,D,1] -> [D,H,W,1]
+ y = tf.transpose(y, [2,0,1,3]) # [H,W,D,L] -> [D,H,W,L]
+
+ if self.debug:
+ import pynvml
+ tf.print (' - [Deform2Punt5D][execute()] GPU mem: ', '%.4f' % (pynvml.nvmlDeviceGetMemoryInfo(self.device_id).used/1024/1024/1000),'GB, ')
+ x_flow_repeat = tf.repeat(x_flow, self.depth, axis=0) # [B,H,W,2] or [D,H,W,2]
+ if self.debug:
+ tf.print (' - [execute()] x_tf: ', x.shape, ' || x_flow_repeat: ', x_flow_repeat.shape)
+ x = tfa.image.dense_image_warp(x, x_flow_repeat)
+ x = tf.transpose(x, [1,2,0,3]) # [D,H,W,1] -> [H,W,D,1]
+ y = tf.concat([
+ tfa.image.dense_image_warp(
+ tf.expand_dims(y[:,:,:,class_id],-1), x_flow_repeat
+ ) for class_id in self.label_ids]
+ , axis=-1
+ ) # [D,H,W,L]
+ y = tf.where(tf.math.greater_equal(y, 0.5), 1.0, y)
+ y = tf.where(tf.math.less(y, 0.5), 0.0, y)
+ y = tf.transpose(y, [1,2,0,3]) # [D,H,W,L] --> [H,W,D,L]
+ if self.debug:
+ import pynvml
+ tf.print (' - [execute()] GPU mem: ', '%.4f' % (pynvml.nvmlDeviceGetMemoryInfo(self.device_id).used/1024/1024/1000),'GB, ')
+
+ if show:
+
+ x_slice = x[:,:,idx,0] # [H,W]
+ axarr[0][1].imshow(x_slice, cmap='gray')
+ y_slice = tf.argmax(y[:,:,idx,:], axis=-1) # [H,W]
+ axarr[0][1].imshow(y_slice, cmap='gray', alpha=0.5)
+ grid_new = tf.reshape(grid_control_points_new, self.grid_control_points_orig.shape)
+ axarr[0][1].plot(grid_new[:,:,1], grid_new[:,:,0], 'y-.', alpha=0.2)
+ axarr[0][1].plot(tf.transpose(grid_new[:,:,1]), tf.transpose(grid_new[:,:,0]), 'y-.', alpha=0.5)
+ axarr[0][1].set_xlim([0, self.img_shape[1]])
+ axarr[0][1].set_ylim([0, self.img_shape[0]])
+
+ axarr[1][0].imshow(x_flow[0,:,:,1], cmap='gray')
+ axarr[1][1].imshow(x_flow[0,:,:,0], cmap='gray')
+ plt.show()
+
+ # return x_tf, y, meta1, meta2
+ return x, y, meta1, meta2
diff --git a/src/dataloader/dataset.py b/src/dataloader/dataset.py
new file mode 100644
index 0000000..b359813
--- /dev/null
+++ b/src/dataloader/dataset.py
@@ -0,0 +1,53 @@
+import tensorflow as tf
+
+class ZipDataset:
+
+ def __init__(self, datasets):
+ self.datasets = datasets
+ self.datasets_generators = []
+ self._init_constants()
+
+
+ def _init_constants(self):
+ self.HU_MIN = self.datasets[0].HU_MIN
+ self.HU_MAX = self.datasets[0].HU_MAX
+ self.dataset_dir_processed = self.datasets[0].dataset_dir_processed
+ self.grid = self.datasets[0].grid
+ self.pregridnorm = self.datasets[0].pregridnorm
+
+ def __len__(self):
+ length = 0
+ for dataset in self.datasets:
+ length += len(dataset)
+
+ return length
+
+ def generator(self):
+ for dataset in self.datasets:
+ self.datasets_generators.append(dataset.generator())
+ return tf.data.experimental.sample_from_datasets(self.datasets_generators) #<_DirectedInterleaveDataset shapes: (, , , ), types: (tf.float32, tf.float32, tf.int16, tf.string)>
+
+ def get_subdataset(self, param_name):
+ if type(param_name) == str:
+ for dataset in self.datasets:
+ if dataset.name == param_name:
+ return dataset
+ else:
+ print (' - [ERROR][ZipDataset] param_name needs to a str')
+
+ return None
+
+ def get_label_map(self, label_map_full=False):
+ if label_map_full:
+ return self.datasets[0].LABEL_MAP_FULL
+ else:
+ return self.datasets[0].LABEL_MAP
+
+ def get_label_colors(self):
+ return self.datasets[0].LABEL_COLORS
+
+ def get_label_weights(self):
+ return self.datasets[0].LABEL_WEIGHTS
+
+ def get_mask_type(self, idx=0):
+ return self.datasets[idx].mask_type
\ No newline at end of file
diff --git a/src/dataloader/extractors/__init__.py b/src/dataloader/extractors/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/dataloader/extractors/han_deepmindtcia.py b/src/dataloader/extractors/han_deepmindtcia.py
new file mode 100644
index 0000000..a80175a
--- /dev/null
+++ b/src/dataloader/extractors/han_deepmindtcia.py
@@ -0,0 +1,339 @@
+# Import private libraries
+import medloader.dataloader.config as config
+import medloader.dataloader.utils as utils
+
+# Import public libraries
+import pdb
+import tqdm
+import nrrd
+import copy
+import traceback
+import numpy as np
+from pathlib import Path
+
+class HaNDeepMindTCIADownloader:
+
+ def __init__(self, dataset_dir_raw, dataset_dir_processed):
+
+ self.class_name = 'HaNDeepMindTCIADownloader'
+ self.dataset_dir_raw = dataset_dir_raw
+ self.dataset_dir_processed = dataset_dir_processed
+
+ self.url_zip = 'https://github.com/deepmind/tcia-ct-scan-dataset/archive/refs/heads/master.zip'
+ self.unzipped_folder_name = self.url_zip.split('/')[-5] + '-' + self.url_zip.split('/')[-1].split('.')[0]
+
+ def download(self):
+
+ # Step 1 - Make raw directory
+ self.dataset_dir_raw.mkdir(parents=True, exist_ok=True)
+
+ # Step 2 - Download .zip and then unzip it
+ path_zip_folder = Path(self.dataset_dir_raw, self.unzipped_folder_name + '.zip')
+
+ if not Path(path_zip_folder).exists():
+ utils.download_zip(self.url_zip, path_zip_folder, self.dataset_dir_raw)
+ else:
+ utils.read_zip(path_zip_folder, self.dataset_dir_raw)
+
+ def sort(self):
+
+ path_download_nrrds = Path(self.dataset_dir_raw).joinpath(self.unzipped_folder_name, 'nrrds')
+ if Path(path_download_nrrds).exists():
+ path_download_nrrds_test_src = Path(path_download_nrrds).joinpath('test')
+ path_download_nrrds_val_src = Path(path_download_nrrds).joinpath('validation')
+
+ path_nrrds_test_dest = Path(self.dataset_dir_raw).joinpath('test')
+ path_nrrds_val_dest = Path(self.dataset_dir_raw).joinpath('validation')
+
+ utils.move_folder(path_download_nrrds_test_src, path_nrrds_test_dest)
+ utils.move_folder(path_download_nrrds_val_src, path_nrrds_val_dest)
+ else:
+ print (' - [ERROR][{}] Could not find nrrds folder: {}'.format(self.class_name, path_download_nrrds))
+
+
+class HaNDeepMindTCIAExtractor:
+
+ def __init__(self, name, dataset_dir_raw, dataset_dir_processed, dataset_dir_datatypes):
+
+ self.name = name
+ self.class_name = 'HaNDeepMindTCIAExtractor'
+ self.dataset_dir_raw = dataset_dir_raw
+ self.dataset_dir_processed = dataset_dir_processed
+ self.dataset_dir_datatypes = dataset_dir_datatypes
+
+ self._preprint()
+ self._init_constants()
+
+ def _preprint(self):
+ self.VOXEL_RESO = getattr(config, self.name)[config.KEY_VOXELRESO]
+ print ('')
+ print (' - [{}] VOXEL_RESO: {}'.format(self.class_name, self.VOXEL_RESO))
+ print ('')
+
+ def _init_constants(self):
+
+ # File names
+ self.DATATYPE_ORIG = '.nrrd'
+ self.IMG_VOXEL_FILENAME = 'CT_IMAGE.nrrd'
+ self.MASK_VOXEL_FILENAME = 'mask.nrrd'
+ self.MASK_ORGANS_FOLDERNAME = 'segmentations'
+
+ # Label information
+ self.dataset_config = getattr(config, self.name)
+ self.LABEL_MAP_MICCAI2015_DEEPMINDTCIA = self.dataset_config[config.KEY_LABELMAP_MICCAI_DEEPMINDTCIA]
+ self.LABEL_MAP = self.dataset_config[config.KEY_LABEL_MAP]
+ self.IGNORE_LABELS = self.dataset_config[config.KEY_IGNORE_LABELS]
+ self.LABELID_MIDPOINT = self.dataset_config[config.KEY_LABELID_MIDPOINT]
+
+ def extract3D(self):
+
+ if 1:
+ import concurrent
+ import concurrent.futures
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ for dir_type in self.dataset_dir_datatypes:
+ path_extract3D = Path(self.dataset_dir_raw).joinpath(dir_type)
+ if Path(path_extract3D).exists():
+ executor.submit(self._extract3D_patients, path_extract3D)
+ else:
+ print (' - [ERROR][{}][extract3D()] {} does not exist: '.format(self.class_name, path_extract3D))
+
+ else:
+ for dir_type in [Path(config.DATALOADER_DEEPMINDTCIA_VAL, config.DATALOADER_DEEPMINDTCIA_ONC)]:
+ path_extract3D = Path(self.dataset_dir_raw).joinpath(dir_type)
+ if Path(path_extract3D).exists():
+ self._extract3D_patients(path_extract3D)
+ else:
+ print (' - [ERROR][{}][extract3D()] {} does not exist: '.format(self.class_name, path_extract3D))
+
+ print ('')
+ print (' - Note: You can view the 3D data in visualizers like MeVisLab or 3DSlicer')
+ print ('')
+
+ def _extract3D_patients(self, dir_dataset):
+
+ dir_type = Path(Path(dir_dataset).parts[-2], Path(dir_dataset).parts[-1])
+ paths_global_voxel_img = []
+ paths_global_voxel_mask = []
+
+ # Step 1 - Loop over patients of dir_type and get their img and mask paths
+ with tqdm.tqdm(total=len(list(dir_dataset.glob('*'))), desc='[3D][{}] Patients: '.format(str(dir_type)), disable=False) as pbar:
+ for _, patient_dir_path in enumerate(dir_dataset.iterdir()):
+ try:
+ if Path(patient_dir_path).is_dir():
+ voxel_img_filepath, voxel_mask_filepath, _ = self._extract3D_patient(patient_dir_path)
+ paths_global_voxel_img.append(voxel_img_filepath)
+ paths_global_voxel_mask.append(voxel_mask_filepath)
+ pbar.update(1)
+
+ except:
+ print ('')
+ print (' - [ERROR][{}][_extract3D_patients()] Error with patient_id: {}'.format(self.class_name, Path(patient_dir_path).parts[-3:]))
+ traceback.print_exc()
+ pdb.set_trace()
+
+ # Step 2 - Save paths in .csvs
+ if len(paths_global_voxel_img) and len(paths_global_voxel_mask):
+ paths_global_voxel_img = list(map(lambda x: str(x), paths_global_voxel_img))
+ paths_global_voxel_mask = list(map(lambda x: str(x), paths_global_voxel_mask))
+ utils.save_csv(Path(self.dataset_dir_processed).joinpath(dir_type, config.DIRNAME_SAVE_3D, config.FILENAME_CSV_IMG), paths_global_voxel_img)
+ utils.save_csv(Path(self.dataset_dir_processed).joinpath(dir_type, config.DIRNAME_SAVE_3D, config.FILENAME_CSV_MASK), paths_global_voxel_mask)
+
+ if len(self.VOXEL_RESO):
+ paths_global_voxel_img_resampled = []
+ for path_global_voxel_img in paths_global_voxel_img:
+ path_global_voxel_img_parts = list(Path(path_global_voxel_img).parts)
+ patient_id = path_global_voxel_img_parts[-1].split('_')[-1].split('.')[0]
+ path_global_voxel_img_parts[-1] = config.FILENAME_IMG_RESAMPLED_3D.format(patient_id)
+ path_global_voxel_img_resampled = Path(*path_global_voxel_img_parts)
+ paths_global_voxel_img_resampled.append(path_global_voxel_img_resampled)
+
+ paths_global_voxel_mask_resampled = []
+ for path_global_voxel_mask in paths_global_voxel_mask:
+ path_global_voxel_mask_parts = list(Path(path_global_voxel_mask).parts)
+ patient_id = path_global_voxel_mask_parts[-1].split('_')[-1].split('.')[0]
+ path_global_voxel_mask_parts[-1] = config.FILENAME_MASK_RESAMPLED_3D.format(patient_id)
+ path_global_voxel_mask_resampled = Path(*path_global_voxel_mask_parts)
+ paths_global_voxel_mask_resampled.append(path_global_voxel_mask_resampled)
+
+ utils.save_csv(Path(self.dataset_dir_processed).joinpath(dir_type, config.DIRNAME_SAVE_3D, config.FILENAME_CSV_IMG_RESAMPLED), paths_global_voxel_img_resampled)
+ utils.save_csv(Path(self.dataset_dir_processed).joinpath(dir_type, config.DIRNAME_SAVE_3D, config.FILENAME_CSV_MASK_RESAMPLED), paths_global_voxel_mask_resampled)
+
+ else:
+ print (' - [ERROR][{}][_extract3D_patients()] Unable to save .csv'.format(self.class_name))
+ pdb.set_trace()
+ print (' - Exiting!')
+ import sys; sys.exit(1)
+
+ def _extract3D_patient(self, patient_dir):
+
+ try:
+ voxel_img, voxel_img_headers, voxel_mask, voxel_mask_headers = self._get_data3D(patient_dir)
+
+ dir_type = Path(Path(patient_dir).parts[-3], Path(patient_dir).parts[-2])
+ patient_id = Path(patient_dir).parts[-1]
+ return self._save_data3D(dir_type, patient_id, voxel_img, voxel_img_headers, voxel_mask, voxel_mask_headers)
+
+ except:
+ print (' - [ERROR][{}][_extract_patient()] path_folder: {}'.format(self.class_name, patient_dir.parts[-3:]))
+ traceback.print_exc()
+ pdb.set_trace()
+
+ def _get_data3D(self, patient_dir):
+
+ try:
+ voxel_img, voxel_mask = [], []
+ voxel_img_headers, voxel_mask_headers = {}, {}
+
+ if Path(patient_dir).exists():
+
+ if Path(patient_dir).is_dir():
+
+ # Step 1 - Get Voxel Data
+ path_voxel_img = Path(patient_dir).joinpath(self.IMG_VOXEL_FILENAME)
+ voxel_img, voxel_img_headers = self._get_voxel_img(path_voxel_img)
+
+ # Step 2 - Get Mask Data
+ path_voxel_mask = Path(patient_dir).joinpath(self.MASK_VOXEL_FILENAME)
+ voxel_mask, voxel_mask_headers = self._get_voxel_mask(path_voxel_mask)
+
+ else:
+ print (' - [ERROR][{}][get_data()]: Path does not exist: patient_dir: {}'.format(self.class_name, patient_dir))
+
+ return voxel_img, voxel_img_headers, voxel_mask, voxel_mask_headers
+
+ except:
+ print (' - [ERROR][{}][get_data()] patient_dir: '.format(self.class_name, patient_dir.parts[-3:]))
+ traceback.print_exc()
+ pdb.set_trace()
+
+ def _get_voxel_img(self, path_voxel, histogram=False):
+
+ try:
+ if Path(path_voxel).exists():
+ voxel_img_data, voxel_img_header = nrrd.read(str(path_voxel)) # shape=[H,W,D]
+
+ if histogram:
+ import matplotlib.pyplot as plt
+ plt.hist(voxel_img_data.flatten())
+ plt.show()
+
+ return voxel_img_data, voxel_img_header
+ else:
+ print (' - [ERROR][{}][_get_voxel_img()]: Path does not exist: {}'.format(self.class_name, path_voxel))
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+ def _get_voxel_mask(self, path_voxel_mask):
+
+ try:
+
+ # Step 1 - Get mask data and headers
+ if 0: #Path(path_voxel_mask).exists():
+ voxel_mask_data, voxel_mask_headers = nrrd.read(str(path_voxel_mask))
+
+ else:
+ path_mask_folder = Path(*Path(path_voxel_mask).parts[:-1]).joinpath(self.MASK_ORGANS_FOLDERNAME)
+ voxel_mask_data, voxel_mask_headers = self._merge_masks(path_mask_folder)
+
+ # Step 2 - Make a list of available headers
+ voxel_mask_headers = dict(voxel_mask_headers)
+ voxel_mask_headers[config.KEYNAME_LABEL_OARS] = []
+ voxel_mask_headers[config.KEYNAME_LABEL_MISSING] = []
+ label_map_inverse = {label_id: label_name for label_name, label_id in self.LABEL_MAP.items()}
+ label_ids_all = self.LABEL_MAP.values()
+ label_ids_voxel_mask = np.unique(voxel_mask_data)
+ for label_id in label_ids_all:
+ label_name = label_map_inverse[label_id]
+ if label_id not in label_ids_voxel_mask:
+ voxel_mask_headers[config.KEYNAME_LABEL_MISSING].append(label_name)
+ else:
+ voxel_mask_headers[config.KEYNAME_LABEL_OARS].append(label_name)
+
+ return voxel_mask_data, voxel_mask_headers
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+ def _merge_masks(self, path_mask_folder):
+
+ try:
+ voxel_mask_full = []
+ voxel_mask_headers = {}
+ labels_oars = []
+ labels_missing = []
+ patient_id = Path(path_mask_folder).parts[-2]
+
+ if Path(path_mask_folder).exists():
+ with tqdm.tqdm(total=len(list(Path(path_mask_folder).glob('*{}'.format(self.DATATYPE_ORIG)))), leave=False, disable=True) as pbar_mask:
+ for filepath_mask in Path(path_mask_folder).iterdir():
+ class_name = Path(filepath_mask).parts[-1].split(self.DATATYPE_ORIG)[0]
+ class_name_miccai2015 = self.LABEL_MAP_MICCAI2015_DEEPMINDTCIA.get(class_name, None)
+ class_id = -1
+
+ if class_name_miccai2015 in self.LABEL_MAP:
+ class_id = self.LABEL_MAP[class_name_miccai2015]
+ labels_oars.append(class_name_miccai2015)
+ voxel_mask, voxel_mask_headers = nrrd.read(str(filepath_mask))
+
+ if class_id not in self.IGNORE_LABELS and class_id > 0:
+ if len(voxel_mask_full) == 0:
+ voxel_mask_full = np.array(voxel_mask, copy=True)
+ idxs = np.argwhere(voxel_mask > 0)
+ voxel_mask_full[idxs[:,0], idxs[:,1], idxs[:,2]] = class_id
+ if 0:
+ print (' - [merge_masks()] class_id:', class_id, ' || name: ', class_name_miccai2015, ' || idxs: ', len(idxs))
+ print (' --- [merge_masks()] label_ids: ', np.unique(voxel_mask_full))
+
+ pbar_mask.update(1)
+
+ path_mask = Path(*Path(path_mask_folder).parts[:-1]).joinpath(self.MASK_VOXEL_FILENAME)
+ nrrd.write(str(path_mask), voxel_mask_full, voxel_mask_headers)
+
+ else:
+ print (' - [ERROR][{}][_merge_masks()] path_mask_folder: {} does not exist '.format(self.class_name, path_mask_folder))
+
+ return voxel_mask_full, voxel_mask_headers
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+ def _save_data3D(self, dir_type, patient_id, voxel_img, voxel_img_headers, voxel_mask, voxel_mask_headers):
+
+ try:
+
+ # Step 1 - Create directory
+ voxel_savedir = Path(self.dataset_dir_processed).joinpath(dir_type, config.DIRNAME_SAVE_3D)
+
+ # Step 2.1 - Save voxel
+ voxel_img_headers_new = {config.TYPE_VOXEL_ORIGSHAPE:{}}
+ voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_PIXEL_SPACING] = voxel_img_headers[config.KEY_NRRD_PIXEL_SPACING][voxel_img_headers[config.KEY_NRRD_PIXEL_SPACING] > 0].tolist()
+ voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_ORIGIN] = voxel_img_headers[config.KEY_NRRD_ORIGIN].tolist()
+ voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_SHAPE] = voxel_img_headers[config.KEY_NRRD_SHAPE].tolist()
+ voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_OTHERS] = voxel_img_headers
+ voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_LABEL_MISSING] = voxel_mask_headers[config.KEYNAME_LABEL_MISSING]
+ voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_LABEL_OARS] = voxel_mask_headers[config.KEYNAME_LABEL_OARS]
+
+ if len(voxel_img) and len(voxel_mask):
+ resample_save = False
+ if len(self.VOXEL_RESO):
+ resample_save = True
+
+ # Find midpoint 3D coord of self.LABELID_MIDPOINT
+ meanpoint_idxs = np.argwhere(voxel_mask == self.LABELID_MIDPOINT)
+ meanpoint_idxs_mean = np.mean(meanpoint_idxs, axis=0)
+ voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_MEAN_MIDPOINT] = meanpoint_idxs_mean.tolist()
+
+ return utils.save_as_mha(voxel_savedir, patient_id, voxel_img, voxel_img_headers_new, voxel_mask
+ , labelid_midpoint=self.LABELID_MIDPOINT
+ , resample_spacing=self.VOXEL_RESO)
+
+ else:
+ print (' - [ERROR][HaNMICCAI2015Extractor] Error with patient_id: ', patient_id)
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
\ No newline at end of file
diff --git a/src/dataloader/extractors/han_miccai2015.py b/src/dataloader/extractors/han_miccai2015.py
new file mode 100644
index 0000000..71f1a97
--- /dev/null
+++ b/src/dataloader/extractors/han_miccai2015.py
@@ -0,0 +1,339 @@
+# Import private libraries
+import src.config as config
+import src.dataloader.utils as utils
+
+# Import public libraries
+import pdb
+import tqdm
+import nrrd
+import copy
+import traceback
+import numpy as np
+from pathlib import Path
+
+
+class HaNMICCAI2015Downloader:
+
+ def __init__(self, dataset_dir_raw, dataset_dir_processed):
+ self.dataset_dir_raw = dataset_dir_raw
+ self.dataset_dir_processed = dataset_dir_processed
+
+ def download(self):
+ self.dataset_dir_raw.mkdir(parents=True, exist_ok=True)
+ # Step 1 - Download .zips and unzip them
+ urls_zip = ['http://www.imagenglab.com/data/pddca/PDDCA-1.4.1_part1.zip'
+ , 'http://www.imagenglab.com/data/pddca/PDDCA-1.4.1_part2.zip'
+ , 'http://www.imagenglab.com/data/pddca/PDDCA-1.4.1_part3.zip']
+
+ import concurrent
+ import concurrent.futures
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ for url_zip in urls_zip:
+ filepath_zip = Path(self.dataset_dir_raw, url_zip.split('/')[-1])
+
+ # Step 1.1 - Download .zip and then unzip it
+ if not Path(filepath_zip).exists():
+ executor.submit(utils.download_zip, url_zip, filepath_zip, self.dataset_dir_raw)
+ else:
+ executor.submit(utils.read_zip, filepath_zip, self.dataset_dir_raw)
+
+ # Step 1.2 - Unzip .zip
+ # executor.submit(utils.read_zip(filepath_zip, self.dataset_dir_raw)
+
+ def sort(self, dataset_dir_datatypes, dataset_dir_datatypes_ranges):
+
+ print ('')
+ import tqdm
+ import shutil
+ import numpy as np
+
+ # Step 1 - Make necessay directories
+ self.dataset_dir_raw.mkdir(parents=True, exist_ok=True)
+ for each in dataset_dir_datatypes:
+ path_tmp = Path(self.dataset_dir_raw).joinpath(each)
+ path_tmp.mkdir(parents=True, exist_ok=True)
+
+ # Step 2 - Sort
+ with tqdm.tqdm(total=len(list(Path(self.dataset_dir_raw).glob('0522*'))), desc='Sorting', leave=False) as pbar:
+ for path_patient in self.dataset_dir_raw.iterdir():
+ if '.zip' not in path_patient.parts[-1]: #and path_patient.parts[-1] not in dataset_dir_datatypes:
+ try:
+ patient_number = Path(path_patient).parts[-1][-3:]
+ if patient_number.isdigit():
+ folder_id = np.digitize(patient_number, dataset_dir_datatypes_ranges)
+ shutil.move(src=str(path_patient), dst=str(Path(self.dataset_dir_raw).joinpath(dataset_dir_datatypes[folder_id])))
+ pbar.update(1)
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+class HaNMICCAI2015Extractor:
+ """
+ More information on the .nrrd format can be found here: http://teem.sourceforge.net/nrrd/format.html#space
+ """
+
+ def __init__(self, name, dataset_dir_raw, dataset_dir_processed, dataset_dir_datatypes):
+
+ self.name = name
+ self.dataset_dir_raw = dataset_dir_raw
+ self.dataset_dir_processed = dataset_dir_processed
+ self.dataset_dir_datatypes = dataset_dir_datatypes
+ self.folder_prefix = '0522'
+
+ self._preprint()
+ self._init_constants()
+
+ def _preprint(self):
+ self.VOXEL_RESO = getattr(config, self.name)[config.KEY_VOXELRESO]
+ print ('')
+ print (' - [HaNMICCAI2015Extractor] VOXEL_RESO: ', self.VOXEL_RESO)
+ print ('')
+
+ def _init_constants(self):
+
+ # File names
+ self.DATATYPE_ORIG = '.nrrd'
+ self.IMG_VOXEL_FILENAME = 'img.nrrd'
+ self.MASK_VOXEL_FILENAME = 'mask.nrrd'
+
+ # Label information
+ self.LABEL_MAP = getattr(config, self.name)[config.KEY_LABEL_MAP]
+ self.IGNORE_LABELS = getattr(config,self.name)[config.KEY_IGNORE_LABELS]
+ self.LABELID_MIDPOINT = getattr(config, self.name)[config.KEY_LABELID_MIDPOINT]
+
+ def extract3D(self):
+
+ import concurrent
+ import concurrent.futures
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ for dir_type in self.dataset_dir_datatypes:
+ executor.submit(self._extract3D_patients, Path(self.dataset_dir_raw).joinpath(dir_type))
+ # for dir_type in ['train']:
+ # self._extract3D_patients(Path(self.dataset_dir_raw).joinpath(dir_type))
+
+ print ('')
+ print (' - Note: You can view the 3D data in visualizers like MeVisLab or 3DSlicer')
+ print ('')
+
+ def _extract3D_patients(self, dir_dataset):
+
+ dir_type = Path(dir_dataset).parts[-1]
+ paths_global_voxel_img = []
+ paths_global_voxel_mask = []
+
+ # Step 1 - Loop over patients of dir_type and get their img and mask paths
+ dir_type_idx = self.dataset_dir_datatypes.index(dir_type)
+ with tqdm.tqdm(total=len(list(dir_dataset.glob('*'))), desc='[3D][{}] Patients: '.format(dir_type), disable=False, position=dir_type_idx) as pbar:
+ for _, patient_dir_path in enumerate(dir_dataset.iterdir()):
+ try:
+ if Path(patient_dir_path).is_dir():
+ voxel_img_filepath, voxel_mask_filepath, _ = self._extract3D_patient(patient_dir_path)
+ paths_global_voxel_img.append(voxel_img_filepath)
+ paths_global_voxel_mask.append(voxel_mask_filepath)
+ pbar.update(1)
+
+ except:
+ print ('')
+ print (' - [ERROR][HaNMICCAI2015Extractor] Error with patient_id: ', Path(patient_dir_path).parts[-2:])
+ traceback.print_exc()
+ pdb.set_trace()
+
+ # Step 2 - Save paths in .csvs
+ if len(paths_global_voxel_img) and len(paths_global_voxel_mask):
+ paths_global_voxel_img = list(map(lambda x: str(x), paths_global_voxel_img))
+ paths_global_voxel_mask = list(map(lambda x: str(x), paths_global_voxel_mask))
+ utils.save_csv(Path(self.dataset_dir_processed).joinpath(dir_type, config.DIRNAME_SAVE_3D, config.FILENAME_CSV_IMG), paths_global_voxel_img)
+ utils.save_csv(Path(self.dataset_dir_processed).joinpath(dir_type, config.DIRNAME_SAVE_3D, config.FILENAME_CSV_MASK), paths_global_voxel_mask)
+
+ if len(self.VOXEL_RESO):
+ paths_global_voxel_img_resampled = []
+ for path_global_voxel_img in paths_global_voxel_img:
+ path_global_voxel_img_parts = list(Path(path_global_voxel_img).parts)
+ patient_id = path_global_voxel_img_parts[-1].split('_')[-1].split('.')[0]
+ path_global_voxel_img_parts[-1] = config.FILENAME_IMG_RESAMPLED_3D.format(patient_id)
+ path_global_voxel_img_resampled = Path(*path_global_voxel_img_parts)
+ paths_global_voxel_img_resampled.append(path_global_voxel_img_resampled)
+
+ paths_global_voxel_mask_resampled = []
+ for path_global_voxel_mask in paths_global_voxel_mask:
+ path_global_voxel_mask_parts = list(Path(path_global_voxel_mask).parts)
+ patient_id = path_global_voxel_mask_parts[-1].split('_')[-1].split('.')[0]
+ path_global_voxel_mask_parts[-1] = config.FILENAME_MASK_RESAMPLED_3D.format(patient_id)
+ path_global_voxel_mask_resampled = Path(*path_global_voxel_mask_parts)
+ paths_global_voxel_mask_resampled.append(path_global_voxel_mask_resampled)
+
+ utils.save_csv(Path(self.dataset_dir_processed).joinpath(dir_type, config.DIRNAME_SAVE_3D, config.FILENAME_CSV_IMG_RESAMPLED), paths_global_voxel_img_resampled)
+ utils.save_csv(Path(self.dataset_dir_processed).joinpath(dir_type, config.DIRNAME_SAVE_3D, config.FILENAME_CSV_MASK_RESAMPLED), paths_global_voxel_mask_resampled)
+
+ else:
+ print (' - [ERROR][HaNMICCAI2015Extractor] Unable to save .csv')
+ pdb.set_trace()
+ print (' - Exiting!')
+ import sys; sys.exit(1)
+
+ def _extract3D_patient(self, patient_dir):
+
+ try:
+ voxel_img, voxel_img_headers, voxel_mask, voxel_mask_headers = self._get_data3D(patient_dir)
+
+ dir_type = Path(patient_dir).parts[-2]
+ patient_id = Path(patient_dir).parts[-1]
+ return self._save_data3D(dir_type, patient_id, voxel_img, voxel_img_headers, voxel_mask, voxel_mask_headers)
+
+ except:
+ print (' - [ERROR][_extract_patient()] path_folder: ', patient_dir.parts[-1])
+ traceback.print_exc()
+ pdb.set_trace()
+
+ def _get_data3D(self, patient_dir):
+ try:
+ voxel_img, voxel_mask = [], []
+ voxel_img_headers, voxel_mask_headers = {}, {}
+ if Path(patient_dir).exists():
+
+ if Path(patient_dir).is_dir():
+ # Step 1 - Get Voxel Data
+ path_voxel_img = Path(patient_dir).joinpath(self.IMG_VOXEL_FILENAME)
+ voxel_img, voxel_img_headers = self._get_voxel_img(path_voxel_img)
+
+ # Step 2 - Get Mask Data
+ path_voxel_mask = Path(patient_dir).joinpath(self.MASK_VOXEL_FILENAME)
+ voxel_mask, voxel_mask_headers = self._get_voxel_mask(path_voxel_mask)
+
+ else:
+ print (' - Error: Path does not exist: patient_dir', patient_dir)
+
+ return voxel_img, voxel_img_headers, voxel_mask, voxel_mask_headers
+
+ except:
+ print (' - [ERROR][get_data()] patient_dir: ', patient_dir.parts[-1])
+ traceback.print_exc()
+ pdb.set_trace()
+
+ def _get_voxel_img(self, path_voxel, histogram=False):
+ try:
+ if path_voxel.exists():
+ voxel_img_data, voxel_img_header = nrrd.read(str(path_voxel)) # shape=[H,W,D]
+
+ if histogram:
+ import matplotlib.pyplot as plt
+ plt.hist(voxel_img_data.flatten())
+ plt.show()
+
+ return voxel_img_data, voxel_img_header
+ else:
+ print (' - Error: Path does not exist: ', path_voxel)
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+ def _get_voxel_mask(self, path_voxel_mask):
+ try:
+
+ # Step 1 - Get mask data and headers
+ if Path(path_voxel_mask).exists():
+ voxel_mask_data, voxel_mask_headers = nrrd.read(str(path_voxel_mask))
+
+ else:
+ path_mask_folder = Path(*Path(path_voxel_mask).parts[:-1]).joinpath('structures')
+ voxel_mask_data, voxel_mask_headers = self._merge_masks(path_mask_folder)
+
+ # Step 2 - Make a list of available headers
+ voxel_mask_headers = dict(voxel_mask_headers)
+ voxel_mask_headers[config.KEYNAME_LABEL_OARS] = []
+ voxel_mask_headers[config.KEYNAME_LABEL_MISSING] = []
+ label_map_inverse = {label_id: label_name for label_name, label_id in self.LABEL_MAP.items()}
+ label_ids_all = self.LABEL_MAP.values()
+ label_ids_voxel_mask = np.unique(voxel_mask_data)
+ for label_id in label_ids_all:
+ label_name = label_map_inverse[label_id]
+ if label_id not in label_ids_voxel_mask:
+ voxel_mask_headers[config.KEYNAME_LABEL_MISSING].append(label_name)
+ else:
+ voxel_mask_headers[config.KEYNAME_LABEL_OARS].append(label_name)
+
+ return voxel_mask_data, voxel_mask_headers
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+ def _merge_masks(self, path_mask_folder):
+
+ try:
+ voxel_mask_full = []
+ voxel_mask_headers = {}
+ labels_oars = []
+ labels_missing = []
+ if Path(path_mask_folder).exists():
+ with tqdm.tqdm(total=len(list(Path(path_mask_folder).glob('*{}'.format(self.DATATYPE_ORIG)))), leave=False, disable=True) as pbar_mask:
+ for filepath_mask in Path(path_mask_folder).iterdir():
+ class_name = Path(filepath_mask).parts[-1].split(self.DATATYPE_ORIG)[0]
+ class_id = -1
+
+ if class_name in self.LABEL_MAP:
+ class_id = self.LABEL_MAP[class_name]
+ labels_oars.append(class_name)
+ voxel_mask, voxel_mask_headers = nrrd.read(str(filepath_mask))
+ # else:
+ # print (' - [ERROR][HaNMICCAI2015Extractor][_merge_masks] Unknown class name: ', class_name)
+
+ if class_id not in self.IGNORE_LABELS and class_id > 0:
+ if len(voxel_mask_full) == 0:
+ voxel_mask_full = copy.deepcopy(voxel_mask)
+ idxs = np.argwhere(voxel_mask > 0)
+ voxel_mask_full[idxs[:,0], idxs[:,1], idxs[:,2]] = class_id
+ if 0:
+ print (' - [merge_masks()] class_id:', class_id, ' || name: ', class_name, ' || idxs: ', len(idxs))
+ print (' --- [merge_masks()] label_ids: ', np.unique(voxel_mask_full))
+
+ pbar_mask.update(1)
+
+ path_mask = Path(*Path(path_mask_folder).parts[:-1]).joinpath(self.MASK_VOXEL_FILENAME)
+ nrrd.write(str(path_mask), voxel_mask_full, voxel_mask_headers)
+
+ else:
+ print (' - Error with path_mask_folder: ', path_mask_folder)
+
+ return voxel_mask_full, voxel_mask_headers
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+ def _save_data3D(self, dir_type, patient_id, voxel_img, voxel_img_headers, voxel_mask, voxel_mask_headers):
+
+ try:
+
+ # Step 1 - Create directory
+ voxel_savedir = Path(self.dataset_dir_processed).joinpath(dir_type, config.DIRNAME_SAVE_3D)
+
+ # Step 2.1 - Save voxel
+ voxel_img_headers_new = {config.TYPE_VOXEL_ORIGSHAPE:{}}
+ voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_PIXEL_SPACING] = voxel_img_headers[config.KEY_NRRD_PIXEL_SPACING][voxel_img_headers[config.KEY_NRRD_PIXEL_SPACING] > 0].tolist()
+ voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_ORIGIN] = voxel_img_headers[config.KEY_NRRD_ORIGIN].tolist()
+ voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_SHAPE] = voxel_img_headers[config.KEY_NRRD_SHAPE].tolist()
+ voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_OTHERS] = voxel_img_headers
+ voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_LABEL_MISSING] = voxel_mask_headers[config.KEYNAME_LABEL_MISSING]
+ voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_LABEL_OARS] = voxel_mask_headers[config.KEYNAME_LABEL_OARS]
+
+ if len(voxel_img) and len(voxel_mask):
+ resample_save = False
+ if len(self.VOXEL_RESO):
+ resample_save = True
+
+ # Find average HU value in the brainstem
+ meanpoint_idxs = np.argwhere(voxel_mask == self.LABELID_MIDPOINT)
+ meanpoint_idxs_mean = np.mean(meanpoint_idxs, axis=0)
+ voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_MEAN_MIDPOINT] = meanpoint_idxs_mean.tolist()
+
+ return utils.save_as_mha(voxel_savedir, patient_id, voxel_img, voxel_img_headers_new, voxel_mask
+ , labelid_midpoint=self.LABELID_MIDPOINT
+ , resample_spacing=self.VOXEL_RESO)
+
+ else:
+ print (' - [ERROR][HaNMICCAI2015Extractor] Error with patient_id: ', patient_id)
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
diff --git a/src/dataloader/han_deepmindtcia.py b/src/dataloader/han_deepmindtcia.py
new file mode 100644
index 0000000..f2169d2
--- /dev/null
+++ b/src/dataloader/han_deepmindtcia.py
@@ -0,0 +1,726 @@
+# Import private libraries
+import src.config as config
+import src.dataloader.utils as utils
+
+# Import public libraries
+import pdb
+import time
+import json
+import itertools
+import traceback
+import numpy as np
+from pathlib import Path
+
+import tensorflow as tf
+
+
+class HaNDeepMindTCIADataset:
+ """
+ Contains data from https://github.com/deepmind/tcia-ct-scan-dataset
+ - Ref: Deep learning to achieve clinically applicable segmentation of head and neck anatomy for radiotherapy
+ """
+
+ def __init__(self, data_dir, dir_type, annotator_type
+ , grid=True, crop_init=False, resampled=False, mask_type=config.MASK_TYPE_ONEHOT
+ , transforms=[], filter_grid=False, random_grid=False, pregridnorm=False
+ , patient_shuffle=False
+ , centred_dataloader_prob = 0.0
+ , parallel_calls=None, deterministic=False
+ , debug=False):
+
+ self.name = '{}_DeepMindTCIA'.format(config.HEAD_AND_NECK)
+ self.class_name = 'HaNDeepMindTCIADataset'
+
+ # Params - Source
+ self.data_dir = data_dir
+ self.dir_type = dir_type
+ self.annotator_type = annotator_type
+
+ # Params - Spatial (x,y)
+ self.grid = grid
+ self.crop_init = crop_init
+ self.resampled = resampled
+ self.mask_type = mask_type
+
+ # Params - Transforms/Filters
+ self.transforms = transforms
+ self.filter_grid = filter_grid
+ self.random_grid = random_grid
+ self.pregridnorm = pregridnorm
+
+ # Params - Memory related
+ self.patient_shuffle = patient_shuffle
+
+ # Params - Dataset related
+ self.centred_dataloader_prob = centred_dataloader_prob
+
+ # Params - TFlow Dataloader related
+ self.parallel_calls = parallel_calls # [1, tf.data.experimental.AUTOTUNE]
+ self.deterministic = deterministic
+
+ # Params - Debug
+ self.debug = debug
+
+ # Data items
+ self.data = {}
+ self.paths_img = []
+ self.paths_mask = []
+ self.cache = {}
+ self.filter = None
+
+ # Config
+ self.dataset_config = getattr(config, self.name)
+
+ # Function calls
+ self._download()
+ self._init_data()
+
+ def __len__(self):
+
+ if self.grid:
+
+ if self.crop_init:
+
+ if self.voxel_shape_cropped == [240,240,80]:
+
+ if self.grid_size == [240,240,80]:
+ return len(self.paths_img) * 1
+ elif self.grid_size == [240,240,40]:
+ return len(self.paths_img) * 2
+ elif self.grid_size == [140,140,40]:
+ return len(self.paths_img) * 8
+ else:
+ return None
+
+ else:
+ return len(self.paths_img)
+
+ def _download(self):
+
+ self.dataset_dir = Path(self.data_dir).joinpath(self.name)
+ self.dataset_dir_raw = Path(self.dataset_dir).joinpath(config.DIRNAME_RAW)
+
+ self.VOXEL_RESO = self.dataset_config[config.KEY_VOXELRESO]
+ if self.VOXEL_RESO == (0.8,0.8,2.5):
+ self.dataset_dir_processed = Path(self.dataset_dir).joinpath(config.DIRNAME_PROCESSED)
+ self.dataset_dir_processed_3D = Path(self.dataset_dir_processed).joinpath(self.dir_type, self.annotator_type, config.DIRNAME_SAVE_3D)
+ self.dataset_dir_datatypes = [
+ Path(config.DATALOADER_DEEPMINDTCIA_TEST, config.DATALOADER_DEEPMINDTCIA_ONC)
+ , Path(config.DATALOADER_DEEPMINDTCIA_TEST, config.DATALOADER_DEEPMINDTCIA_RAD)
+ , Path(config.DATALOADER_DEEPMINDTCIA_VAL, config.DATALOADER_DEEPMINDTCIA_ONC)
+ , Path(config.DATALOADER_DEEPMINDTCIA_VAL, config.DATALOADER_DEEPMINDTCIA_RAD)
+ ]
+
+ if not Path(self.dataset_dir_raw).exists() or not Path(self.dataset_dir_processed).exists():
+ print ('')
+ print (' ------------------ {} Dataset ------------------'.format(self.name))
+
+ if not Path(self.dataset_dir_raw).exists():
+ print ('')
+ print (' ------------------ Download Data ------------------')
+ from src.dataloader.extractors.han_deepmindtcia import HaNDeepMindTCIADownloader
+ downloader = HaNDeepMindTCIADownloader(self.dataset_dir_raw, self.dataset_dir_processed)
+ downloader.download()
+ downloader.sort()
+
+ if not Path(self.dataset_dir_processed_3D).exists():
+ print ('')
+ print (' ------------------ Process Data (3D) ------------------')
+ from src.dataloader.extractors.han_deepmindtcia import HaNDeepMindTCIAExtractor
+ extractor = HaNDeepMindTCIAExtractor(self.name, self.dataset_dir_raw, self.dataset_dir_processed, self.dataset_dir_datatypes)
+ extractor.extract3D()
+ print ('')
+ print (' ------------------------- * -------------------------')
+ print ('')
+
+ def _init_data(self):
+
+ # Step 0 - Init vars
+ self.patient_meta_info = {}
+ self.path_img_csv = ''
+ self.path_mask_csv = ''
+ if self.VOXEL_RESO == (0.8,0.8,2.5):
+ self.patients_z_prob = []
+ else:
+ self.patients_z_prob = []
+
+ # Step 1 - Define global paths
+ self.data_dir_processed = Path(self.dataset_dir_processed).joinpath(self.dir_type, self.annotator_type)
+ if not Path(self.data_dir_processed).exists():
+ print (' - [ERROR][{}][_init_data()] Processed Dir Path issue: {}'.format(self.class_name, self.data_dir_processed))
+ self.data_dir_processed_3D = Path(self.data_dir_processed).joinpath(config.DIRNAME_SAVE_3D)
+
+ # Step 2.1 - Get paths for 2D/3D
+ if self.resampled is False:
+ self.path_img_csv = Path(self.data_dir_processed_3D).joinpath(config.FILENAME_CSV_IMG)
+ self.path_mask_csv = Path(self.data_dir_processed_3D).joinpath(config.FILENAME_CSV_MASK)
+ else:
+ self.path_img_csv = Path(self.data_dir_processed_3D).joinpath(config.FILENAME_CSV_IMG_RESAMPLED)
+ self.path_mask_csv = Path(self.data_dir_processed_3D).joinpath(config.FILENAME_CSV_MASK_RESAMPLED)
+
+ # Step 2.2 - Get file paths
+ if Path(self.path_img_csv).exists() and Path(self.path_mask_csv).exists():
+ self.paths_img = utils.read_csv(self.path_img_csv)
+ self.paths_mask = utils.read_csv(self.path_mask_csv)
+
+ exit_condition = False
+ for path_img in self.paths_img:
+ if not Path(path_img).exists():
+ print (' - [ERROR][{}][_init_data()] path issue: path_img: {}/{}, {}'.format(self.class_name, self.dir_type, self.annotator_type, path_img))
+ exit_condition = True
+ for path_mask in self.paths_mask:
+ if not Path(path_mask).exists():
+ print (' - [ERROR][{}][_init_data()] path issue: path_mask: {}/{}, {}'.format(self.class_name, self.dir_type, self.annotator_type, path_mask))
+ exit_condition = True
+
+ if exit_condition:
+ print ('\n - [ERROR][{}][_init_data()] Exiting due to path issues'.format(self.class_name))
+ import sys; sys.exit(1)
+
+ else:
+ print (' - [ERROR][{}] Issue with path'.format(self.class_name))
+ print (' -- [ERROR][{}] self.path_img_csv : ({}) {}'.format(self.class_name, Path(self.path_img_csv).exists(), self.path_img_csv ))
+ print (' -- [ERROR][{}] self.path_mask_csv: ({}) {}'.format(self.class_name, Path(self.path_mask_csv).exists(), self.path_mask_csv ))
+
+ # Step 3.1 - Meta for labels
+ self.LABEL_MAP = self.dataset_config[config.KEY_LABEL_MAP]
+ self.LABEL_COLORS = self.dataset_config[config.KEY_LABEL_COLORS]
+ if len(self.dataset_config[config.KEY_LABEL_WEIGHTS]):
+ self.LABEL_WEIGHTS = np.array(self.dataset_config[config.KEY_LABEL_WEIGHTS])
+ self.LABEL_WEIGHTS = (self.LABEL_WEIGHTS / np.sum(self.LABEL_WEIGHTS)).tolist()
+ else:
+ self.LABEL_WEIGHTS = []
+
+ # Step 3.2 - Meta for voxel HU
+ self.HU_MIN = self.dataset_config['HU_MIN']
+ self.HU_MAX = self.dataset_config['HU_MAX']
+
+ # Step 4 - Get patient meta info
+ if self.resampled is False:
+ print (' - [Warning][{}]: This dataloader is not extracting 3D Volumes which have been resampled to the same 3D voxel spacing: {}/{}'.format(self.class_name, self.dir_type, self.annotator_type))
+ for path_img in self.paths_img:
+ path_img_json = Path(path_img).parent.absolute().joinpath(config.FILENAME_VOXEL_INFO)
+ with open(str(path_img_json), 'r') as fp:
+ patient_config_file = json.load(fp)
+
+ patient_id = Path(path_img).parts[-2]
+ midpoint_idxs_mean = []
+ missing_label_names = []
+ voxel_shape_resampled = []
+ voxel_shape_orig = []
+ if self.resampled is False:
+ midpoint_idxs_mean = patient_config_file[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_MEAN_MIDPOINT]
+ missing_label_names = patient_config_file[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_LABEL_MISSING]
+ voxel_shape_orig = patient_config_file[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_SHAPE]
+ else:
+ if config.TYPE_VOXEL_RESAMPLED in patient_config_file:
+ midpoint_idxs_mean = patient_config_file[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_MEAN_MIDPOINT]
+ missing_label_names = patient_config_file[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_LABEL_MISSING]
+ voxel_shape_orig = patient_config_file[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_SHAPE]
+ voxel_shape_resampled = patient_config_file[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_SHAPE]
+
+ else:
+ print (' - [ERROR][{}][_init_data()] There is no resampled data and you have set resample=True'.format(self.class_name))
+ print (' -- Delete _data/{}/processed directory and set VOXEL_RESO to a tuple of pixel spacing values for x,y,z axes'.format(self.name))
+ print (' -- Exiting now! ')
+ import sys; sys.exit(1)
+
+ if len(midpoint_idxs_mean):
+ midpoint_idxs_mean = np.array(midpoint_idxs_mean).astype(config.DATATYPE_VOXEL_IMG).tolist()
+ self.patient_meta_info[patient_id] = {
+ config.KEYNAME_MEAN_MIDPOINT : midpoint_idxs_mean
+ , config.KEYNAME_LABEL_MISSING : missing_label_names
+ , config.KEYNAME_SHAPE_ORIG : voxel_shape_orig
+ , config.KEYNAME_SHAPE_RESAMPLED : voxel_shape_resampled
+ }
+ else:
+ if self.crop_init:
+ print ('')
+ print (' - [ERROR][{}][_init_data()] Crop is set to true and there is no midpoint mean idx'.format(self.class_name))
+ print (' -- Exiting now! ')
+ import sys; sys.exit(1)
+
+ # Step 6 - Meta for grid sampling
+ if self.resampled:
+
+ if self.crop_init:
+ self.crop_info = self.dataset_config[config.KEY_PREPROCESS][config.TYPE_VOXEL_RESAMPLED][config.KEY_CROP][str(self.VOXEL_RESO)]
+ self.w_lateral_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_W_LEFT]
+ self.w_medial_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_W_RIGHT]
+ self.h_posterior_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_H_BACK]
+ self.h_anterior_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_H_FRONT]
+ self.d_cranial_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_D_TOP]
+ self.d_caudal_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_D_BOTTOM]
+ self.voxel_shape_cropped = [self.w_lateral_crop+self.w_medial_crop
+ , self.h_posterior_crop+self.h_anterior_crop
+ , self.d_cranial_crop+self.d_caudal_crop]
+
+ if self.grid:
+ grid_3D_params = self.dataset_config[config.KEY_GRID_3D][config.TYPE_VOXEL_RESAMPLED][str(self.VOXEL_RESO)]
+ self.grid_size = grid_3D_params[config.KEY_GRID_SIZE]
+ self.grid_overlap = grid_3D_params[config.KEY_GRID_OVERLAP]
+ self.SAMPLER_PERC = grid_3D_params[config.KEY_GRID_SAMPLER_PERC]
+ self.RANDOM_SHIFT_MAX = grid_3D_params[config.KEY_GRID_RANDOM_SHIFT_MAX]
+ self.RANDOM_SHIFT_PERC = grid_3D_params[config.KEY_GRID_RANDOM_SHIFT_PERC]
+
+ self.w_grid, self.h_grid, self.d_grid = self.grid_size
+ self.w_overlap, self.h_overlap, self.d_overlap = self.grid_overlap
+
+ else:
+
+ if self.crop_init:
+ self.w_grid, self.h_grid, self.d_grid = self.voxel_shape_cropped
+ else:
+ print (' - [ERROR][HaNMICCAI2015Dataset] No info present for non-grid cropping')
+
+ else:
+ if self.crop_init:
+ self.crop_info = self.dataset_config[config.KEY_PREPROCESS][config.TYPE_VOXEL_ORIGSHAPE][config.KEY_CROP][str(self.VOXEL_RESO)]
+ self.w_lateral_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_W_LEFT]
+ self.w_medial_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_W_RIGHT]
+ self.h_posterior_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_H_BACK]
+ self.h_anterior_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_H_FRONT]
+ self.d_cranial_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_D_TOP]
+ self.d_caudal_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_D_BOTTOM]
+ self.voxel_shape_cropped = [self.w_lateral_crop+self.w_medial_crop
+ , self.h_posterior_crop+self.h_anterior_crop
+ , self.d_cranial_crop+self.d_caudal_crop]
+
+ if self.grid:
+ pass
+ else:
+ if self.crop_init:
+ self.w_grid, self.h_grid, self.d_grid = self.voxel_shape_cropped
+ else:
+ print (' - [ERROR][HaNMICCAI2015Dataset] No info present for non-crop size')
+
+ def generator(self):
+ """
+ - Note:
+ - In general, even when running your model on an accelerator like a GPU or TPU, the tf.data pipelines are run on the CPU
+ - Ref: https://www.tensorflow.org/guide/data_performance_analysis#analysis_workflow
+ """
+
+ try:
+
+ if len(self.paths_img) and len(self.paths_mask):
+
+ # Step 1 - Create basic generator
+ dataset = None
+ dataset = tf.data.Dataset.from_generator(self._generator3D
+ , output_types=(config.DATATYPE_TF_FLOAT32, config.DATATYPE_TF_UINT8, config.DATATYPE_TF_INT32, tf.string)
+ ,args=())
+
+ # Step 2 - Get 3D data
+ dataset = dataset.map(self._get_data_3D, num_parallel_calls=self.parallel_calls, deterministic=self.deterministic)
+
+ # Step 3 - Filter function
+ if self.filter_grid:
+ dataset = dataset.filter(self.filter.execute)
+
+ # Step 4 - Data augmentations
+ if len(self.transforms):
+ for transform in self.transforms:
+ try:
+ dataset = dataset.map(transform.execute, num_parallel_calls=self.parallel_calls, deterministic=self.deterministic)
+ except:
+ traceback.print_exc()
+ print (' - [ERROR][{}][generator()] Issue with transform: {}'.format(self.class_name, transform.name))
+ else:
+ print ('')
+ print (' - [INFO][{}][generator()] No transformations available! - {}, {}'.format(self.class_name, self.dir_type, self.annotator_type))
+ print ('')
+
+ # Step 6 - Return
+ return dataset
+
+ else:
+ return None
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+ return None
+
+ def _get_paths(self, idx):
+ patient_id = ''
+ study_id = ''
+ path_img, path_mask = '', ''
+
+ if self.debug:
+ path_img = Path(self.paths_img[0]).absolute()
+ path_mask = Path(self.paths_mask[0]).absolute()
+ path_img, path_mask = self.path_debug_3D(path_img, path_mask)
+ else:
+ path_img = Path(self.paths_img[idx]).absolute()
+ path_mask = Path(self.paths_mask[idx]).absolute()
+
+ if path_img.exists() and path_mask.exists():
+ patient_id = Path(path_img).parts[-2]
+ study_id = Path(path_img).parts[-5] + '-' + Path(path_img).parts[-4]
+ else:
+ print (' - [ERROR] Issue with path')
+ print (' -- [ERROR][{}] path_img : {}'.format(self.class_name, path_img))
+ print (' -- [ERROR][{}] path_mask: {}'.format(self.class_name, path_mask))
+
+ return path_img, path_mask, patient_id, study_id
+
+ def _generator3D(self):
+
+ # Step 0 - Init
+ res = []
+
+ # Step 1 - Get patient idxs
+ idxs = np.arange(len(self.paths_img)).tolist() #[:3]
+ if self.single_sample: idxs = idxs[0:1] # [2:4]
+ if self.patient_shuffle: np.random.shuffle(idxs)
+
+ # Step 2 - Proceed on the basis of grid sampling or full-volume (self.grid=False) sampling
+ if self.grid:
+
+ # Step 2.1 - Get grid sampler info for each patient-idx
+ sampler_info = {}
+ for idx in idxs:
+ path_img = Path(self.paths_img[idx]).absolute()
+ patient_id = path_img.parts[-2]
+
+ if config.TYPE_VOXEL_RESAMPLED in str(path_img):
+ voxel_shape = self.patient_meta_info[patient_id][config.KEYNAME_SHAPE_RESAMPLED]
+ else:
+ voxel_shape = self.patient_meta_info[patient_id][config.KEYNAME_SHAPE_ORIG]
+
+ if self.crop_init:
+ if patient_id in self.patients_z_prob:
+ voxel_shape[0] = self.voxel_shape_cropped[0]
+ voxel_shape[1] = self.voxel_shape_cropped[1]
+ else:
+ voxel_shape = self.voxel_shape_cropped
+
+ grid_idxs_width = utils.split_into_overlapping_grids(voxel_shape[0], len_grid=self.grid_size[0], len_overlap=self.grid_overlap[0])
+ grid_idxs_height = utils.split_into_overlapping_grids(voxel_shape[1], len_grid=self.grid_size[1], len_overlap=self.grid_overlap[1])
+ grid_idxs_depth = utils.split_into_overlapping_grids(voxel_shape[2], len_grid=self.grid_size[2], len_overlap=self.grid_overlap[2])
+ sampler_info[idx] = list(itertools.product(grid_idxs_width,grid_idxs_height,grid_idxs_depth))
+
+ # Step 2.2 - Loop over all patients and their grids
+ # Note - Grids of a patient are extracted in order
+ for i, idx in enumerate(idxs):
+ path_img, path_mask, patient_id, study_id = self._get_paths(idx)
+ missing_labels = self.patient_meta_info[patient_id][config.KEYNAME_LABEL_MISSING]
+ bgd_mask = 1 # by default
+ if len(missing_labels):
+ bgd_mask = 0
+ if path_img.exists() and path_mask.exists():
+ for sample_info in sampler_info[idx]:
+ grid_idxs = sample_info
+ meta1 = [idx] + [grid_idxs[0][0], grid_idxs[1][0], grid_idxs[2][0]] # only include w_start, h_start, d_start
+ meta2 = '-'.join([self.name, study_id, patient_id + '_resample_' + str(self.resampled)])
+ path_img = str(path_img)
+ path_mask = str(path_mask)
+ res.append((path_img, path_mask, meta1, meta2, bgd_mask))
+
+ else:
+ label_names = list(self.LABEL_MAP.keys())
+ for i, idx in enumerate(idxs):
+ path_img, path_mask, patient_id, study_id = self._get_paths(idx)
+ missing_label_names = self.patient_meta_info[patient_id][config.KEYNAME_LABEL_MISSING]
+ bgd_mask = 1
+
+ # if len(missing_labels): bgd_mask = 0
+ if len(missing_label_names):
+ if len(set(label_names) - set(missing_label_names)):
+ bgd_mask = 0
+
+ if path_img.exists() and path_mask.exists():
+ meta1 = [idx] + [0,0,0] # dummy for w_start, h_start, d_start
+ meta2 ='-'.join([self.name, study_id, patient_id + '_resample_' + str(self.resampled)])
+ path_img = str(path_img)
+ path_mask = str(path_mask)
+ res.append((path_img, path_mask, meta1, meta2, bgd_mask))
+
+ # Step 3 - Yield
+ for each in res:
+ path_img, path_mask, meta1, meta2, bgd_mask = each
+
+ vol_img_npy, vol_mask_npy, spacing = self._get_cache_item_old(path_img, path_mask)
+ if vol_img_npy is None and vol_mask_npy is None:
+ vol_img_npy, vol_mask_npy, spacing = self._get_volume_from_path(path_img, path_mask)
+ self._set_cache_item_old(path_img, path_mask, vol_img_npy, vol_mask_npy, spacing)
+
+ spacing = tf.constant(spacing, dtype=tf.int32)
+ vol_img_npy_shape = tf.constant(vol_img_npy.shape, dtype=tf.int32)
+ meta1 = tf.concat([meta1, spacing, vol_img_npy_shape, [bgd_mask]], axis=0) # [idx,[grid_idxs],[spacing],[shape]]
+
+ yield (vol_img_npy, vol_mask_npy, meta1, meta2)
+
+ def _get_cache_item_old(self, path_img, path_mask):
+ if 'img' in self.cache and 'mask' in self.cache:
+ if path_img in self.cache['img'] and path_mask in self.cache['mask']:
+ # print (' - [_get_cache_item()] ')
+ return self.cache['img'][path_img], self.cache['mask'][path_mask], self.cache['spacing']
+ else:
+ return None, None, None
+ else:
+ return None, None, None
+
+ def _set_cache_item_old(self, path_img, path_mask, vol_img, vol_mask, spacing):
+
+ self.cache = {
+ 'img': {path_img: vol_img}
+ , 'mask': {path_mask: vol_mask}
+ , 'spacing': spacing
+ }
+
+ def _set_cache_item(self, path_img, path_mask, vol_img, vol_mask, spacing):
+ if len(self.cache) == 0:
+ self.cache = {path_img: [vol_img, vol_mask, spacing]}
+ self.cache_id = {path_img:0}
+ elif len(self.cache) == 1:
+ self.cache[path_img] = [vol_img, vol_mask, spacing]
+ self.cache_id[path_img] = 1
+ elif len(self.cache) == 2:
+ max_order_id = max(self.cache_id.values())
+ for path_img_ in self.cache_id:
+ if self.cache_id[path_img_] == max_order_id - 1:
+ self.cache.pop(path_img_)
+ self.cache[path_img] = [vol_img, vol_mask, spacing]
+ self.cache_id[path_img] = max_order_id+1
+
+ def _get_cache_item(self, path_img, path_mask):
+
+ if path_img in self.cache:
+ return self.cache[path_img]
+ else:
+ return None, None, None
+
+ def _get_volume_from_path(self, path_img, path_mask, verbose=False):
+
+ # Step 1 - Get volumes
+ if verbose: t0 = time.time()
+ vol_img_sitk = utils.read_mha(path_img)
+ vol_img_npy = utils.sitk_to_array(vol_img_sitk)
+ vol_mask_sitk = utils.read_mha(path_mask)
+ vol_mask_npy = utils.sitk_to_array(vol_mask_sitk)
+ spacing = np.array(vol_img_sitk.GetSpacing())
+
+ # Step 2 - Perform init crop on volumes
+ if self.crop_init:
+ patient_id = str(Path(path_img).parts[-2])
+ mean_point = np.array(self.patient_meta_info[patient_id][config.KEYNAME_MEAN_MIDPOINT]).astype(np.uint16).tolist()
+ vol_img_npy_shape_prev = vol_img_npy.shape
+
+ # Step 2.1 - Perform crops in H,W region
+ vol_img_npy = vol_img_npy[
+ mean_point[0] - self.w_lateral_crop : mean_point[0] + self.w_medial_crop
+ , mean_point[1] - self.h_anterior_crop : mean_point[1] + self.h_posterior_crop
+ , :
+ ]
+ vol_mask_npy = vol_mask_npy[
+ mean_point[0] - self.w_lateral_crop : mean_point[0] + self.w_medial_crop
+ , mean_point[1] - self.h_anterior_crop : mean_point[1] + self.h_posterior_crop
+ , :
+ ]
+
+ # Step 2.2 - Perform crops in D region
+ if self.grid:
+ if self.VOXEL_RESO == (0.8,0.8,2.5):
+ vol_img_npy = vol_img_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop]
+ vol_mask_npy = vol_mask_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop]
+ else:
+ vol_img_npy = vol_img_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop]
+ vol_mask_npy = vol_mask_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop]
+ else:
+ if self.resampled:
+ if self.VOXEL_RESO == (0.8,0.8,2.5):
+ vol_img_npy = vol_img_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop]
+ vol_mask_npy = vol_mask_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop]
+ else:
+ vol_img_npy = vol_img_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop]
+ vol_mask_npy = vol_mask_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop]
+ else:
+ vol_img_npy = vol_img_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop]
+ vol_mask_npy = vol_mask_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop]
+
+ # Step 3 - Pad (with=0) if volume is less in z-dimension
+ (vol_img_npy_x, vol_img_npy_y, vol_img_npy_z) = vol_img_npy.shape
+ if vol_img_npy_z < self.d_caudal_crop + self.d_cranial_crop:
+ del_z = self.d_caudal_crop + self.d_cranial_crop - vol_img_npy_z
+ vol_img_npy = np.concatenate((vol_img_npy, np.zeros((vol_img_npy_x, vol_img_npy_y, del_z))), axis=2)
+ vol_mask_npy = np.concatenate((vol_mask_npy, np.zeros((vol_img_npy_x, vol_img_npy_y, del_z))), axis=2)
+
+ if verbose: print (' - [HaNMICCAI2015Dataset._get_volume_from_path()] Time: ({}):{}s'.format(Path(path_img).parts[-2], round(time.time() - t0,2)))
+ if self.pregridnorm:
+ vol_img_npy[vol_img_npy <= self.HU_MIN] = self.HU_MIN
+ vol_img_npy[vol_img_npy >= self.HU_MAX] = self.HU_MAX
+ vol_img_npy = (vol_img_npy -np.mean(vol_img_npy))/np.std(vol_img_npy) #Standardize (z-scoring)
+
+ return tf.cast(vol_img_npy, dtype=config.DATATYPE_TF_FLOAT32), tf.cast(vol_mask_npy, dtype=config.DATATYPE_TF_UINT8), tf.constant(spacing*100, dtype=config.DATATYPE_TF_INT32)
+
+ @tf.function
+ def _get_new_grid_idx(self, start, end, max):
+
+ start_prev = start
+ if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.RANDOM_SHIFT_PERC:
+
+ delta_left = start
+ delta_right = max - end
+ shift_voxels = tf.random.uniform([], minval=0, maxval=self.RANDOM_SHIFT_MAX, dtype=tf.dtypes.int32)
+
+ if delta_left > self.RANDOM_SHIFT_MAX and delta_right > self.RANDOM_SHIFT_MAX:
+ if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.RANDOM_SHIFT_PERC:
+ start = start - shift_voxels
+ end = end - shift_voxels
+ else:
+ start = start + shift_voxels
+ end = end + shift_voxels
+
+ elif delta_left > self.RANDOM_SHIFT_MAX and delta_right <= self.RANDOM_SHIFT_MAX:
+ start = start - shift_voxels
+ end = end - shift_voxels
+
+ elif delta_left <= self.RANDOM_SHIFT_MAX and delta_right > self.RANDOM_SHIFT_MAX:
+ start = start + shift_voxels
+ end = end + shift_voxels
+
+ return start_prev, start, end
+
+ @tf.function
+ def _get_new_grid_idx_centred(self, grid_size_half, max_pt, mid_pt):
+
+ # Step 1 - Return vars
+ start, end = 0,0
+
+ # Step 2 - Define margin on either side of mid point
+ margin_left = mid_pt
+ margin_right = max_pt - mid_pt
+
+ # Ste p2 - Calculate vars
+ if margin_left >= grid_size_half and margin_right >= grid_size_half:
+ start = mid_pt - grid_size_half
+ end = mid_pt + grid_size_half
+ elif margin_right < grid_size_half:
+ if margin_left >= grid_size_half + (grid_size_half - margin_right):
+ end = mid_pt + margin_right
+ start = mid_pt - grid_size_half - (grid_size_half - margin_right)
+ else:
+ tf.print(' - [ERROR][_get_new_grid_idx_centred()] Cond 2 problem')
+ elif margin_left < grid_size_half:
+ if margin_right >= grid_size_half + (grid_size_half - margin_left):
+ start = mid_pt - margin_left
+ end = mid_pt + grid_size_half + (grid_size_half-margin_left)
+ else:
+ tf.print(' - [ERROR][_get_new_grid_idx_centred()] Cond 3 problem')
+
+ return start, end
+
+ @tf.function
+ def _get_data_3D(self, vol_img, vol_mask, meta1, meta2):
+ """
+ Params
+ ------
+ meta1: [idx, [w_start, h_start, d_start], [spacing_x, spacing_y, spacing_z], [shape_x, shape_y, shape_z], [bgd_mask]]
+ """
+
+ vol_img_npy = None
+ vol_mask_npy = None
+
+ # Step 1 - Proceed on the basis of grid sampling or full-volume (self.grid=False) sampling
+ if self.grid:
+
+ if 0: #tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.centred_dataloader_prob:
+
+ # Step 1.1 - Get a label_id and its mean
+ label_id = tf.cast(tf.random.categorical(tf.math.log([self.LABEL_WEIGHTS]), 1), dtype=config.DATATYPE_TF_UINT8)[0][0] # the LABEL_WEGIHTS sum to 1; the log is added as tf.random.categorical expects logits
+ label_id_idxs = tf.where(tf.math.equal(label_id, vol_mask))
+ label_id_idxs_mean = tf.math.reduce_mean(label_id_idxs, axis=0)
+ label_id_idxs_mean = tf.cast(label_id_idxs_mean, dtype=config.DATATYPE_TF_INT32)
+
+ # Step 1.2 - Create a grid around that mid-point
+ w_start_prev = meta1[1]
+ h_start_prev = meta1[2]
+ d_start_prev = meta1[3]
+ w_max = meta1[7]
+ h_max = meta1[8]
+ d_max = meta1[9]
+ w_grid = self.grid_size[0]
+ h_grid = self.grid_size[1]
+ d_grid = self.grid_size[2]
+ w_mid = label_id_idxs_mean[0]
+ h_mid = label_id_idxs_mean[1]
+ d_mid = label_id_idxs_mean[2]
+
+ w_start, w_end = self._get_new_grid_idx_centred(w_grid//2, w_max, w_mid)
+ h_start, h_end = self._get_new_grid_idx_centred(h_grid//2, h_max, h_mid)
+ d_start, d_end = self._get_new_grid_idx_centred(d_grid//2, d_max, d_mid)
+
+ meta1_diff = tf.convert_to_tensor([0,w_start - w_start_prev, h_start - h_start_prev, d_start - d_start_prev,0,0,0,0,0,0,0])
+ meta1 = meta1 + meta1_diff
+
+ else:
+
+ # tf.print(' - [INFO] regular dataloader: ', self.dir_type)
+ # Step 1.1 - Get raw images/masks and extract grid
+ w_start = meta1[1]
+ w_end = w_start + self.grid_size[0]
+ h_start = meta1[2]
+ h_end = h_start + self.grid_size[1]
+ d_start = meta1[3]
+ d_end = d_start + self.grid_size[2]
+
+ # Step 1.2 - Randomization of grid
+ if self.random_grid:
+ w_max = meta1[7]
+ h_max = meta1[8]
+ d_max = meta1[9]
+
+ w_start_prev = w_start
+ d_start_prev = d_start
+ w_start_prev, w_start, w_end = self._get_new_grid_idx(w_start, w_end, w_max)
+ h_start_prev, h_start, h_end = self._get_new_grid_idx(h_start, h_end, h_max)
+ d_start_prev, d_start, d_end = self._get_new_grid_idx(d_start, d_end, d_max)
+
+ meta1_diff = tf.convert_to_tensor([0,w_start - w_start_prev, h_start - h_start_prev, d_start - d_start_prev,0,0,0,0,0,0,0])
+ meta1 = meta1 + meta1_diff
+
+ # Step 1.3 - Extracting grid
+ vol_img_npy = tf.identity(vol_img[w_start:w_end, h_start:h_end, d_start:d_end])
+ vol_mask_npy = tf.identity(vol_mask[w_start:w_end, h_start:h_end, d_start:d_end])
+
+
+ else:
+ vol_img_npy = vol_img
+ vol_mask_npy = vol_mask
+
+ # Step 2 - One-hot or not
+ vol_mask_classes = []
+ label_ids_mask = []
+ label_ids = sorted(list(self.LABEL_MAP.values()))
+ if self.mask_type == config.MASK_TYPE_ONEHOT:
+ vol_mask_classes = tf.concat([tf.expand_dims(tf.math.equal(vol_mask_npy, label), axis=-1) for label in label_ids], axis=-1) # [H,W,D,L]
+ for label_id in label_ids:
+ label_ids_mask.append(tf.cast(tf.math.reduce_any(vol_mask_classes[:,:,:,label_id]), dtype=config.DATATYPE_TF_INT32))
+
+ elif self.mask_type == config.MASK_TYPE_COMBINED:
+ vol_mask_classes = vol_mask_npy
+ unique_classes, _ = tf.unique(tf.reshape(vol_mask_npy,[-1]))
+ unique_classes = tf.cast(unique_classes, config.DATATYPE_TF_INT32)
+ for label_id in label_ids:
+ label_ids_mask.append(tf.cast(tf.math.reduce_any(tf.math.equal(unique_classes, label_id)), dtype=config.DATATYPE_TF_INT32))
+
+ # Step 2.2 - Handling background mask explicitly if there is a missing label
+ bgd_mask = meta1[-1]
+ label_ids_mask[0] = bgd_mask
+ meta1 = meta1[:-1] # removes the bgd mask index
+
+ # Step 3 - Dtype conversion and expading dimensions
+ if self.mask_type == config.MASK_TYPE_ONEHOT:
+ x = tf.cast(tf.expand_dims(vol_img_npy, axis=3), dtype=config.DATATYPE_TF_FLOAT32) # [H,W,D,1]
+ else:
+ x = tf.cast(vol_img_npy, dtype=config.DATATYPE_TF_FLOAT32) # [H,W,D]
+ y = tf.cast(vol_mask_classes, dtype=config.DATATYPE_TF_FLOAT32) # [H,W,D,L]
+
+ # Step 4 - Append info to meta1
+ meta1 = tf.concat([meta1, label_ids_mask], axis=0)
+
+ # Step 5 - return
+ return (x, y, meta1, meta2)
+
\ No newline at end of file
diff --git a/src/dataloader/han_miccai2015.py b/src/dataloader/han_miccai2015.py
new file mode 100644
index 0000000..3b0c9c8
--- /dev/null
+++ b/src/dataloader/han_miccai2015.py
@@ -0,0 +1,765 @@
+# Import internal libraries
+import src.config as config
+import src.dataloader.utils as utils
+
+# Import external libraries
+import pdb
+import time
+import json
+import itertools
+import traceback
+import numpy as np
+from pathlib import Path
+import tensorflow as tf
+
+
+
+class HaNMICCAI2015Dataset:
+ """
+ The 2015 MICCAI Challenge contains CT scans of the head and neck along with annotations for 9 organs.
+ It contains train, train_additional, test_onsite and test_offsite folders
+
+ Dataset link: http://www.imagenglab.com/wiki/mediawiki/index.php?title=2015_MICCAI_Challenge
+ """
+
+ def __init__(self, data_dir, dir_type
+ , dimension=3, grid=True, crop_init=False, resampled=False, mask_type=config.MASK_TYPE_ONEHOT
+ , transforms=[], filter_grid=False, random_grid=False, pregridnorm=False
+ , parallel_calls=None, deterministic=False
+ , patient_shuffle=False
+ , centred_dataloader_prob = 0.0
+ , debug=False):
+
+ self.name = '{}_MICCAI2015'.format(config.HEAD_AND_NECK)
+
+ # Params - Source
+ self.data_dir = data_dir
+ self.dir_type = dir_type
+
+ # Params - Spatial (x,y)
+ self.dimension = dimension
+ self.grid = grid
+ self.crop_init = crop_init
+ self.resampled = resampled
+ self.mask_type = mask_type
+
+ # Params - Transforms/Filters
+ self.transforms = transforms
+ self.filter_grid = filter_grid
+ self.random_grid = random_grid
+ self.pregridnorm = pregridnorm
+
+ # Params - Memory related
+ self.patient_shuffle = patient_shuffle
+
+ # Params - Dataset related
+ self.centred_dataloader_prob = centred_dataloader_prob
+
+ # Params - TFlow Dataloader related
+ self.parallel_calls = parallel_calls # [1, tf.data.experimental.AUTOTUNE]
+ self.deterministic = deterministic
+
+ # Params - Debug
+ self.debug = debug
+
+ # Data items
+ self.data = {}
+ self.paths_img = []
+ self.paths_mask = []
+ self.cache = {}
+ self.filter = None
+
+ # Config
+ self.dataset_config = getattr(config, self.name)
+
+ # Function calls
+ self._download()
+ self._init_data()
+
+ def __len__(self):
+
+ if self.grid:
+
+ if self.crop_init:
+ if self.voxel_shape_cropped == [240,240,80]:
+ if self.grid_size == [96,96,40]:
+ return len(self.paths_img)*(18)
+ elif self.grid_size == [140,140,40]:
+ return len(self.paths_img)*8
+ elif self.grid_size == [240,240,40]:
+ return len(self.paths_img)*2
+ elif self.grid_size == [240,240,80]:
+ return len(self.paths_img)*1
+
+ elif self.voxel_shape_cropped == [240,240,100]:
+ if self.grid_size == [140,140,60]:
+ return len(self.paths_img)*8
+
+ if self.filter is None and self.filter_grid is False:
+ if self.crop_init:
+ return 10*len(self.paths_img)
+ else:
+ return 200*len(self.paths_img) # i.e. approx 200 grids per volume
+ else:
+ sampler_perc_data = 1.0 - getattr(config, self.name)['GRID_3D']['SAMPLER_PERC'] + 0.1
+ return int(200*len(self.paths_img)*sampler_perc_data)
+
+ else:
+ return len(self.paths_img)
+
+ def _download(self):
+ self.dataset_dir = Path(self.data_dir).joinpath(self.name)
+ self.dataset_dir_raw = Path(self.dataset_dir).joinpath(config.DIRNAME_RAW)
+ self.VOXEL_RESO = self.dataset_config[config.KEY_VOXELRESO]
+ if self.VOXEL_RESO == (0.8,0.8,2.5):
+ self.dataset_dir_processed = Path(self.dataset_dir).joinpath(config.DIRNAME_PROCESSED)
+ elif self.VOXEL_RESO == (1.0,1.0,2.0):
+ self.dataset_dir_processed = Path(self.dataset_dir).joinpath(config.DIRNAME_PROCESSED + 'v2')
+ else:
+ self.dataset_dir_processed = Path(self.dataset_dir).joinpath(config.DIRNAME_PROCESSED + '_v3')
+
+ self.dataset_dir_datatypes = ['train', 'train_additional', 'test_offsite', 'test_onsite']
+ self.dataset_dir_datatypes_ranges = [328+1,479+1,746+1,878+1]
+ self.dataset_dir_processed_3D = Path(self.dataset_dir_processed).joinpath('train', config.DIRNAME_SAVE_3D)
+
+ if not Path(self.dataset_dir_raw).exists() or not Path(self.dataset_dir_processed).exists():
+ print ('')
+ print (' ------------------ HaNMICCAI2015 Dataset ------------------')
+
+ if not Path(self.dataset_dir_raw).exists():
+ print ('')
+ print (' ------------------ Download Data ------------------')
+ from src.dataloader.extractors.han_miccai2015 import HaNMICCAI2015Downloader
+ downloader = HaNMICCAI2015Downloader(self.dataset_dir_raw, self.dataset_dir_processed)
+ downloader.download()
+ downloader.sort(self.dataset_dir_datatypes, self.dataset_dir_datatypes_ranges)
+
+ if not Path(self.dataset_dir_processed_3D).exists():
+ print ('')
+ print (' ------------------ Process Data (3D) ------------------')
+ from src.dataloader.extractors.han_miccai2015 import HaNMICCAI2015Extractor
+ extractor = HaNMICCAI2015Extractor(self.name, self.dataset_dir_raw, self.dataset_dir_processed, self.dataset_dir_datatypes)
+ extractor.extract3D()
+ print ('')
+ print (' ------------------------- * -------------------------')
+ print ('')
+
+ def _init_data(self):
+
+ # Step 0 - Init vars
+ self.patient_meta_info = {}
+ self.path_img_csv = ''
+ self.path_mask_csv = ''
+ if self.VOXEL_RESO == (0.8,0.8,2.4):
+ self.patients_z_prob = ['0522c0125']
+ else:
+ self.patients_z_prob = []
+
+ # Step 1 - Define global paths
+ self.data_dir_processed = Path(self.dataset_dir_processed).joinpath(self.dir_type)
+ if not Path(self.data_dir_processed).exists():
+ print (' - [ERROR][HaNMICCAI2015Dataset][_init_data()] Processed Dir Path issue: ', self.dir_type, self.data_dir_processed)
+ self.data_dir_processed_3D = Path(self.data_dir_processed).joinpath(config.DIRNAME_SAVE_3D)
+
+ # Step 2.1 - Get paths for 3D volumes
+ if self.dimension == 3:
+ if self.resampled is False:
+ self.path_img_csv = Path(self.data_dir_processed_3D).joinpath(config.FILENAME_CSV_IMG)
+ self.path_mask_csv = Path(self.data_dir_processed_3D).joinpath(config.FILENAME_CSV_MASK)
+ else:
+ self.path_img_csv = Path(self.data_dir_processed_3D).joinpath(config.FILENAME_CSV_IMG_RESAMPLED)
+ self.path_mask_csv = Path(self.data_dir_processed_3D).joinpath(config.FILENAME_CSV_MASK_RESAMPLED)
+
+ # Step 2.2 - Get file paths
+ if Path(self.path_img_csv).exists() and Path(self.path_mask_csv).exists():
+ self.paths_img = utils.read_csv(self.path_img_csv)
+ self.paths_mask = utils.read_csv(self.path_mask_csv)
+
+ exit_condition = False
+ for path_img in self.paths_img:
+ if not Path(path_img).exists():
+ print (' - [ERROR][HaNMICCAI2015Dataset][_init_data()] path issue: path_img: ',self.dir_type, path_img)
+ exit_condition = True
+ for path_mask in self.paths_mask:
+ if not Path(path_mask).exists():
+ print (' - [ERROR][HaNMICCAI2015Dataset][_init_data()] path issue: path_mask: ',self.dir_type, path_mask)
+ exit_condition = True
+
+ if exit_condition:
+ print ('\n - [_init_data()] Exiting due to path issues')
+ import sys; sys.exit(1)
+
+ else:
+ print (' - [ERROR] Issue with path')
+ print (' -- [ERROR] self.path_img_csv : ({}) {}'.format(Path(self.path_img_csv).exists(), self.path_img_csv ))
+ print (' -- [ERROR] self.path_mask_csv: ({}) {}'.format(Path(self.path_mask_csv).exists(), self.path_mask_csv ))
+
+ # Step 3.1 - Meta for labels
+ self.LABEL_MAP = self.dataset_config[config.KEY_LABEL_MAP]
+ self.LABEL_COLORS = self.dataset_config[config.KEY_LABEL_COLORS]
+ self.LABEL_WEIGHTS = np.array(self.dataset_config[config.KEY_LABEL_WEIGHTS])
+ self.LABEL_WEIGHTS = (self.LABEL_WEIGHTS / np.sum(self.LABEL_WEIGHTS)).tolist()
+
+ # Step 3.2 - Meta for voxel HU
+ self.HU_MIN = self.dataset_config['HU_MIN']
+ self.HU_MAX = self.dataset_config['HU_MAX']
+
+ # Step 4 - Get patient meta info
+ if self.resampled is False:
+ print (' - [Warning][HaNMICCAI2015Dataset]: This dataloader is not extracting 3D Volumes which have been resampled to the same 3D voxel spacing: ', self.dir_type)
+ for path_img in self.paths_img:
+ path_img_json = Path(path_img).parent.absolute().joinpath(config.FILENAME_VOXEL_INFO)
+ with open(str(path_img_json), 'r') as fp:
+ patient_config_file = json.load(fp)
+
+ patient_id = Path(path_img).parts[-2]
+ midpoint_idxs_mean = []
+ missing_label_names = []
+ voxel_shape_resampled = []
+ voxel_shape_orig = []
+ if self.resampled is False:
+ midpoint_idxs_mean = patient_config_file[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_MEAN_MIDPOINT]
+ missing_label_names = patient_config_file[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_LABEL_MISSING]
+ voxel_shape_orig = patient_config_file[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_SHAPE]
+ else:
+ if config.TYPE_VOXEL_RESAMPLED in patient_config_file:
+ midpoint_idxs_mean = patient_config_file[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_MEAN_MIDPOINT]
+ missing_label_names = patient_config_file[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_LABEL_MISSING]
+ voxel_shape_orig = patient_config_file[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_SHAPE]
+ voxel_shape_resampled = patient_config_file[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_SHAPE]
+ else:
+ print (' - [ERROR][HaNMICCAI2015Dataset] There is no resampled data and you have set resample=True')
+ print (' -- Delete the data/HaNMICCAI2015Dataset/processed directory and set VOXEL_RESO to a tuple of pixel spacing values for x,y,z axes')
+ print (' -- Exiting now! ')
+ import sys; sys.exit(1)
+
+ if len(midpoint_idxs_mean):
+ midpoint_idxs_mean = np.array(midpoint_idxs_mean).astype(config.DATATYPE_VOXEL_IMG).tolist()
+ self.patient_meta_info[patient_id] = {
+ config.KEYNAME_MEAN_MIDPOINT : midpoint_idxs_mean
+ , config.KEYNAME_LABEL_MISSING : missing_label_names
+ , config.KEYNAME_SHAPE_ORIG : voxel_shape_orig
+ , config.KEYNAME_SHAPE_RESAMPLED : voxel_shape_resampled
+ }
+ else:
+ if self.crop_init:
+ print ('')
+ print (' - [ERROR][HaNMICCAI2015Dataset] Crop is set to true and there is no midpoint mean idx')
+ print (' -- Exiting now! ')
+ import sys; sys.exit(1)
+
+ # Step 6 - Meta for grid sampling
+ if self.dimension == 3:
+ if self.resampled:
+
+ if self.crop_init:
+ self.crop_info = self.dataset_config[config.KEY_PREPROCESS][config.TYPE_VOXEL_RESAMPLED][config.KEY_CROP][str(self.VOXEL_RESO)]
+ self.w_lateral_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_W_LEFT]
+ self.w_medial_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_W_RIGHT]
+ self.h_posterior_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_H_BACK]
+ self.h_anterior_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_H_FRONT]
+ self.d_cranial_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_D_TOP]
+ self.d_caudal_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_D_BOTTOM]
+ self.voxel_shape_cropped = [self.w_lateral_crop+self.w_medial_crop
+ , self.h_posterior_crop+self.h_anterior_crop
+ , self.d_cranial_crop+self.d_caudal_crop]
+
+ if self.grid:
+ grid_3D_params = self.dataset_config[config.KEY_GRID_3D][config.TYPE_VOXEL_RESAMPLED][str(self.VOXEL_RESO)]
+ self.grid_size = grid_3D_params[config.KEY_GRID_SIZE]
+ self.grid_overlap = grid_3D_params[config.KEY_GRID_OVERLAP]
+ self.SAMPLER_PERC = grid_3D_params[config.KEY_GRID_SAMPLER_PERC]
+ self.RANDOM_SHIFT_MAX = grid_3D_params[config.KEY_GRID_RANDOM_SHIFT_MAX]
+ self.RANDOM_SHIFT_PERC = grid_3D_params[config.KEY_GRID_RANDOM_SHIFT_PERC]
+
+ self.w_grid, self.h_grid, self.d_grid = self.grid_size
+ self.w_overlap, self.h_overlap, self.d_overlap = self.grid_overlap
+
+ else:
+
+ if self.crop_init:
+ self.w_grid, self.h_grid, self.d_grid = self.voxel_shape_cropped
+ else:
+ print (' - [ERROR][HaNMICCAI2015Dataset] No info present for non-grid cropping')
+
+ else:
+ if self.crop_init:
+ preprocess_obj = getattr(config, self.name)['PREPROCESS'][config.TYPE_VOXEL_ORIGSHAPE]
+ self.crop_info = preprocess_obj['CROP']
+ self.w_lateral_crop = self.crop_info['MIDPOINT_EXTENSION_W_LEFT']
+ self.w_medial_crop = self.crop_info['MIDPOINT_EXTENSION_W_RIGHT']
+ self.h_posterior_crop = self.crop_info['MIDPOINT_EXTENSION_H_BACK']
+ self.h_anterior_crop = self.crop_info['MIDPOINT_EXTENSION_H_FRONT']
+ self.d_cranial_crop = self.crop_info['MIDPOINT_EXTENSION_D_TOP']
+ self.d_caudal_crop = self.crop_info['MIDPOINT_EXTENSION_D_BOTTOM']
+ self.voxel_shape_cropped = [self.w_lateral_crop+self.w_medial_crop
+ , self.h_posterior_crop+self.h_anterior_crop
+ , self.d_cranial_crop+self.d_caudal_crop]
+
+ if self.grid:
+ pass
+ else:
+ if self.crop_init:
+ self.w_grid, self.h_grid, self.d_grid = self.voxel_shape_cropped
+ else:
+ print (' - [ERROR][HaNMICCAI2015Dataset] No info present for non-crop size')
+
+ else:
+ print (' - [ERROR][HaNMICCAI2015Dataset] No info present for 2D data')
+
+ def get_voxel_stats(self, show=False):
+
+ spacing_x = []
+ spacing_y = []
+ spacing_z = []
+
+ info_img_path = Path(self.data_dir).joinpath(self.name, config.DIRNAME_PROCESSED, self.dir_type, config.DIRNAME_SAVE_2D, config.FILENAME_VOXEL_INFO)
+
+ if Path(info_img_path).exists():
+ import json
+ with open(str(info_img_path), 'r') as fp:
+ data = json.load(fp)
+ for patient_id in data:
+ spacing_info = data[patient_id][config.TYPE_VOXEL_ORIGSHAPE]['spacing']
+ spacing_x.append(spacing_info[0])
+ spacing_y.append(spacing_info[1])
+ spacing_z.append(spacing_info[2])
+
+ if show:
+ if len(spacing_x) and len(spacing_y) and len(spacing_z):
+ import matplotlib.pyplot as plt
+ f,axarr = plt.subplots(1,3)
+ axarr[0].hist(spacing_x); axarr[0].set_title('Voxel Spacing (X)')
+ axarr[1].hist(spacing_y); axarr[1].set_title('Voxel Spacing (Y)')
+ axarr[2].hist(spacing_z); axarr[2].set_title('Voxel Spacing (Z)')
+ plt.suptitle(self.name)
+ plt.show()
+
+ else:
+ print (' - [ERROR][get_voxel_stats()] Path issue: info_img_path: ', info_img_path)
+
+ return spacing_x, spacing_y, spacing_z
+
+ def generator(self):
+ """
+ - Note:
+ - In general, even when running your model on an accelerator like a GPU or TPU, the tf.data pipelines are run on the CPU
+ - Ref: https://www.tensorflow.org/guide/data_performance_analysis#analysis_workflow
+ """
+
+ try:
+
+ if len(self.paths_img) and len(self.paths_mask):
+
+ # Step 1 - Create basic generator
+ dataset = None
+ if self.dimension == 3:
+ dataset = tf.data.Dataset.from_generator(self._generator3D
+ , output_types=(config.DATATYPE_TF_FLOAT32, config.DATATYPE_TF_UINT8, config.DATATYPE_TF_INT32, tf.string)
+ ,args=())
+
+ # Step 2 - Get 3D data
+ if self.dimension == 3:
+ dataset = dataset.map(self._get_data_3D, num_parallel_calls=self.parallel_calls, deterministic=self.deterministic)
+ # dataset = dataset.apply(tf.data.experimental.copy_to_device(target_device='/GPU:0'))
+
+ # Step 3 - Filter function
+ if self.filter_grid:
+ dataset = dataset.filter(self.filter.execute)
+
+ # Step 4 - Data augmentations
+ if len(self.transforms):
+ for transform in self.transforms:
+ try:
+ dataset = dataset.map(transform.execute, num_parallel_calls=self.parallel_calls, deterministic=self.deterministic)
+ except:
+ traceback.print_exc()
+ print (' - [ERROR][HaNMICCAI2015Dataset] Issue with transform: ', transform.name)
+ else:
+ print ('')
+ print (' - [INFO][HaNMICCAI2015Dataset] No transformations available!', self.dir_type)
+ print ('')
+
+ # Step 6 - Return
+ return dataset
+
+ else:
+ return None
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+ return None
+
+ def _get_paths(self, idx):
+ patient_id = ''
+ study_id = ''
+ path_img, path_mask = '', ''
+
+ if self.debug:
+ path_img = Path(self.paths_img[0]).absolute()
+ path_mask = Path(self.paths_mask[0]).absolute()
+ path_img, path_mask = self.path_debug_3D(path_img, path_mask)
+ else:
+ path_img = Path(self.paths_img[idx]).absolute()
+ path_mask = Path(self.paths_mask[idx]).absolute()
+
+ if path_img.exists() and path_mask.exists():
+ patient_id = Path(path_img).parts[-2]
+ study_id = Path(path_img).parts[-4]
+ else:
+ print (' - [ERROR] Issue with path')
+ print (' -- [ERROR][HaNMICCAI2015] path_img : ', path_img)
+ print (' -- [ERROR][HaNMICCAI2015] path_mask: ', path_mask)
+
+ return path_img, path_mask, patient_id, study_id
+
+ def _generator3D(self):
+
+ # Step 0 - Init
+ res = []
+
+ # Step 1 - Get patient idxs
+ idxs = np.arange(len(self.paths_img)).tolist() #[:3]
+ if self.patient_shuffle: np.random.shuffle(idxs)
+
+ # Step 2 - Proceed on the basis of grid sampling or full-volume (self.grid=False) sampling
+ if self.grid:
+
+ # Step 2.1 - Get grid sampler info for each patient-idx
+ sampler_info = {}
+ for idx in idxs:
+ path_img = Path(self.paths_img[idx]).absolute()
+ patient_id = path_img.parts[-2]
+
+ if config.TYPE_VOXEL_RESAMPLED in str(path_img):
+ voxel_shape = self.patient_meta_info[patient_id][config.KEYNAME_SHAPE_RESAMPLED]
+ else:
+ voxel_shape = self.patient_meta_info[patient_id][config.KEYNAME_SHAPE_ORIG]
+
+ if self.crop_init:
+ if patient_id in self.patients_z_prob:
+ voxel_shape[0] = self.voxel_shape_cropped[0]
+ voxel_shape[1] = self.voxel_shape_cropped[1]
+ else:
+ voxel_shape = self.voxel_shape_cropped
+
+ grid_idxs_width = utils.split_into_overlapping_grids(voxel_shape[0], len_grid=self.grid_size[0], len_overlap=self.grid_overlap[0])
+ grid_idxs_height = utils.split_into_overlapping_grids(voxel_shape[1], len_grid=self.grid_size[1], len_overlap=self.grid_overlap[1])
+ grid_idxs_depth = utils.split_into_overlapping_grids(voxel_shape[2], len_grid=self.grid_size[2], len_overlap=self.grid_overlap[2])
+ sampler_info[idx] = list(itertools.product(grid_idxs_width,grid_idxs_height,grid_idxs_depth))
+
+ if 0: #patient_id in self.patients_z_prob:
+ print (' - [DEBUG] patient_id: ', patient_id, ' || voxel_shape: ', voxel_shape)
+ print (' - [DEBUG] sampler_info: ', sampler_info[idx])
+
+ # Step 2.2 - Loop over all patients and their grids
+ # Note - Grids of a patient are extracted in order
+ for i, idx in enumerate(idxs):
+ path_img, path_mask, patient_id, study_id = self._get_paths(idx)
+ missing_labels = self.patient_meta_info[patient_id][config.KEYNAME_LABEL_MISSING]
+ bgd_mask = 1 # by default
+ if len(missing_labels):
+ bgd_mask = 0
+ if path_img.exists() and path_mask.exists():
+ for sample_info in sampler_info[idx]:
+ grid_idxs = sample_info
+ meta1 = [idx] + [grid_idxs[0][0], grid_idxs[1][0], grid_idxs[2][0]] # only include w_start, h_start, d_start
+ meta2 = '-'.join([self.name, study_id, patient_id + '_resample_' + str(self.resampled)])
+ path_img = str(path_img)
+ path_mask = str(path_mask)
+ res.append((path_img, path_mask, meta1, meta2, bgd_mask))
+
+ else:
+ label_names = list(self.LABEL_MAP.keys())
+ for i, idx in enumerate(idxs):
+ path_img, path_mask, patient_id, study_id = self._get_paths(idx)
+ missing_label_names = self.patient_meta_info[patient_id][config.KEYNAME_LABEL_MISSING]
+ bgd_mask = 1
+
+ # if len(missing_labels): bgd_mask = 0
+ if len(missing_label_names):
+ if len(set(label_names) - set(missing_label_names)):
+ bgd_mask = 0
+
+ if path_img.exists() and path_mask.exists():
+ meta1 = [idx] + [0,0,0] # dummy for w_start, h_start, d_start
+ meta2 ='-'.join([self.name, study_id, patient_id + '_resample_' + str(self.resampled)])
+ path_img = str(path_img)
+ path_mask = str(path_mask)
+ res.append((path_img, path_mask, meta1, meta2, bgd_mask))
+
+ # Step 3 - Yield
+ for each in res:
+ path_img, path_mask, meta1, meta2, bgd_mask = each
+
+ vol_img_npy, vol_mask_npy, spacing = self._get_cache_item(path_img, path_mask)
+ if vol_img_npy is None and vol_mask_npy is None:
+ vol_img_npy, vol_mask_npy, spacing = self._get_volume_from_path(path_img, path_mask)
+ self._set_cache_item(path_img, path_mask, vol_img_npy, vol_mask_npy, spacing)
+
+ spacing = tf.constant(spacing, dtype=tf.int32)
+ vol_img_npy_shape = tf.constant(vol_img_npy.shape, dtype=tf.int32)
+ meta1 = tf.concat([meta1, spacing, vol_img_npy_shape, [bgd_mask]], axis=0) # [idx,[grid_idxs],[spacing],[shape]]
+
+ yield (vol_img_npy, vol_mask_npy, meta1, meta2)
+
+ def _get_cache_item(self, path_img, path_mask):
+ if 'img' in self.cache and 'mask' in self.cache:
+ if path_img in self.cache['img'] and path_mask in self.cache['mask']:
+ return self.cache['img'][path_img], self.cache['mask'][path_mask], self.cache['spacing']
+ else:
+ return None, None, None
+ else:
+ return None, None, None
+
+ def _set_cache_item(self, path_img, path_mask, vol_img, vol_mask, spacing):
+ # print (' - [_set_cache_item() ]: ', vol_img.shape, vol_mask.shape)
+ self.cache = {
+ 'img': {path_img: vol_img}
+ , 'mask': {path_mask: vol_mask}
+ , 'spacing': spacing
+ }
+
+ def _get_volume_from_path(self, path_img, path_mask, verbose=False):
+
+ # Step 1 - Get volumes
+ if verbose: t0 = time.time()
+ vol_img_sitk = utils.read_mha(path_img)
+ vol_img_npy = utils.sitk_to_array(vol_img_sitk)
+ vol_mask_sitk = utils.read_mha(path_mask)
+ vol_mask_npy = utils.sitk_to_array(vol_mask_sitk)
+ spacing = np.array(vol_img_sitk.GetSpacing())
+
+ # Step 2 - Perform init crop on volumes
+ if self.crop_init:
+ patient_id = str(Path(path_img).parts[-2])
+ mean_point = np.array(self.patient_meta_info[patient_id][config.KEYNAME_MEAN_MIDPOINT]).astype(np.uint16).tolist()
+
+ # Step 2.1 - Perform crops in H,W region
+ vol_img_npy = vol_img_npy[
+ mean_point[0] - self.w_lateral_crop : mean_point[0] + self.w_medial_crop
+ , mean_point[1] - self.h_anterior_crop : mean_point[1] + self.h_posterior_crop
+ , :
+ ]
+ vol_mask_npy = vol_mask_npy[
+ mean_point[0] - self.w_lateral_crop : mean_point[0] + self.w_medial_crop
+ , mean_point[1] - self.h_anterior_crop : mean_point[1] + self.h_posterior_crop
+ , :
+ ]
+
+ # Step 2.2 - Perform crops in D region
+ if self.grid:
+ if self.VOXEL_RESO == (0.8,0.8,2.4):
+ if '0522c0125' not in patient_id:
+ vol_img_npy = vol_img_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop]
+ vol_mask_npy = vol_mask_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop]
+ else:
+ vol_img_npy = vol_img_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop]
+ vol_mask_npy = vol_mask_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop]
+ else:
+ if self.resampled:
+ if self.VOXEL_RESO == (0.8,0.8,2.4):
+ if '0522c0125' in patient_id:
+ # vol_img_npy = vol_img_npy[:, :, 0:self.d_caudal_crop + self.d_cranial_crop]
+ # vol_mask_npy = vol_mask_npy[:, :, 0:self.d_caudal_crop + self.d_cranial_crop]
+ vol_img_npy = vol_img_npy[:, :, -(self.d_caudal_crop + self.d_cranial_crop):]
+ vol_mask_npy = vol_mask_npy[:, :, -(self.d_caudal_crop + self.d_cranial_crop):]
+ else:
+ vol_img_npy = vol_img_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop]
+ vol_mask_npy = vol_mask_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop]
+ else:
+ vol_img_npy = vol_img_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop]
+ vol_mask_npy = vol_mask_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop]
+ else:
+ vol_img_npy = vol_img_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop]
+ vol_mask_npy = vol_mask_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop]
+
+ # Step 3 - Pad (with=0) if volume is less in z-dimension
+ (vol_img_npy_x, vol_img_npy_y, vol_img_npy_z) = vol_img_npy.shape
+ if vol_img_npy_z < self.d_caudal_crop + self.d_cranial_crop:
+ del_z = self.d_caudal_crop + self.d_cranial_crop - vol_img_npy_z
+ vol_img_npy = np.concatenate((vol_img_npy, np.zeros((vol_img_npy_x, vol_img_npy_y, del_z))), axis=2)
+ vol_mask_npy = np.concatenate((vol_mask_npy, np.zeros((vol_img_npy_x, vol_img_npy_y, del_z))), axis=2)
+
+ # print (' - [cropping] z_mean: ', mean_point[2], ' || -', self.d_caudal_crop, ' || + ', self.d_cranial_crop)
+ # print (' - [cropping] || shape: ', vol_img_npy.shape)
+ # print (' - [DEBUG] change: ', vol_img_npy_shape_prev, vol_img_npy.shape)
+
+ if verbose: print (' - [HaNMICCAI2015Dataset._get_volume_from_path()] Time: ({}):{}s'.format(Path(path_img).parts[-2], round(time.time() - t0,2)))
+ if self.pregridnorm:
+ vol_img_npy[vol_img_npy <= self.HU_MIN] = self.HU_MIN
+ vol_img_npy[vol_img_npy >= self.HU_MAX] = self.HU_MAX
+ vol_img_npy = (vol_img_npy -np.mean(vol_img_npy))/np.std(vol_img_npy) #Standardize (z-scoring)
+
+ return tf.cast(vol_img_npy, dtype=config.DATATYPE_TF_FLOAT32), tf.cast(vol_mask_npy, dtype=config.DATATYPE_TF_UINT8), tf.constant(spacing*100, dtype=config.DATATYPE_TF_INT32)
+
+ @tf.function
+ def _get_new_grid_idx(self, start, end, max):
+
+ start_prev = start
+ if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.RANDOM_SHIFT_PERC:
+
+ delta_left = start
+ delta_right = max - end
+ shift_voxels = tf.random.uniform([], minval=0, maxval=self.RANDOM_SHIFT_MAX, dtype=tf.dtypes.int32)
+
+ if delta_left > self.RANDOM_SHIFT_MAX and delta_right > self.RANDOM_SHIFT_MAX:
+ if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.RANDOM_SHIFT_PERC:
+ start = start - shift_voxels
+ end = end - shift_voxels
+ else:
+ start = start + shift_voxels
+ end = end + shift_voxels
+
+ elif delta_left > self.RANDOM_SHIFT_MAX and delta_right <= self.RANDOM_SHIFT_MAX:
+ start = start - shift_voxels
+ end = end - shift_voxels
+
+ elif delta_left <= self.RANDOM_SHIFT_MAX and delta_right > self.RANDOM_SHIFT_MAX:
+ start = start + shift_voxels
+ end = end + shift_voxels
+
+ return start_prev, start, end
+
+ @tf.function
+ def _get_new_grid_idx_centred(self, grid_size_half, max_pt, mid_pt):
+
+ # Step 1 - Return vars
+ start, end = 0,0
+
+ # Step 2 - Define margin on either side of mid point
+ margin_left = mid_pt
+ margin_right = max_pt - mid_pt
+
+ # Ste p2 - Calculate vars
+ if margin_left >= grid_size_half and margin_right >= grid_size_half:
+ start = mid_pt - grid_size_half
+ end = mid_pt + grid_size_half
+ elif margin_right < grid_size_half:
+ if margin_left >= grid_size_half + (grid_size_half - margin_right):
+ end = mid_pt + margin_right
+ start = mid_pt - grid_size_half - (grid_size_half - margin_right)
+ else:
+ tf.print(' - [ERROR][_get_new_grid_idx_centred()] Cond 2 problem')
+ elif margin_left < grid_size_half:
+ if margin_right >= grid_size_half + (grid_size_half - margin_left):
+ start = mid_pt - margin_left
+ end = mid_pt + grid_size_half + (grid_size_half-margin_left)
+ else:
+ tf.print(' - [ERROR][_get_new_grid_idx_centred()] Cond 3 problem')
+
+ return start, end
+
+ @tf.function
+ def _get_data_3D(self, vol_img, vol_mask, meta1, meta2):
+ """
+ Params
+ ------
+ meta1: [idx, [w_start, h_start, d_start], [spacing_x, spacing_y, spacing_z], [shape_x, shape_y, shape_z], [bgd_mask]]
+ """
+
+ vol_img_npy = None
+ vol_mask_npy = None
+
+ # Step 1 - Proceed on the basis of grid sampling or full-volume (self.grid=False) sampling
+ if self.grid:
+
+ if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.centred_dataloader_prob:
+
+ # Step 1.1 - Get a label_id and its mean
+ label_id = tf.cast(tf.random.categorical(tf.math.log([self.LABEL_WEIGHTS]), 1), dtype=config.DATATYPE_TF_UINT8)[0][0] # the LABEL_WEGIHTS sum to 1; the log is added as tf.random.categorical expects logits
+ label_id_idxs = tf.where(tf.math.equal(label_id, vol_mask))
+ label_id_idxs_mean = tf.math.reduce_mean(label_id_idxs, axis=0)
+ label_id_idxs_mean = tf.cast(label_id_idxs_mean, dtype=config.DATATYPE_TF_INT32)
+
+ # Step 1.2 - Create a grid around that mid-point
+ w_start_prev = meta1[1]
+ h_start_prev = meta1[2]
+ d_start_prev = meta1[3]
+ w_max = meta1[7]
+ h_max = meta1[8]
+ d_max = meta1[9]
+ w_grid = self.grid_size[0]
+ h_grid = self.grid_size[1]
+ d_grid = self.grid_size[2]
+ w_mid = label_id_idxs_mean[0]
+ h_mid = label_id_idxs_mean[1]
+ d_mid = label_id_idxs_mean[2]
+
+ w_start, w_end = self._get_new_grid_idx_centred(w_grid//2, w_max, w_mid)
+ h_start, h_end = self._get_new_grid_idx_centred(h_grid//2, h_max, h_mid)
+ d_start, d_end = self._get_new_grid_idx_centred(d_grid//2, d_max, d_mid)
+
+ meta1_diff = tf.convert_to_tensor([0,w_start - w_start_prev, h_start - h_start_prev, d_start - d_start_prev,0,0,0,0,0,0,0])
+ meta1 = meta1 + meta1_diff
+
+ else:
+
+ # tf.print(' - [INFO] regular dataloader: ', self.dir_type)
+ # Step 1.1 - Get raw images/masks and extract grid
+ w_start = meta1[1]
+ w_end = w_start + self.grid_size[0]
+ h_start = meta1[2]
+ h_end = h_start + self.grid_size[1]
+ d_start = meta1[3]
+ d_end = d_start + self.grid_size[2]
+
+ # Step 1.2 - Randomization of grid
+ if self.random_grid:
+ w_max = meta1[7]
+ h_max = meta1[8]
+ d_max = meta1[9]
+
+ w_start_prev = w_start
+ d_start_prev = d_start
+ w_start_prev, w_start, w_end = self._get_new_grid_idx(w_start, w_end, w_max)
+ h_start_prev, h_start, h_end = self._get_new_grid_idx(h_start, h_end, h_max)
+ d_start_prev, d_start, d_end = self._get_new_grid_idx(d_start, d_end, d_max)
+
+ meta1_diff = tf.convert_to_tensor([0,w_start - w_start_prev, h_start - h_start_prev, d_start - d_start_prev,0,0,0,0,0,0,0])
+ meta1 = meta1 + meta1_diff
+
+ # Step 1.3 - Extracting grid
+ vol_img_npy = tf.identity(vol_img[w_start:w_end, h_start:h_end, d_start:d_end])
+ vol_mask_npy = tf.identity(vol_mask[w_start:w_end, h_start:h_end, d_start:d_end])
+
+
+ else:
+ vol_img_npy = vol_img
+ vol_mask_npy = vol_mask
+
+ # Step 2 - One-hot or not
+ vol_mask_classes = []
+ label_ids_mask = []
+ label_ids = sorted(list(self.LABEL_MAP.values()))
+ if self.mask_type == config.MASK_TYPE_ONEHOT:
+ vol_mask_classes = tf.concat([tf.expand_dims(tf.math.equal(vol_mask_npy, label), axis=-1) for label in label_ids], axis=-1) # [H,W,D,L]
+ for label_id in label_ids:
+ label_ids_mask.append(tf.cast(tf.math.reduce_any(vol_mask_classes[:,:,:,label_id]), dtype=config.DATATYPE_TF_INT32))
+
+ elif self.mask_type == config.MASK_TYPE_COMBINED:
+ vol_mask_classes = vol_mask_npy
+ unique_classes, _ = tf.unique(tf.reshape(vol_mask_npy,[-1]))
+ unique_classes = tf.cast(unique_classes, config.DATATYPE_TF_INT32)
+ for label_id in label_ids:
+ label_ids_mask.append(tf.cast(tf.math.reduce_any(tf.math.equal(unique_classes, label_id)), dtype=config.DATATYPE_TF_INT32))
+
+ # Step 2.2 - Handling background mask explicitly if there is a missing label
+ bgd_mask = meta1[-1]
+ label_ids_mask[0] = bgd_mask
+ meta1 = meta1[:-1] # removes the bgd mask index
+
+ # Step 3 - Dtype conversion and expading dimensions
+ if self.mask_type == config.MASK_TYPE_ONEHOT:
+ x = tf.cast(tf.expand_dims(vol_img_npy, axis=3), dtype=config.DATATYPE_TF_FLOAT32) # [H,W,D,1]
+ else:
+ x = tf.cast(vol_img_npy, dtype=config.DATATYPE_TF_FLOAT32) # [H,W,D]
+ y = tf.cast(vol_mask_classes, dtype=config.DATATYPE_TF_FLOAT32) # [H,W,D,L]
+
+ # Step 4 - Append info to meta1
+ meta1 = tf.concat([meta1, label_ids_mask], axis=0)
+
+ # Step 5 - return
+ return (x, y, meta1, meta2)
+
\ No newline at end of file
diff --git a/src/dataloader/utils.py b/src/dataloader/utils.py
new file mode 100644
index 0000000..88c6741
--- /dev/null
+++ b/src/dataloader/utils.py
@@ -0,0 +1,1273 @@
+# Import internal libraries
+import src.config as config
+import src.dataloader.augmentations as aug
+from src.dataloader.dataset import ZipDataset
+from src.dataloader.han_miccai2015 import HaNMICCAI2015Dataset
+from src.dataloader.han_deepmindtcia import HaNDeepMindTCIADataset
+
+# Import external libraries
+import os
+import pdb
+import itk
+import copy
+import time
+import tqdm
+import json
+import urllib
+import psutil
+import pydicom
+import humanize
+import traceback
+import numpy as np
+import tensorflow as tf
+from pathlib import Path
+import SimpleITK as sitk # sitk.Version.ExtendedVersionString()
+
+if len(tf.config.list_physical_devices('GPU')):
+ tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True)
+
+
+
+############################################################
+# DOWNLOAD RELATED #
+############################################################
+
+class DownloadProgressBar(tqdm.tqdm):
+ def update_to(self, b=1, bsize=1, tsize=None):
+ if tsize is not None:
+ self.total = tsize
+ self.update(b * bsize - self.n)
+
+def download_zip(url_zip, filepath_zip, filepath_output):
+ import urllib
+ with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=' - [Download]' + str(url_zip.split('/')[-1]) ) as pbar:
+ urllib.request.urlretrieve(url_zip, filename=filepath_zip, reporthook=pbar.update_to)
+ read_zip(filepath_zip, filepath_output)
+
+def read_zip(filepath_zip, filepath_output=None, leave=False):
+
+ import zipfile
+
+ # Step 0 - Init
+ if Path(filepath_zip).exists():
+ if filepath_output is None:
+ filepath_zip_parts = list(Path(filepath_zip).parts)
+ filepath_zip_name = filepath_zip_parts[-1].split(config.EXT_ZIP)[0]
+ filepath_zip_parts[-1] = filepath_zip_name
+ filepath_output = Path(*filepath_zip_parts)
+
+ zip_fp = zipfile.ZipFile(filepath_zip, 'r')
+ zip_fp_members = zip_fp.namelist()
+ with tqdm.tqdm(total=len(zip_fp_members), desc=' - [Unzip] ' + str(filepath_zip.parts[-1]), leave=leave) as pbar_zip:
+ for member in zip_fp_members:
+ zip_fp.extract(member, filepath_output)
+ pbar_zip.update(1)
+
+ return filepath_output
+ else:
+ print (' - [ERROR][read_zip()] Path does not exist: ', filepath_zip)
+ return None
+
+def move_folder(path_src_folder, path_dest_folder):
+
+ import shutil
+ if Path(path_src_folder).exists():
+ shutil.move(str(path_src_folder), str(path_dest_folder))
+ else:
+ print (' - [ERROR][utils.move_folder()] path_src_folder:{} does not exist'.format(path_src_folder))
+
+############################################################
+# ITK/SITK RELATED #
+############################################################
+
+def write_mha(filepath, data_array, spacing=[], origin=[]):
+ # data_array = numpy array
+ if len(spacing) and len(origin):
+ img_sitk = array_to_sitk(data_array, spacing=spacing, origin=origin)
+ elif len(spacing) and not len(origin):
+ img_sitk = array_to_sitk(data_array, spacing=spacing)
+
+ sitk.WriteImage(img_sitk, str(filepath), useCompression=True)
+
+def write_mha(path_save, img_data, img_headers, img_dtype):
+
+ # Step 0 - Path related
+ path_save_parents = Path(path_save).parent.absolute()
+ path_save_parents.mkdir(exist_ok=True, parents=True)
+
+ # Step 1 - Create ITK volume
+ orig_origin = img_headers[config.KEYNAME_ORIGIN]
+ orig_pixel_spacing = img_headers[config.KEYNAME_PIXEL_SPACING]
+ img_sitk = array_to_sitk(img_data, origin=orig_origin, spacing=orig_pixel_spacing)
+ img_sitk = sitk.Cast(img_sitk, img_dtype)
+ sitk.WriteImage(img_sitk, str(path_save), useCompression=True)
+
+ return img_sitk
+
+def imwrite_sitk(img, filepath, dtype, compression=True):
+ """
+ img: itk () or sitk image
+ """
+ import itk
+
+ def convert_itk_to_sitk(image_itk, dtype):
+
+ img_array = itk.GetArrayFromImage(image_itk)
+ if dtype in ['short', 'int16', config.DATATYPE_VOXEL_IMG]:
+ img_array = np.array(img_array, dtype=config.DATATYPE_VOXEL_IMG)
+ elif dtype in ['unsigned int', 'uint8', config.DATATYPE_VOXEL_MASK]:
+ img_array = np.array(img_array, dtype=config.DATATYPE_VOXEL_MASK)
+
+ image_sitk = sitk.GetImageFromArray(img_array, isVector=image_itk.GetNumberOfComponentsPerPixel()>1)
+ image_sitk.SetOrigin(tuple(image_itk.GetOrigin()))
+ image_sitk.SetSpacing(tuple(image_itk.GetSpacing()))
+ image_sitk.SetDirection(itk.GetArrayFromMatrix(image_itk.GetDirection()).flatten())
+ return image_sitk
+
+ writer = sitk.ImageFileWriter()
+ writer.SetFileName(str(filepath))
+ writer.SetUseCompression(compression)
+ if 'SimpleITK' not in str(type(img)):
+ writer.Execute(convert_itk_to_sitk(img, dtype))
+ else:
+ writer.Execute(img)
+
+def read_itk(img_url, fixed_ct_skip_slices=None):
+
+ if Path(img_url).exists():
+ img = itk.imread(str(img_url), itk.F)
+
+ if fixed_ct_skip_slices is not None:
+ img_array = itk.GetArrayFromImage(img) # [D,H,W]
+ img_array = img_array[fixed_ct_skip_slices:,:,:]
+
+ img_ = itk.GetImageFromArray(img_array)
+ img_.SetOrigin(tuple(img.GetOrigin()))
+ img_.SetSpacing(tuple(img.GetSpacing()))
+ img_.SetDirection(img.GetDirection())
+ img = img_
+
+ return img
+ else:
+ print (' - [read_itk()] Path does not exist: ', img_url)
+ return None
+
+def read_itk_mask(mask_url, fixed_ct_skip_slices=None):
+
+ if Path(mask_url).exists():
+ img = itk.imread(str(mask_url), itk.UC)
+
+ if fixed_ct_skip_slices is not None:
+ img_array = itk.GetArrayFromImage(img) # [D,H,W]
+ img_array = img_array[fixed_ct_skip_slices:,:,:]
+
+ img_ = itk.GetImageFromArray(img_array)
+ img_.SetOrigin(tuple(img.GetOrigin()))
+ img_.SetSpacing(tuple(img.GetSpacing()))
+ img_.SetDirection(img.GetDirection())
+ img = img_
+
+ return img
+
+ else:
+ print (' - [read_itk()] Path does not exist: ', mask_url)
+ return None
+
+def array_to_itk(array, origin, spacing):
+ """
+ array = [H,W,D]
+ origin = [x,y,z]
+ spacing = [x,y,z]
+ """
+
+ try:
+ import itk
+
+ img_itk = itk.GetImageFromArray(np.moveaxis(array, [0,1,2], [2,1,0]).copy())
+ img_itk.SetOrigin(tuple(origin))
+ img_itk.SetSpacing(tuple(spacing))
+ # dir = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]).astype(np.float64)
+ # img_itk.SetDirection(itk.GetMatrixFromArray(dir))
+
+ return img_itk
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+def array_to_sitk(array_input, size=None, origin=None, spacing=None, direction=None, is_vector=False, im_ref=None):
+ """
+ This function takes an array and converts it into a SimpleITK image.
+
+ Parameters
+ ----------
+ array_input: numpy
+ The numpy array to convert to a SimpleITK image
+ size: tuple, optional
+ The size of the array
+ origin: tuple, optional
+ The origin of the data in physical space
+ spacing: tuple, optional
+ Spacing describes the physical sie of each pixel
+ direction: tuple, optional
+ A [nxn] matrix passed as a 1D in a row-major form for a nD matrix (n=[2,3]) to infer the orientation of the data
+ is_vector: bool, optional
+ If isVector is True, then the Image will have a Vector pixel type, and the last dimension of the array will be considered the component index.
+ im_ref: sitk image
+ An empty image with meta information
+
+ Ref: https://github.com/hsokooti/RegNet/blob/46f345d25cd6a1e0ee6f230f64c32bd15b7650d3/functions/image/image_processing.py#L86
+ """
+ import SimpleITK as sitk
+ verbose = False
+
+ if origin is None:
+ origin = [0.0, 0.0, 0.0]
+ if spacing is None:
+ spacing = [1, 1, 1] # the voxel spacing
+ if direction is None:
+ direction = [1, 0, 0, 0, 1, 0, 0, 0, 1]
+ if size is None:
+ size = np.array(array_input).shape
+
+ """
+ ITK has a GetPixel which takes an ITK Index object as an argument, which is ordered as (x,y,z).
+ This is the convention that SimpleITK's Image class uses for the GetPixel method and slicing operator as well.
+ In numpy, an array is indexed in the opposite order (z,y,x)
+ """
+ sitk_output = sitk.GetImageFromArray(np.moveaxis(array_input, [0,1,2], [2,1,0]), isVector=is_vector) # np([H,W,D]) -> np([D,W,H]) -> sitk([H,W,D])
+
+ if im_ref is None:
+ sitk_output.SetOrigin(origin)
+ sitk_output.SetSpacing(spacing)
+ sitk_output.SetDirection(direction)
+ else:
+ sitk_output.SetOrigin(im_ref.GetOrigin())
+ sitk_output.SetSpacing(im_ref.GetSpacing())
+ sitk_output.SetDirection(im_ref.GetDirection())
+
+ return sitk_output
+
+def sitk_to_array(sitk_image):
+ array = sitk.GetArrayFromImage(sitk_image)
+ array = np.moveaxis(array, [0,1,2], [2,1,0]) # [D,W,H] --> [H,W,D]
+ return array
+
+def itk_to_array(sitk_image):
+ import itk
+ array = itk.GetArrayFromImage(sitk_image)
+ array = np.moveaxis(array, [0,1,2], [2,1,0]) # [D,W,H] --> [H,W,D]
+ return array
+
+def resampler_sitk(image_sitk, spacing=None, scale=None, im_ref=None, im_ref_size=None, default_pixel_value=0, interpolator=None, dimension=3):
+ """
+ :param image_sitk: input image
+ :param spacing: desired spacing to set
+ :param scale: if greater than 1 means downsampling, less than 1 means upsampling
+ :param im_ref: if im_ref available, the spacing will be overwritten by the im_ref.GetSpacing()
+ :param im_ref_size: in sikt order: x, y, z
+ :param default_pixel_value:
+ :param interpolator:
+ :param dimension:
+ :return:
+ """
+
+ import math
+ import SimpleITK as sitk
+
+ if spacing is None and scale is None:
+ raise ValueError('spacing and scale cannot be both None')
+ if interpolator is None:
+ interpolator = sitk.sitkBSpline # sitk.Linear, sitk.Nearest
+
+ if spacing is None:
+ spacing = tuple(i * scale for i in image_sitk.GetSpacing())
+ if im_ref_size is None:
+ im_ref_size = tuple(round(i / scale) for i in image_sitk.GetSize())
+
+ elif scale is None:
+ ratio = [spacing_dim / spacing[i] for i, spacing_dim in enumerate(image_sitk.GetSpacing())]
+ if im_ref_size is None:
+ im_ref_size = tuple(math.ceil(size_dim * ratio[i]) - 1 for i, size_dim in enumerate(image_sitk.GetSize()))
+ else:
+ raise ValueError('spacing and scale cannot both have values')
+
+ if im_ref is None:
+ im_ref = sitk.Image(im_ref_size, sitk.sitkInt8)
+ im_ref.SetOrigin(image_sitk.GetOrigin())
+ im_ref.SetDirection(image_sitk.GetDirection())
+ im_ref.SetSpacing(spacing)
+
+
+ identity = sitk.Transform(dimension, sitk.sitkIdentity)
+ resampled_sitk = resampler_by_transform(image_sitk, identity, im_ref=im_ref,
+ default_pixel_value=default_pixel_value,
+ interpolator=interpolator)
+ return resampled_sitk
+
+def resampler_by_transform(im_sitk, dvf_t, im_ref=None, default_pixel_value=0, interpolator=None):
+ import SimpleITK as sitk
+
+ if im_ref is None:
+ im_ref = sitk.Image(dvf_t.GetDisplacementField().GetSize(), sitk.sitkInt8)
+ im_ref.SetOrigin(dvf_t.GetDisplacementField().GetOrigin())
+ im_ref.SetSpacing(dvf_t.GetDisplacementField().GetSpacing())
+ im_ref.SetDirection(dvf_t.GetDisplacementField().GetDirection())
+
+ if interpolator is None:
+ interpolator = sitk.sitkBSpline
+
+ resampler = sitk.ResampleImageFilter()
+ resampler.SetReferenceImage(im_ref)
+ resampler.SetInterpolator(interpolator)
+ resampler.SetDefaultPixelValue(default_pixel_value)
+ resampler.SetTransform(dvf_t)
+
+ # [DEBUG]
+ resampler.SetOutputPixelType(sitk.sitkFloat32)
+
+ out_im = resampler.Execute(im_sitk)
+ return out_im
+
+def save_as_mha_mask(data_dir, patient_id, voxel_mask, voxel_img_headers):
+
+ try:
+ voxel_save_folder = Path(data_dir).joinpath(patient_id)
+ Path(voxel_save_folder).mkdir(parents=True, exist_ok=True)
+ study_id = Path(voxel_save_folder).parts[-1]
+ if config.STR_ACCESSION_PREFIX in str(study_id):
+ study_id = 'acc_' + str(study_id.split(config.STR_ACCESSION_PREFIX)[-1])
+ study_id = study_id.replace('.', '')
+
+
+ orig_origin = voxel_img_headers[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_ORIGIN]
+ orig_pixel_spacing = voxel_img_headers[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_PIXEL_SPACING]
+
+ if len(voxel_mask):
+ voxel_mask_sitk = array_to_sitk(voxel_mask.astype(config.DATATYPE_VOXEL_MASK)
+ , origin=orig_origin, spacing=orig_pixel_spacing)
+ path_voxel_mask = Path(voxel_save_folder).joinpath(config.FILENAME_MASK_TUMOR_3D.format(study_id))
+ sitk.WriteImage(voxel_mask_sitk, str(path_voxel_mask), useCompression=True)
+ else:
+ print (' - [ERROR][save_as_mha_mask()] Patient: ', patient_id)
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+ pass
+
+def save_as_mha(data_dir, patient_id, voxel_img, voxel_img_headers, voxel_mask
+ , voxel_img_reg_dict={}, labelid_midpoint=None
+ , resample_spacing=[]):
+
+ try:
+ """
+ Thi function converts the raw numpy data into a SimpleITK image and saves as .mha
+
+ Parameters
+ ----------
+ data_dir: Path
+ The path where you would like to save the data
+ patient_id: str
+ A reference to the patient
+ voxel_img: numpy
+ A nD numpy array with [H,W,D] format containing radiodensity data in Hounsfield units
+ voxel_img_headers: dict
+ A python dictionary containing information on 'origin' and 'pixel_spacing'
+ voxel_mask: numpy
+ A nD array with labels on each nD voxel
+ resample_save: bool
+ A boolean variable to indicate whether the function should resample
+ """
+
+ # Step 1 - Original Voxel resolution
+ ## Step 1.1 - Create save dir
+ voxel_save_folder = Path(data_dir).joinpath(patient_id)
+ Path(voxel_save_folder).mkdir(parents=True, exist_ok=True)
+ study_id = Path(voxel_save_folder).parts[-1]
+ if config.STR_ACCESSION_PREFIX in str(study_id):
+ study_id = config.FILEPREFIX_ACCENSION + get_acc_id_from_str(study_id)
+
+ ## Step 1.2 - Save img voxel
+ orig_origin = voxel_img_headers[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_ORIGIN]
+ orig_pixel_spacing = voxel_img_headers[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_PIXEL_SPACING]
+ voxel_img_sitk = array_to_sitk(voxel_img.astype(config.DATATYPE_VOXEL_IMG)
+ , origin=orig_origin, spacing=orig_pixel_spacing)
+ path_voxel_img = Path(voxel_save_folder).joinpath(config.FILENAME_IMG_3D.format(study_id))
+ sitk.WriteImage(voxel_img_sitk, str(path_voxel_img), useCompression=True)
+
+ ## Step 1.3 - Save mask voxel
+ voxel_mask_sitk = array_to_sitk(voxel_mask.astype(config.DATATYPE_VOXEL_MASK)
+ , origin=orig_origin, spacing=orig_pixel_spacing)
+ path_voxel_mask = Path(voxel_save_folder).joinpath(config.FILENAME_MASK_3D.format(study_id))
+ sitk.WriteImage(voxel_mask_sitk, str(path_voxel_mask), useCompression=True)
+
+ ## Step 1.4 - Save registration params
+ paths_voxel_reg = {}
+ for study_id in voxel_img_reg_dict:
+ if type(voxel_img_reg_dict[study_id]) == sitk.AffineTransform:
+ path_voxel_reg = str(Path(voxel_save_folder).joinpath('{}.tfm'.format(study_id)))
+ sitk.WriteTransform(voxel_img_reg_dict[study_id], path_voxel_reg)
+ paths_voxel_reg[study_id] = str(path_voxel_reg)
+
+ # Step 2 - Resampled Voxel
+ if len(resample_spacing):
+ new_spacing = resample_spacing
+
+ ## Step 2.1 - Save resampled img
+ voxel_img_sitk_resampled = resampler_sitk(voxel_img_sitk, spacing=new_spacing, interpolator=sitk.sitkBSpline)
+ voxel_img_sitk_resampled = sitk.Cast(voxel_img_sitk_resampled, sitk.sitkInt16)
+ path_voxel_img_resampled = Path(voxel_save_folder).joinpath(config.FILENAME_IMG_RESAMPLED_3D.format(study_id))
+ sitk.WriteImage(voxel_img_sitk_resampled, str(path_voxel_img_resampled), useCompression=True)
+ interpolator_img = 'sitk.sitkBSpline'
+
+ ## Step 2.2 - Save resampled mask
+ voxel_mask_sitk_resampled = []
+ interpolator_mask = ''
+ if 0:
+ voxel_mask_sitk_resampled = resampler_sitk(voxel_mask_sitk, spacing=new_spacing, interpolator=sitk.sitkNearestNeighbor)
+ interpolator_mask = 'sitk.sitkNearestNeighbor'
+
+ elif 1:
+ interpolator_mask = 'sitk.sitkLinear'
+ new_size = voxel_img_sitk_resampled.GetSize()
+ voxel_mask_resampled = np.zeros(new_size)
+ for label_id in np.unique(voxel_mask):
+ if label_id != 0:
+ voxel_mask_singlelabel = copy.deepcopy(voxel_mask).astype(config.DATATYPE_VOXEL_MASK)
+ voxel_mask_singlelabel[voxel_mask_singlelabel != label_id] = 0
+ voxel_mask_singlelabel[voxel_mask_singlelabel == label_id] = 1
+ voxel_mask_singlelabel_sitk = array_to_sitk(voxel_mask_singlelabel
+ , origin=orig_origin, spacing=orig_pixel_spacing)
+ voxel_mask_singlelabel_sitk_resampled = resampler_sitk(voxel_mask_singlelabel_sitk, spacing=new_spacing
+ , interpolator=sitk.sitkLinear)
+ if 0:
+ voxel_mask_singlelabel_sitk_resampled = sitk.Cast(voxel_mask_singlelabel_sitk_resampled, sitk.sitkUInt8)
+ voxel_mask_singlelabel_array_resampled = sitk_to_array(voxel_mask_singlelabel_sitk_resampled)
+ idxs = np.argwhere(voxel_mask_singlelabel_array_resampled > 0)
+ else:
+ voxel_mask_singlelabel_array_resampled = sitk_to_array(voxel_mask_singlelabel_sitk_resampled)
+ idxs = np.argwhere(voxel_mask_singlelabel_array_resampled >= 0.5)
+ voxel_mask_resampled[idxs[:,0], idxs[:,1], idxs[:,2]] = label_id
+
+ voxel_mask_sitk_resampled = array_to_sitk(voxel_mask_resampled
+ , origin=orig_origin, spacing=new_spacing)
+
+ voxel_mask_sitk_resampled = sitk.Cast(voxel_mask_sitk_resampled, sitk.sitkUInt8)
+ path_voxel_mask_resampled = Path(voxel_save_folder).joinpath(config.FILENAME_MASK_RESAMPLED_3D.format(study_id))
+ sitk.WriteImage(voxel_mask_sitk_resampled, str(path_voxel_mask_resampled), useCompression=True)
+
+ # Step 2.3 - Save voxel info for resampled data
+ midpoint_idxs_mean = []
+ if labelid_midpoint is not None:
+ voxel_mask_resampled_data = sitk_to_array(voxel_mask_sitk_resampled)
+ midpoint_idxs = np.argwhere(voxel_mask_resampled_data == labelid_midpoint)
+ midpoint_idxs_mean = np.mean(midpoint_idxs, axis=0)
+
+ path_voxel_headers = Path(voxel_save_folder).joinpath(config.FILENAME_VOXEL_INFO)
+
+ voxel_img_headers[config.TYPE_VOXEL_RESAMPLED] = {config.KEYNAME_MEAN_MIDPOINT : midpoint_idxs_mean.tolist()}
+ voxel_img_headers[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_PIXEL_SPACING] = voxel_img_sitk_resampled.GetSpacing()
+ voxel_img_headers[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_ORIGIN] = voxel_img_sitk_resampled.GetOrigin()
+ voxel_img_headers[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_SHAPE] = voxel_img_sitk_resampled.GetSize()
+ voxel_img_headers[config.TYPE_VOXEL_RESAMPLED][config.TYPE_VOXEL_ORIGSHAPE] = {
+ config.KEYNAME_INTERPOLATOR_IMG: interpolator_img
+ , config.KEYNAME_INTERPOLATOR_MASK: interpolator_mask
+ }
+ if config.KEYNAME_LABEL_OARS in voxel_img_headers[config.TYPE_VOXEL_ORIGSHAPE]:
+ voxel_img_headers[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_LABEL_OARS] = voxel_img_headers[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_LABEL_OARS]
+ if config.KEYNAME_LABEL_MISSING in voxel_img_headers[config.TYPE_VOXEL_ORIGSHAPE]:
+ voxel_img_headers[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_LABEL_MISSING] = voxel_img_headers[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_LABEL_MISSING]
+
+ write_json(voxel_img_headers, path_voxel_headers)
+
+ ## Step 3 - Save img headers
+ path_voxel_headers = Path(voxel_save_folder).joinpath(config.FILENAME_VOXEL_INFO)
+ write_json(voxel_img_headers, path_voxel_headers)
+
+ if len(voxel_img_reg_dict):
+ return str(path_voxel_img), str(path_voxel_mask), paths_voxel_reg
+ else:
+ return str(path_voxel_img), str(path_voxel_mask), {}
+
+ except:
+ print ('\n --------------- [ERROR][save_as_mha()]')
+ traceback.print_exc()
+ print (' --------------- [ERROR][save_as_mha()]\n')
+ pdb.set_trace()
+
+def read_mha(path_file):
+ try:
+
+ if Path(path_file).exists():
+ img_mha = sitk.ReadImage(str(path_file))
+ return img_mha
+ else:
+ print (' - [ERROR][read_mha()] Path issue: path_file: ', path_file)
+ pdb.set_trace()
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+def write_itk(data, filepath):
+ filepath_parent = Path(filepath).parent.absolute()
+ filepath_parent.mkdir(parents=True, exist_ok=True)
+ itk.imwrite(data, str(filepath), compression=True)
+
+def resample_img(path_img, path_new_img, spacing):
+
+ try:
+
+ img_sitk = sitk.ReadImage(str(path_img))
+ img_resampled = resampler_sitk(img_sitk, spacing=spacing, interpolator=sitk.sitkBSpline)
+ img_resampled = sitk.Cast(img_resampled, sitk.sitkInt16)
+
+ path_new_parent = Path(path_new_img).parent.absolute()
+ Path(path_new_parent).mkdir(exist_ok=True, parents=True)
+ sitk.WriteImage(img_resampled, str(path_new_img), useCompression=True)
+
+ return img_resampled
+
+ except:
+ traceback.print_exc()
+
+def resample_mask(path_mask, path_new_mask, spacing, size, labels_allowed = []):
+
+ try:
+
+ # Step 0 - Init
+ mask_array_resampled = np.zeros(size)
+
+ # Step 1 - Read mask
+ mask_sitk = sitk.ReadImage(str(path_mask))
+ mask_spacing = mask_sitk.GetSpacing()
+ mask_origin = mask_sitk.GetOrigin()
+ mask_array = sitk_to_array(mask_sitk)
+
+ # Step 2 - Loop over mask labels
+ for label_id in np.unique(mask_array):
+ if label_id != 0 and label_id in labels_allowed:
+ mask_array_label = np.array(mask_array, copy=True)
+ mask_array_label[mask_array_label != label_id] = 0
+ mask_array_label[mask_array_label == label_id] = 1
+ mask_sitk_label = array_to_sitk(mask_array_label, origin=mask_origin, spacing=mask_spacing)
+ mask_sitk_label_resampled = resampler_sitk(mask_sitk_label, spacing=spacing, interpolator=sitk.sitkLinear)
+ mask_array_label_resampled = sitk_to_array(mask_sitk_label_resampled)
+ idxs = np.argwhere(mask_array_label_resampled >= 0.5)
+ mask_array_resampled[idxs[:,0], idxs[:,1], idxs[:,2]] = label_id
+
+ mask_sitk_resampled = array_to_sitk(mask_array_resampled , origin=mask_origin, spacing=spacing)
+ mask_sitk_resampled = sitk.Cast(mask_sitk_resampled, sitk.sitkUInt8)
+
+ path_new_parent = Path(path_new_mask).parent.absolute()
+ Path(path_new_parent).mkdir(exist_ok=True, parents=True)
+ sitk.WriteImage(mask_sitk_resampled, str(path_new_mask), useCompression=True)
+
+ return mask_sitk_resampled
+
+ except:
+ traceback.print_exc()
+
+############################################################
+# 3D VOXEL RELATED #
+############################################################
+
+def split_into_overlapping_grids(len_total, len_grid, len_overlap, res_type='boundary'):
+ res_range = []
+ res_boundary = []
+
+ A = np.arange(len_total)
+ l_start = 0
+ l_end = len_grid
+ while(l_end < len(A)):
+ res_range.append(np.arange(l_start, l_end))
+ res_boundary.append([l_start, l_end])
+ l_start = l_start + len_grid - len_overlap
+ l_end = l_start + len_grid
+
+ res_range.append(np.arange(len(A)-len_grid, len(A)))
+ res_boundary.append([len(A)-len_grid, len(A)])
+ if res_type == 'boundary':
+ return res_boundary
+ elif res_type == 'range':
+ return res_range
+
+def extract_numpy_from_dcm(patient_dir, skip_slices=None):
+ """
+ Given the path of the folder containing the .dcm files, this function extracts
+ a numpy array by combining them
+
+ Parameters
+ ----------
+ patient_dir: Path
+ The path of the folder containing the .dcm files
+ """
+ slices = []
+ voxel_img_data = []
+ try:
+ import pydicom
+ from pathlib import Path
+
+ if Path(patient_dir).exists():
+ for path_ct in Path(patient_dir).iterdir():
+ try:
+ ds = pydicom.filereader.dcmread(path_ct)
+ slices.append(ds)
+ except:
+ pass
+
+ if len(slices):
+ slices = list(filter(lambda x: 'ImagePositionPatient' in x, slices))
+ slices.sort(key = lambda x: float(x.ImagePositionPatient[2])) # inferior to superior
+ slices_data = []
+ for s_id, s in enumerate(slices):
+ try:
+ if skip_slices is not None:
+ if s_id < skip_slices:
+ continue
+ slices_data.append(s.pixel_array.T)
+ except:
+ pass
+
+ if len(slices_data):
+ voxel_img_data = np.stack(slices_data, axis=-1) # [row, col, plane], [H,W,D]
+
+ except:
+ pass
+ traceback.print_exc()
+ # pdb.set_trace()
+
+ return slices, voxel_img_data
+
+def perform_hu(voxel_img, intercept, slope):
+ """
+ Rescale Intercept != 0, Rescale Slope != 1 or Dose Grid Scaling != 1.
+ The pixel data has not been transformed according to these values.
+ Consider using the module ApplyDicomPixelModifiers after
+ importing the volume to transform the image data appropriately.
+ """
+ try:
+ import copy
+
+ slope = np.float(slope)
+ intercept = np.float(intercept)
+ voxel_img_hu = copy.deepcopy(voxel_img).astype(np.float64)
+
+ # Convert to Hounsfield units (HU)
+ if slope != 1:
+ voxel_img_hu = slope * voxel_img_hu
+
+ voxel_img_hu += intercept
+
+ return voxel_img_hu.astype(config.DATATYPE_VOXEL_IMG)
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+def print_final_message():
+ print ('')
+ print (' - Note: You can view the 3D data in visualizers like MeVisLab or 3DSlicer')
+ print (' - Note: Raw Voxel Data ({}/{}) is in Hounsfield units (HU) with int16 datatype'.format(config.FILENAME_IMG_3D, config.FILENAME_IMG_RESAMPLED_3D))
+ print ('')
+
+def volumize_ct(patient_ct_dir, skip_slices=None):
+ """
+ Params
+ ------
+ patient_ct_dir: Path - contains .dcm files for CT scans
+ """
+ voxel_ct_data = []
+ voxel_ct_headers = {}
+
+ try:
+
+ if Path(patient_ct_dir).exists():
+
+ # Step1 - Accumulate all slices and sort
+ slices, voxel_ct_data = extract_numpy_from_dcm(Path(patient_ct_dir), skip_slices=skip_slices)
+
+ # Step 2 - Get parameters related to the 3D scan
+ if len(voxel_ct_data):
+ slice_thickness = 0
+ if slices[0].SliceThickness < 0.01:
+ slice_thickness = float(slices[1].ImagePositionPatient[2] - slices[0].ImagePositionPatient[2])
+ else:
+ slice_thickness = float(slices[0].SliceThickness)
+
+ voxel_ct_headers[config.KEYNAME_PIXEL_SPACING] = np.array(slices[0].PixelSpacing).tolist() + [slice_thickness]
+ voxel_ct_headers[config.KEYNAME_ORIGIN] = np.array(slices[0].ImagePositionPatient).tolist()
+ voxel_ct_headers[config.KEYNAME_SHAPE] = voxel_ct_data.shape
+ voxel_ct_headers[config.KEYNAME_ZVALS] = [round(s.ImagePositionPatient[2]) for s in slices]
+ voxel_ct_headers[config.KEYNAME_INTERCEPT] = slices[0].RescaleIntercept
+ voxel_ct_headers[config.KEYNAME_SLOPE] = slices[0].RescaleSlope
+
+ # Step 3 - Postprocess the 3D voxel data (data in HU + reso=[config.VOXEL_DATA_RESO])
+ voxel_ct_data = perform_hu(voxel_ct_data, intercept=slices[0].RescaleIntercept, slope=slices[0].RescaleSlope)
+
+ else:
+ print (' - [ERROR][volumize_ct()] Error with numpy extraction of CT volume: ', patient_ct_dir)
+
+ else:
+ print (' - [ERROR][volumize_ct()] Path issue: ', patient_ct_dir)
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+ return voxel_ct_data, voxel_ct_headers
+
+############################################################
+# 3D VOXEL RELATED - RTSTRUCT #
+############################################################
+
+def get_self_label_id(label_name_orig, RAW_TO_SELF_LABEL_MAPPING, LABEL_MAP):
+
+ try:
+ if label_name_orig in RAW_TO_SELF_LABEL_MAPPING:
+ label_name_self = RAW_TO_SELF_LABEL_MAPPING[label_name_orig]
+ if label_name_self in LABEL_MAP:
+ return LABEL_MAP[label_name_self], label_name_self
+ else:
+ return 0, ''
+ else:
+ return 0, ''
+ except:
+ traceback.print_exc()
+ print (' - [ERROR][get_global_label_id()] label_name_orig: ', label_name_orig)
+ # pdb.set_trace()
+ return 0, ''
+
+def extract_contours_from_rtstruct(rtstruct_ds, RAW_TO_SELF_LABEL_MAPPING=None, LABEL_MAP=None, verbose=False):
+ """
+ Params
+ ------
+ rtstruct_ds: pydicom.dataset.FileDataset
+ - .StructureSetROISequence: contains identifying info on all contours
+ - .ROIContourSequence : contains a set of contours for particular ROI
+ """
+
+ # Step 0 - Init
+ contours = []
+ labels_debug = {}
+
+ # Step 1 - Loop and extract all different contours
+ for i in range(len(rtstruct_ds.ROIContourSequence)):
+
+ try:
+
+ # Step 1.1 - Get contour
+ contour = {}
+ contour['name'] = str(rtstruct_ds.StructureSetROISequence[i].ROIName)
+ contour['color'] = list(rtstruct_ds.ROIContourSequence[i].ROIDisplayColor)
+ contour['contours'] = [s.ContourData for s in rtstruct_ds.ROIContourSequence[i].ContourSequence]
+
+ if RAW_TO_SELF_LABEL_MAPPING is None and LABEL_MAP is None:
+ contour['number'] = rtstruct_ds.ROIContourSequence[i].ReferencedROINumber
+ assert contour['number'] == rtstruct_ds.StructureSetROISequence[i].ROINumber
+ elif RAW_TO_SELF_LABEL_MAPPING is None and LABEL_MAP is not None:
+ contour['number'] = int(LABEL_MAP.get(contour['name'], -1))
+ else:
+ label_id, _ = get_self_label_id(contour['name'], RAW_TO_SELF_LABEL_MAPPING, LABEL_MAP)
+ contour['number'] = label_id
+
+ if verbose: print (' - name: ', contour['name'], LABEL_MAP.get(contour['name'], -1))
+
+ # Step 1.2 - Keep or not condition
+ if contour['number'] > 0:
+ contours.append(contour)
+
+ # Step 1.3 - Some debugging
+ labels_debug[contour['name']] = {'id': len(labels_debug) + 1}
+
+ except:
+ if verbose: print ('\n ---------- [ERROR][utils.extract_contours_from_rtstruct()] name: ', rtstruct_ds.StructureSetROISequence[i].ROIName)
+
+ # Step 2 - Order your contours
+ if len(contours):
+ contours = list(sorted(contours, key = lambda obj: obj['number']))
+
+ return contours, labels_debug
+
+def process_contours(contour_obj_list, params, voxel_mask_data, special_processing_ids = []):
+ """
+ Goal: Convert contours to voxel mask
+ Params
+ ------
+ contour_obj_list: [{'name':'', 'contours': [[], [], ... ,[]]}, {}, ..., {}]
+ special_processing_ids: For donut shaped contours
+ """
+ import skimage
+ import skimage.draw
+
+ # Step 1 - Get some position and spacing params
+ z = params[config.KEYNAME_ZVALS]
+ pos_r = params[config.KEYNAME_ORIGIN][1]
+ spacing_r = params[config.KEYNAME_PIXEL_SPACING][1]
+ pos_c = params[config.KEYNAME_ORIGIN][0]
+ spacing_c = params[config.KEYNAME_PIXEL_SPACING][0]
+ shape = params[config.KEYNAME_SHAPE]
+
+ # Step 2 - Loop over contour objects
+ for contour_obj in contour_obj_list:
+
+ try:
+
+ class_id = int(contour_obj['number'])
+
+ if class_id not in special_processing_ids:
+
+ # Step 2.1 - Pick a contour for a particular ROI
+ for c_id, contour in enumerate(contour_obj['contours']):
+ coords = np.array(contour).reshape((-1, 3))
+ if len(coords) > 1:
+
+ # Step 2.2 - Get the z-index of a particular z-value
+ assert np.amax(np.abs(np.diff(coords[:, 2]))) == 0
+ z_index = z.index(pydicom.valuerep.DSfloat(float(round(coords[0, 2]))))
+
+ # Step 2.4 - Polygonize
+ rows = (coords[:, 1] - pos_r) / spacing_r #pixel_idx = f(real_world_idx, ct_resolution)
+ cols = (coords[:, 0] - pos_c) / spacing_c
+ if 1:
+ rr, cc = skimage.draw.polygon(rows, cols) # rr --> y-axis, cc --> x-axis
+ voxel_mask_data[cc, rr, z_index] = class_id
+ else:
+ contour_mask = skimage.draw.polygon2mask(voxel_mask_data.shape[:2], np.hstack((rows[np.newaxis].T, cols[np.newaxis].T)))
+ contour_idxs = np.argwhere(contour_mask > 0)
+ rr, cc = contour_idxs[:,0], contour_idxs[:,1]
+ voxel_mask_data[cc, rr, z_index] = class_id
+
+ # Step 2.99 - Debug
+ # f,axarr = plt.subplots(1,2); axarr[0].scatter(cols, rows); axarr[0].invert_yaxis(); axarr[1].imshow(contour_mask);plt.suptitle('Z={:f}'.format(contour[0, 2])); plt.show()
+
+ else:
+
+ # Step 2.1 - Gather all contours for a particular ROI
+ contours_all = []
+ for contour in contour_obj['contours']: contours_all.extend(contour)
+ contours_all = np.array(contours_all).reshape((-1,3))
+
+ # Step 2.2 - Split contours on the basis of z-value
+ for c_id, contour in enumerate([contours_all[contours_all[:,2]==z_pos] for z_pos in np.unique(contours_all[:,2])]):
+
+ # Step 2.3 - Get the z-index of a particular z-value
+ assert np.amax(np.abs(np.diff(contour[:, 2]))) == 0
+ z_index = z.index(pydicom.valuerep.DSfloat(float(round(contour[0, 2]))))
+
+ # Step 2.4 - Polygonize
+ rows = (contour[:, 1] - pos_r) / spacing_r # pixel_idx = f(real_world_idx, ct_resolution)
+ cols = (contour[:, 0] - pos_c) / spacing_c
+ rr, cc = skimage.draw.polygon(rows, cols)
+ voxel_mask_data[cc, rr, z_index] = class_id
+
+ # Step 2.99 - Debug
+ # if class_id == 8 and c_id > 3 and c_id < 7: # larynx
+ # import matplotlib.pyplot as plt
+ # f,axarr = plt.subplots(1,2); axarr[0].scatter(cols, rows); axarr[0].invert_yaxis(); axarr[1].scatter(cc, rr);axarr[1].invert_yaxis();plt.suptitle('Z={:f}'.format(contour[0, 2])); plt.show()
+ # pdb.set_trace()
+
+ except:
+ print (' --- [ERROR][utils.process_contours()] contour-number:', contour_obj['number'], ' || contour-label:', contour_obj['name'])
+ traceback.print_exc()
+
+ return voxel_mask_data
+
+def volumize_rtstruct(patient_rtstruct_path, params, params_labelinfo):
+ """
+ This function takes a .dcm file (modality=RTSTRUCT) in a folder and converts it into a numpy mask
+
+ Parameters
+ ----------
+ patient_rtstruct_path: Path
+ path to the .dcm file (modality=RTSTRUCT)
+ params: dictionary
+ A python dictionary containing the following keys - ['pixel_spacing', 'origin', 'z_vals'].
+ 'z_vals' is a list of all the depth values of the raw .dcm slices
+ """
+
+ # Step 0 - Init
+ mask_data_oars = []
+ mask_data_external = []
+ mask_headers = {config.KEYNAME_LABEL_OARS:[], config.KEYNAME_LABEL_EXTERNAL:[]}
+ LABEL_MAP_DCM = params_labelinfo.get(config.KEY_LABELMAP_DCM, None)
+ LABEL_MAP_FULL = params_labelinfo.get(config.KEY_LABEL_MAP_FULL, None)
+ LABEL_IDS_SPECIAL = params_labelinfo.get(config.KEY_ARTFORCE_DONUTSHAPED_IDS, None)
+ LABEL_MAP_EXTERNAL = params_labelinfo.get(config.KEY_LABEL_MAP_EXTERNAL, None)
+
+ try:
+
+ ds = pydicom.filereader.dcmread(patient_rtstruct_path)
+
+ if ds.Modality == config.MODALITY_RTSTRUCT:
+
+ if config.KEYNAME_SHAPE in params:
+
+ # Step 1 - Extract all different contours (for OARs)
+ if LABEL_MAP_DCM is not None or LABEL_MAP_FULL is not None:
+ contours_oars, _ = extract_contours_from_rtstruct(ds, LABEL_MAP_DCM, LABEL_MAP_FULL)
+ if len(contours_oars):
+ mask_headers[config.KEYNAME_LABEL_OARS] = [contour['name'] for contour in contours_oars]
+ mask_data_oars = np.zeros(params[config.KEYNAME_SHAPE], dtype=np.uint8)
+ mask_data_oars = process_contours(contours_oars, params, mask_data_oars, special_processing_ids=LABEL_IDS_SPECIAL)
+ else:
+ print (' - [ERROR][volumize_rtstruct()] Len of OAR contours is 0')
+
+ # Step 2 - Extract contour for "External"
+ if LABEL_MAP_EXTERNAL is not None:
+ contours_external, _ = extract_contours_from_rtstruct(ds, LABEL_MAP_DCM, LABEL_MAP_EXTERNAL)
+ if len(contours_external):
+ mask_headers[config.KEYNAME_LABEL_EXTERNAL] = [contour['name'] for contour in contours_external]
+ mask_data_external = np.zeros(params[config.KEYNAME_SHAPE], dtype=np.uint8)
+ mask_data_external = process_contours(contours_external, params, mask_data_external)
+ else:
+ print (' - [ERROR][volumize_rtstruct()] Len of External contours is 0')
+
+ else:
+ print (' - [ERROR][volumize_rtstruct()] Issue with voxel params: ', params)
+
+ else:
+ print (' - [ERROR][volumize_rtstruct()] Could not capture RTSTRUCT file')
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+ return mask_data_oars, mask_data_external, mask_headers
+
+############################################################
+# SAVING/READING RELATED #
+############################################################
+
+def save_csv(filepath, data_array):
+ Path(filepath).parent.absolute().mkdir(parents=True, exist_ok=True)
+ np.savetxt(filepath, data_array, fmt='%s')
+
+def read_csv(filepath):
+ data = np.loadtxt(filepath, dtype='str')
+ return data
+
+def write_json(json_data, json_filepath):
+
+ Path(json_filepath).parent.absolute().mkdir(parents=True, exist_ok=True)
+
+ with open(str(json_filepath), 'w') as fp:
+ json.dump(json_data, fp, indent=4, cls=NpEncoder)
+
+def read_json(json_filepath, verbose=True):
+
+ if Path(json_filepath).exists():
+ with open(str(json_filepath), 'r') as fp:
+ data = json.load(fp)
+ return data
+ else:
+ if verbose: print (' - [ERROR][read_json()] json_filepath does not exist: ', json_filepath)
+ return None
+
+class NpEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, np.integer):
+ return int(obj)
+ elif isinstance(obj, np.floating):
+ return float(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ else:
+ return super(NpEncoder, self).default(obj)
+
+def write_nrrd(filepath, data, spacing):
+ nrrd_headers = {'space':'left-posterior-superior', 'kinds': ['domain', 'domain', 'domain'], 'encoding':'gzip'}
+ space_directions = np.zeros((3,3), dtype=np.float32)
+ space_directions[[0,1,2],[0,1,2]] = np.array(spacing)
+ nrrd_headers['space directions'] = space_directions
+
+ import nrrd
+ nrrd.write(str(filepath), data, nrrd_headers)
+
+############################################################
+# RANDOM #
+############################################################
+def get_name_patient_study_id(meta):
+ try:
+ meta = np.array(meta)
+ meta = str(meta.astype(str))
+
+ meta_split = meta.split('-')
+ name = None
+ patient_id = None
+ study_id = None
+
+ if len(meta_split) == 2:
+ name = meta_split[0]
+ patient_id = meta_split[1]
+ elif len(meta_split) == 3:
+ name = meta_split[0]
+ patient_id = meta_split[1]
+ study_id = meta_split[2]
+
+ return name, patient_id, study_id
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+def get_numbers(string):
+ return ''.join([s for s in string if s.isdigit()])
+
+def create_folds(idxs, folds=4):
+
+ # Step 0 - Init
+ res = {}
+ idxs = np.array(idxs)
+
+ try:
+ for fold in range(folds):
+ print (' - [create_crossvalfolds()] fold: ', fold)
+ val_idxs = list(np.random.choice(np.arange(len((idxs))), size=int(len(idxs)/folds), replace=False))
+ train_idxs = list( set(np.arange(len(idxs))) - set(val_idxs) )
+ val_patients = idxs[val_idxs]
+ train_patients = idxs[train_idxs]
+
+ res[fold+1] = {config.MODE_TRAIN: train_patients, config.MODE_VAL: val_patients}
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+ return res
+
+############################################################
+# DEBUG RELATED #
+############################################################
+
+def benchmark_model(model_time):
+ time.sleep(model_time)
+
+def benchmark(dataset_generator, model_time=0.05):
+
+ import psutil
+ import humanize
+ import pynvml
+ pynvml.nvmlInit()
+
+ device_id = pynvml.nvmlDeviceGetHandleByIndex(0)
+ process = psutil.Process(os.getpid())
+
+ print ('\n - [benchmark()]')
+ t99 = time.time()
+ steps = 0
+ t0 = time.time()
+ for X,_,meta1,meta2 in dataset_generator:
+ t1 = time.time()
+ benchmark_model(model_time)
+ t2 = time.time()
+ steps += 1
+ # print (' - Data Time: ', round(t1 - t0,5),'s || Model Time: ', round(t2-t1,2),'s', '(',X.shape,'), (',meta2.numpy(),')')
+ print (' - Data Time: ', round(t1 - t0,5),'s || Model Time: ', round(t2-t1,2),'s' \
+ # , '(', humanize.naturalsize( process.memory_info().rss),'), ' \
+ # , '(', '%.4f' % (pynvml.nvmlDeviceGetMemoryInfo(device_id).used/1024/1024/1000),'GB), ' \
+ , '(', meta2.numpy(),')'
+ # , np.sum(meta1[:,-9:].numpy(), axis=1)
+ # , meta1[:,1:4].numpy()
+ )
+ t0 = time.time()
+ t99 = time.time() - t99
+ print ('\n - Total steps: ', steps)
+ print (' - Total time for dataset: ', round(t99,2), 's')
+
+def benchmark2(dataset_generator, model_time=0.05):
+
+ import psutil
+ import humanize
+ import pynvml
+ pynvml.nvmlInit()
+
+ device_id = pynvml.nvmlDeviceGetHandleByIndex(0)
+ process = psutil.Process(os.getpid())
+
+ print ('\n - [benchmark2()]')
+ t99 = time.time()
+ steps = 0
+ t0 = time.time()
+ for X_moving,_,_,_,meta1,meta2 in dataset_generator:
+ t1 = time.time()
+ benchmark_model(model_time)
+ t2 = time.time()
+ steps += 1
+ # print (' - Data Time: ', round(t1 - t0,5),'s || Model Time: ', round(t2-t1,2),'s', '(',X_moving.shape,'), (',meta2.numpy(),')')
+ print (' - Data Time: ', round(t1 - t0,5),'s || Model Time: ', round(t2-t1,2),'s' \
+ , '(', humanize.naturalsize( process.memory_info().rss),'), ' \
+ , '(', '%.4f' % (pynvml.nvmlDeviceGetMemoryInfo(device_id).used/1024/1024/1000),'GB), ' \
+ # , '(', meta2.numpy(),')'
+ # , np.sum(meta1[:,-9:].numpy(), axis=1)
+ # , meta1[:,1:4].numpy()
+ )
+ t0 = time.time()
+ t99 = time.time() - t99
+ print ('\n - Total steps: ', steps)
+ print (' - Total time for dataset: ', round(t99,2), 's')
+
+def print_debug_header():
+ print (' ============================================== ')
+ print (' DEBUG ')
+ print (' ============================================== ')
+
+def get_memory(pid):
+ try:
+ process = psutil.Process(pid)
+ return humanize.naturalsize(process.memory_info().rss)
+ except:
+ return '-1'
+
+############################################################
+# DATALOADER RELATED #
+############################################################
+
+def get_dataloader_3D_train(data_dir, dir_type=['train', 'train_additional']
+ , dimension=3, grid=True, crop_init=True, resampled=True, mask_type=config.MASK_TYPE_ONEHOT
+ , transforms=[], filter_grid=False, random_grid=True
+ , parallel_calls=None, deterministic=False
+ , patient_shuffle=True
+ , centred_dataloader_prob=0.0
+ , debug=False
+ , pregridnorm=True):
+
+ debug = False
+ datasets = []
+
+ # Dataset 1
+ for dir_type_ in dir_type:
+
+ # Step 1 - Get dataset class
+ dataset_han_miccai2015 = HaNMICCAI2015Dataset(data_dir=data_dir, dir_type=dir_type_
+ , dimension=dimension, grid=grid, crop_init=crop_init, resampled=resampled, mask_type=mask_type
+ , transforms=transforms, filter_grid=filter_grid, random_grid=random_grid
+ , parallel_calls=parallel_calls, deterministic=deterministic
+ , patient_shuffle=patient_shuffle
+ , centred_dataloader_prob=centred_dataloader_prob
+ , debug=debug
+ , pregridnorm=pregridnorm)
+
+ # Step 2 - Training transforms
+ x_shape_w = dataset_han_miccai2015.w_grid
+ x_shape_h = dataset_han_miccai2015.h_grid
+ x_shape_d = dataset_han_miccai2015.d_grid
+ label_map = dataset_han_miccai2015.LABEL_MAP
+ transforms = [
+ aug.Rotate3DSmall(label_map, mask_type)
+ , aug.Deform2Punt5D((x_shape_h, x_shape_w, x_shape_d), label_map, grid_points=50, stddev=4, div_factor=2, debug=False)
+ , aug.Translate(label_map, translations=[40,40])
+ , aug.Noise(x_shape=(x_shape_h, x_shape_w, x_shape_d, 1), mean=0.0, std=0.1)
+ ]
+ dataset_han_miccai2015.transforms = transforms
+
+ # Step 3 - Training filters for background-only grids
+ if filter_grid:
+ dataset_han_miccai2015.filter = aug.FilterByMask(len(dataset_han_miccai2015.LABEL_MAP), dataset_han_miccai2015.SAMPLER_PERC)
+
+ # Step 4 - Append to list
+ datasets.append(dataset_han_miccai2015)
+
+ dataset = ZipDataset(datasets)
+ return dataset
+
+def get_dataloader_3D_train_eval(data_dir, dir_type=['train_train_additional']
+ , dimension=3, grid=True, crop_init=True, resampled=True, mask_type=config.MASK_TYPE_ONEHOT
+ , transforms=[], filter_grid=False, random_grid=False
+ , parallel_calls=None, deterministic=True
+ , patient_shuffle=False
+ , centred_dataloader_prob=0.0
+ , debug=False
+ , pregridnorm=True):
+
+ datasets = []
+
+ # Dataset 1
+ for dir_type_ in dir_type:
+
+ # Step 1 - Get dataset class
+ dataset_han_miccai2015 = HaNMICCAI2015Dataset(data_dir=data_dir, dir_type=dir_type_
+ , dimension=dimension, grid=grid, crop_init=crop_init, resampled=resampled, mask_type=mask_type
+ , transforms=transforms, filter_grid=filter_grid, random_grid=random_grid
+ , parallel_calls=parallel_calls, deterministic=deterministic
+ , patient_shuffle=patient_shuffle
+ , centred_dataloader_prob=centred_dataloader_prob
+ , debug=debug
+ , pregridnorm=pregridnorm)
+
+ # Step 2 - Training transforms
+ # None
+
+ # Step 3 - Training filters for background-only grids
+ # None
+
+ # Step 4 - Append to list
+ datasets.append(dataset_han_miccai2015)
+
+ dataset = ZipDataset(datasets)
+ return dataset
+
+def get_dataloader_3D_test_eval(data_dir, dir_type=['test_offsite']
+ , dimension=3, grid=True, crop_init=True, resampled=True, mask_type=config.MASK_TYPE_ONEHOT
+ , transforms=[], filter_grid=False, random_grid=False
+ , parallel_calls=None, deterministic=True
+ , patient_shuffle=False
+ , debug=False
+ , pregridnorm=True):
+
+ datasets = []
+
+ # Dataset 1
+ for dir_type_ in dir_type:
+
+ # Step 1 - Get dataset class
+ dataset_han_miccai2015 = HaNMICCAI2015Dataset(data_dir=data_dir, dir_type=dir_type_
+ , dimension=dimension, grid=grid, crop_init=crop_init, resampled=resampled, mask_type=mask_type
+ , transforms=transforms, filter_grid=filter_grid, random_grid=random_grid
+ , parallel_calls=parallel_calls, deterministic=deterministic
+ , patient_shuffle=patient_shuffle
+ , debug=debug
+ , pregridnorm=pregridnorm)
+
+ # Step 2 - Testing transforms
+ # None
+
+ # Step 3 - Testing filters for background-only grids
+ # None
+
+ # Step 4 - Append to list
+ datasets.append(dataset_han_miccai2015)
+
+ dataset = ZipDataset(datasets)
+ return dataset
+
+def get_dataloader_deepmindtcia(data_dir
+ , dir_type=[config.DATALOADER_DEEPMINDTCIA_TEST]
+ , annotator_type=[config.DATALOADER_DEEPMINDTCIA_ONC]
+ , grid=True, crop_init=True, resampled=True, mask_type=config.MASK_TYPE_ONEHOT
+ , transforms=[], filter_grid=False, random_grid=False, pregridnorm=True
+ , parallel_calls=None, deterministic=False
+ , patient_shuffle=True
+ , centred_dataloader_prob = 0.0
+ , debug=False):
+
+ datasets = []
+
+ # Dataset 1
+ for dir_type_ in dir_type:
+
+ for anno_type_ in annotator_type:
+
+ # Step 1 - Get dataset class
+ dataset_han_deepmindtcia = HaNDeepMindTCIADataset(data_dir=data_dir, dir_type=dir_type_, annotator_type=anno_type_
+ , grid=grid, crop_init=crop_init, resampled=resampled, mask_type=mask_type, pregridnorm=pregridnorm
+ , transforms=transforms, filter_grid=filter_grid, random_grid=random_grid
+ , parallel_calls=parallel_calls, deterministic=deterministic
+ , patient_shuffle=patient_shuffle
+ , centred_dataloader_prob=centred_dataloader_prob
+ , debug=debug)
+
+ # Step 2 - Append to list
+ datasets.append(dataset_han_deepmindtcia)
+
+ dataset = ZipDataset(datasets)
+ return dataset
+
diff --git a/src/dataloader/utils_viz.py b/src/dataloader/utils_viz.py
new file mode 100644
index 0000000..54a0bc4
--- /dev/null
+++ b/src/dataloader/utils_viz.py
@@ -0,0 +1,437 @@
+import pdb
+import copy
+import traceback
+import numpy as np
+from pathlib import Path
+
+import matplotlib
+import matplotlib.pyplot as plt
+
+import medloader.dataloader.utils as utils
+import medloader.dataloader.config as config
+
+def cmap_for_dataset(dataset):
+ # dataset = ZipDataset
+ LABEL_COLORS = dataset.get_label_colors()
+ cmap_me = matplotlib.colors.ListedColormap(np.array([*LABEL_COLORS.values()])/255.0)
+ norm = matplotlib.colors.BoundaryNorm(boundaries=range(0,cmap_me.N+1), ncolors=cmap_me.N)
+
+ return cmap_me, norm
+
+############################################################
+# 2D #
+############################################################
+
+def get_info_from_label_id(label_id, LABEL_MAP, LABEL_COLORS=None):
+ """
+ The label_id param has to be greater than 0
+ """
+
+ label_name = [label for label in LABEL_MAP if LABEL_MAP[label] == label_id]
+ if len(label_name):
+ label_name = label_name[0]
+ else:
+ label_name = None
+
+ label_color = None
+ if LABEL_COLORS is not None:
+ label_color = np.array(LABEL_COLORS[label_id])
+ if np.any(label_color > 1):
+ label_color = label_color/255.0
+
+ return label_name, label_color
+
+def viz_slice_raw_batch(slices_raw, slices_mask, meta, dataset):
+ try:
+
+ slices_raw = slices_raw[:,:,:,0]
+ batch_size = slices_raw.shape[0]
+
+ LABEL_COLORS = getattr(config, dataset.name)['LABEL_COLORS']
+ cmap_me = matplotlib.colors.ListedColormap(np.array([*LABEL_COLORS.values()])/255.0)
+ norm = matplotlib.colors.BoundaryNorm(boundaries=range(0,cmap_me.N+1), ncolors=cmap_me.N)
+ # for i in range(cmap_me.N): print (i, ':', norm.__call__(i))
+
+ f, axarr = plt.subplots(2,batch_size, figsize=config.FIGSIZE)
+ if batch_size == 1:
+ axarr = axarr.reshape(2,batch_size)
+
+ # Loop over all batches
+ for batch_id in range(batch_size):
+ slice_raw= slices_raw[batch_id].numpy()
+ idxs_update = np.argwhere(slice_raw < -200)
+ slice_raw[idxs_update[:,0], idxs_update[:,1]] = -200
+ idxs_update = np.argwhere(slice_raw > 400)
+ slice_raw[idxs_update[:,0], idxs_update[:,1]] = 400
+ axarr[1,batch_id].imshow(slice_raw, cmap='gray')
+ axarr[0,batch_id].imshow(slice_raw, cmap='gray')
+
+ # Create a mask
+ label_mask_show = np.zeros(slices_mask[batch_id,:,:,0].shape)
+ for label_id in range(slices_mask.shape[-1]):
+ label_mask = slices_mask[batch_id,:,:,label_id].numpy()
+ if np.sum(label_mask) > 0:
+ label_idxs = np.argwhere(label_mask > 0)
+ if np.any(label_mask_show[label_idxs[:,0], label_idxs[:,1]] > 0):
+ print (' - [utils.viz_slice_raw_batch()] Label overload by label_id:{}'.format(label_id))
+
+ label_mask_show[label_idxs[:,0], label_idxs[:,1]] = label_id
+
+ if len(label_idxs) < 20:
+ label_name, _ = get_info_from_label_id(label_id, dataset)
+ print (' - [utils.viz_slice_raw_batch()] label_id:{} || label: {} || count: {}'.format(label_id+1, label_name, len(label_idxs)))
+
+
+ # Show the mask
+ axarr[0,batch_id].imshow(label_mask_show.astype(np.float32), cmap=cmap_me, norm=norm, interpolation='none', alpha=0.3)
+
+ # Add legend
+ for label_id_mask in np.unique(label_mask_show):
+ if label_id_mask != 0:
+ label, color = get_info_from_label_id(label_id_mask, dataset)
+ axarr[0,batch_id].scatter(0,0, color=color, label=label + '(' + str(int(label_id_mask))+')')
+ leg = axarr[0,batch_id].legend(fontsize=8)
+ for legobj in leg.legendHandles:
+ legobj.set_linewidth(5.0)
+
+ # Add title
+ if meta is not None and dataset is not None:
+ filename = '\n'.join(meta[batch_id].numpy().decode('utf-8').split('-'))
+ axarr[0,batch_id].set_title(dataset.name + '\n' + filename + '\n(BatchId={})'.format(batch_id))
+
+ mng = plt.get_current_fig_manager()
+ try:mng.window.showMaximized()
+ except:pass
+ try:mng.window.maxsize()
+ except:pass
+ plt.show()
+
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+def viz_slice_raw_batch_datasets(slices_raw, slices_mask, meta1, meta2, datasets):
+ try:
+
+ slices_raw = slices_raw[:,:,:,0]
+ batch_size = slices_raw.shape[0]
+
+ f, axarr = plt.subplots(2,batch_size, figsize=config.FIGSIZE)
+ if batch_size == 1:
+ axarr = axarr.reshape(2,batch_size)
+
+ # Loop over all batches
+ for batch_id in range(batch_size):
+ axarr[1,batch_id].imshow(slices_raw[batch_id], cmap='gray')
+ axarr[0,batch_id].imshow(slices_raw[batch_id], cmap='gray')
+
+ meta2_batchid = meta2[batch_id].numpy().decode('utf-8').split('-')[0]
+ dataset_batch = ''
+ for dataset in datasets:
+ if dataset.name == meta2_batchid:
+ dataset_batch = dataset
+
+ LABELID_BACKGROUND = getattr(config, dataset_batch.name)['LABELID_BACKGROUND']
+ LABEL_COLORS = getattr(config, dataset_batch.name)['LABEL_COLORS']
+ cmap_me = matplotlib.colors.ListedColormap(np.array([*LABEL_COLORS.values()])/255.0)
+ norm = matplotlib.colors.BoundaryNorm(boundaries=range(1,12), ncolors=cmap_me.N)
+
+ # Create a mask
+ label_mask_show = np.zeros(slices_mask[batch_id,:,:,0].shape)
+ for label_id in range(slices_mask.shape[-1]):
+ label_mask = slices_mask[batch_id,:,:,label_id].numpy()
+ if np.sum(label_mask) > 0:
+ label_idxs = np.argwhere(label_mask > 0)
+ if label_id == slices_mask.shape[-1] - 1:
+ label_id_actual = LABELID_BACKGROUND
+ else:
+ label_id_actual = label_id + 1
+
+ if np.any(label_mask_show[label_idxs[:,0], label_idxs[:,1]] > 0):
+ print (' - [utils.viz_slice_raw_batch()] Label overload by label_id:{}'.format(label_id))
+ if len(label_idxs) < 20:
+ label_name, _ = get_info_from_label_id(label_id_actual, dataset_batch)
+ print (' - [utils.viz_slice_raw_batch()] label_id:{} || label: {} || count: {}'.format(label_id+1, label_name, len(label_idxs)))
+ label_mask_show[label_idxs[:,0], label_idxs[:,1]] = label_id_actual
+ # label_name, label_color = get_info_from_label_id(label_id+1)
+
+ # Show the mask
+ # axarr[0,batch_id].imshow(label_mask_show, alpha=0.5, cmap=cmap_me, norm=norm)
+ axarr[0,batch_id].imshow(label_mask_show.astype(np.float32), cmap=cmap_me, norm=norm)
+
+ # Add legend
+ for label_id_mask in np.unique(label_mask_show):
+ if label_id_mask != 0:
+ label, color = get_info_from_label_id(label_id_mask, dataset_batch)
+ axarr[0,batch_id].scatter(0,0, color=color, label=label + '(' + str(int(label_id_mask))+')')
+ leg = axarr[0,batch_id].legend(fontsize=8)
+ for legobj in leg.legendHandles:
+ legobj.set_linewidth(5.0)
+
+ # Add title
+ filepath = dataset_batch.paths_raw[meta1[batch_id]]
+ if len(filepath) == 2:
+ filepath = filepath[0] + ' ' + filepath[1]
+ filename = Path(filepath).parts[-1]
+ axarr[0,batch_id].set_title(dataset_batch.name + '\n' + filename + '\n(BatchId={})'.format(batch_id))
+
+ mng = plt.get_current_fig_manager()
+ mng.window.showMaximized()
+ plt.show()
+
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+def viz_slice(slice_raw, slice_mask, meta=None, dataset=None):
+ """
+ - slice_raw: [H,W]
+ - slice_mask: [H,W]
+ """
+ try:
+ import matplotlib.pyplot as plt
+ f, axarr = plt.subplots(2,2, figsize=config.FIGSIZE)
+ axarr[0,0].imshow(slice_raw, cmap='gray')
+ axarr[0,1].imshow(slice_mask)
+ axarr[1,0].hist(slice_raw)
+ axarr[1,1].imshow(slice_raw, cmap='gray')
+ axarr[1,1].imshow(slice_mask, alpha=0.5)
+
+ axarr[0,0].set_title('Raw image')
+ axarr[0,1].set_title('Raw mask')
+ axarr[1,0].set_title('Raw Value Histogram')
+ axarr[1,1].set_title('Raw + Mask')
+
+ if meta is not None and dataset is not None:
+ filename = Path(dataset.paths_raw[meta[0]]).parts[-1]
+ plt.suptitle(filename)
+
+ try:
+ mng = plt.get_current_fig_manager()
+ mng.window.showMaximized()
+ except: pass
+
+ plt.show()
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+############################################################
+# 3D #
+############################################################
+
+def viz_3d_slices(voxel_img, voxel_mask, dataset, meta1, meta2, plots=4):
+ """
+ voxel_img : [B,H,W,D, C=1]
+ voxel_mask: [B,H,W,D, C]
+ """
+ try:
+ cmap_me, norm = cmap_for_dataset(dataset)
+
+ for batch_id in range(voxel_img.shape[0]):
+ voxel_img_batch = voxel_img[batch_id,:,:,:,0]*(dataset.HU_MAX - dataset.HU_MIN) + dataset.HU_MIN
+ voxel_mask_batch = np.argmax(voxel_mask[batch_id,:,:,:,:], axis=-1)
+ height = voxel_img_batch.shape[-1]
+
+ f,axarr = plt.subplots(2,plots)
+ for plt_idx, z_idx in enumerate(np.random.choice(height, plots, replace=False)):
+ axarr[0][plt_idx].imshow(voxel_img_batch[:,:,z_idx], cmap='gray')
+ axarr[1][plt_idx].imshow(voxel_img_batch[:,:,z_idx], cmap='gray', alpha=0.2)
+ axarr[1][plt_idx].imshow(voxel_mask_batch[:,:,z_idx], cmap=cmap_me, norm=norm, interpolation='none')
+ axarr[0][plt_idx].set_title('Slice: {}/{}'.format(z_idx+1, height))
+
+ name, patient_id, study_id = utils.get_name_patient_study_id(meta2[batch_id])
+ if study_id is not None:
+ filename = patient_id + '\n' + study_id
+ else:
+ filename = patient_id
+ plt.suptitle(filename)
+ plt.show()
+
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+def viz_3d_data(voxel_imgs, voxel_masks, meta1, meta2, dataset):
+ """
+ - voxel_imgs: [B,H,W,D]
+ """
+
+ try:
+ """
+ Ref: https://plotly.com/python/visualizing-mri-volume-slices/
+ """
+ import plotly.graph_objects as go
+
+ voxel_imgs = voxel_imgs.numpy()
+ for batch_id, voxel_img in enumerate(voxel_imgs):
+
+ # Plot Volume
+ r,c,d = voxel_img.shape
+ d_ = d - 1
+
+ fig = go.Figure(
+ frames=[
+ go.Frame(data=go.Surface(
+ z=(d_ - k) * np.ones((r, c)),
+ surfacecolor=np.flipud(voxel_img[:,:,d_ - k]),
+ # surfacecolor=voxel_img[:,:,d_ - k],
+ cmin=-1000, cmax=3000
+ )
+ , name=str(k) # you need to name the frame for the animation to behave properly
+ ) for k in range(d)
+ ]
+ )
+
+ # Add data to be displayed before animation starts
+ fig.add_trace(go.Surface(
+ z=d_ * np.ones((r, c))
+ , surfacecolor=np.flipud(voxel_img[:,:,d_])
+ # , surfacecolor=voxel_img[:,:,d_]
+ , colorscale='Gray'
+ # , cmin=config.HU_MIN, cmax=config.HU_MIN
+ , cmin=-1000, cmax=3000
+ , colorbar=dict(thickness=20, ticklen=4)
+ )
+ )
+
+ def frame_args(duration):
+ return {
+ "frame": {"duration": duration},
+ "mode": "immediate",
+ "fromcurrent": True,
+ "transition": {"duration": duration, "easing": "linear"},
+ }
+
+ sliders = [
+ {
+ "pad": {"b": 10, "t": 60},
+ "len": 0.9,
+ "x": 0.1,
+ "y": 0,
+ "steps": [
+ {
+ "args": [[f.name], frame_args(0)],
+ "label": str(k),
+ "method": "animate",
+ }
+ for k, f in enumerate(fig.frames)
+ ],
+ }
+ ]
+
+ fig.update_layout(
+ title='Slices in volumetric data',
+ width=600,
+ height=600,
+ scene=dict(
+ # zaxis=dict(range=[0, d*0.1], autorange=False),
+ zaxis=dict(range=[0, d], autorange=False),
+ aspectratio=dict(x=1, y=1, z=1),
+ ),
+ updatemenus = [
+ {
+ "buttons": [
+ {
+ "args": [None, frame_args(50)],
+ "label": "▶", # play symbol
+ "method": "animate",
+ },
+ {
+ "args": [[None], frame_args(0)],
+ "label": "◼", # pause symbol
+ "method": "animate",
+ },
+ ],
+ "direction": "left",
+ "pad": {"r": 10, "t": 70},
+ "type": "buttons",
+ "x": 0.1,
+ "y": 0,
+ }
+ ],
+ sliders=sliders
+ )
+ fig.show()
+
+ # Plot Mask (3D)
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+def viz_3d_mask(voxel_masks, dataset, meta1, meta2, label_map_full=False):
+ """
+ Expects a [B,H,W,D] shaped mask
+ """
+ try:
+ import plotly.graph_objects as go
+ import skimage
+ import skimage.measure
+
+ LABEL_MAP = dataset.get_label_map(label_map_full=label_map_full)
+ LABEL_COLORS = dataset.get_label_colors()
+ label_ids = meta1[:,-len(LABEL_MAP):].numpy()
+ if np.sum(label_ids) < len(meta1): # if <= batch_size: return 0
+ print (' - [viz_3d_mask()] No labels present: np.sum(label_ids): ', np.sum(label_ids))
+ return 0
+
+ # Step 1 - Loop over all batch_ids
+ print (' ------------------------ VIZ 3D ------------------------')
+ for batch_id, voxel_mask in enumerate(voxel_masks):
+ fig = go.Figure()
+ label_ids = np.unique(voxel_mask)
+ print (' - label_ids: ', label_ids)
+
+ import tensorflow as tf
+ # voxel_mask_ = tf.image.rot90(voxel_mask, k=3)
+ # voxel_mask_ = tf.transpose(tf.reverse(voxel_mask, [0]), [1,0,2])
+ voxel_mask_ = voxel_mask
+ print (' -voxel_mask: ', voxel_mask.shape)
+
+ # Step 2 - Loop over all label_ids
+ for i_, label_id in enumerate(label_ids):
+
+ if label_id == 0 : continue
+ name, color = get_info_from_label_id(label_id, LABEL_MAP, LABEL_COLORS)
+ print (' - label_id: ', label_id, '(',name,')')
+
+ # Get surface information
+ voxel_mask_tmp = np.array(copy.deepcopy(voxel_mask_)).astype(config.DATATYPE_VOXEL_MASK)
+ voxel_mask_tmp[voxel_mask_tmp != label_id] = 0
+ verts, faces, _, _ = skimage.measure.marching_cubes(voxel_mask_tmp, step_size=1)
+
+ # https://plotly.github.io/plotly.py-docs/generated/plotly.graph_objects.Mesh3d.html
+ visible=True
+ fig.add_trace(
+ go.Mesh3d(
+ x=verts[:,0], y=verts[:,1], z=verts[:,2]
+ , i=faces[:,0],j=faces[:,1],k=faces[:,2]
+ , color='rgb({},{},{})'.format(*color)
+ , name=name, showlegend=True
+ , visible=visible
+ # , lighting=go.mesh3d.Lighting(ambient=0.5)
+ )
+ )
+
+ fig.update_layout(
+ scene = dict(
+ xaxis = dict(nticks=10, range=[0,voxel_mask.shape[0]], title='X-axis'),
+ yaxis = dict(nticks=10, range=[0,voxel_mask.shape[1]]),
+ zaxis = dict(nticks=10, range=[0,voxel_mask.shape[2]]),
+ )
+ ,width=700,
+ margin=dict(r=20, l=50, b=10, t=50)
+ )
+ fig.update_layout(legend_title_text='Labels', showlegend=True)
+ fig.update_layout(scene_aspectmode='cube')
+ fig.update_layout(title_text='{} (BatchId={})'.format(meta2, batch_id))
+ fig.show()
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
diff --git a/src/model/__init__.py b/src/model/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/model/losses.py b/src/model/losses.py
new file mode 100644
index 0000000..61b4833
--- /dev/null
+++ b/src/model/losses.py
@@ -0,0 +1,488 @@
+# Import internal libraries
+import src.config as config
+
+# Import external libraries
+import pdb
+import copy
+import time
+import skimage
+import skimage.util
+import traceback
+import numpy as np
+from pathlib import Path
+import SimpleITK as sitk
+import tensorflow as tf
+import tensorflow_addons as tfa
+try: import tensorflow_probability as tfp
+except: pass
+import matplotlib.pyplot as plt
+
+_EPSILON = tf.keras.backend.epsilon()
+MAX_FUNC_TIME = 300
+
+############################################################
+# UTILS #
+############################################################
+
+@tf.function
+def get_mask(mask_1D, Y):
+ # mask_1D: [[1,0,0,0, ...., 1]] - [B,L] something like this
+ # Y : [B,H,W,D,L]
+ if 0:
+ mask = tf.expand_dims(tf.expand_dims(tf.expand_dims(mask_1D, axis=1),axis=1),axis=1) # mask.shape=[B,1,1,1,L]
+ mask = tf.tile(mask, multiples=[1,Y.shape[1],Y.shape[2],Y.shape[3],1]) # mask.shape = [B,H,W,D,L]
+ mask = tf.cast(mask, tf.float32)
+ return mask
+ else:
+ return mask_1D
+
+def get_largest_component(y, verbose=True):
+ """
+ Takes as input predicted predicted probs and returns a binary mask with the largest component for each label
+
+ Params
+ ------
+ y: [H,W,D,L] --> predicted probabilities of np.float32
+ - Ref: https://simpleitk.readthedocs.io/en/v1.2.4/Documentation/docs/source/filters.html
+ """
+
+ try:
+
+ label_count = y.shape[-1]
+ y = np.argmax(y, axis=-1)
+ y = np.concatenate([np.expand_dims(y == label_id, axis=-1) for label_id in range(label_count)],axis=-1)
+ ccFilter = sitk.ConnectedComponentImageFilter()
+
+ for label_id in range(y.shape[-1]):
+
+ if label_id > 0 and label_id not in []: # Note: pointless to do it for background
+ if verbose: t0 = time.time()
+ y_label = y[:,:,:,label_id] # [H,W,D]
+
+ component_img = ccFilter.Execute(sitk.GetImageFromArray(y_label.astype(np.uint8)))
+ component_array = sitk.GetArrayFromImage(component_img) # will contain pseduo-labels given to different components
+ component_count = ccFilter.GetObjectCount()
+
+ if component_count >= 2: # background and foreground
+ component_sizes = np.bincount(component_array.flatten()) # count the voxels belong to different components
+ component_sizes_sorted = np.asarray(sorted(component_sizes, reverse=True))
+ if verbose: print ('\n - [INFO][losses.get_largest_component()] label_id: ', label_id, ' || sizes: ', component_sizes_sorted)
+
+ component_largest_sortedidx = np.argwhere(component_sizes == component_sizes_sorted[1])[0][0] # Note: idx=1 as idx=0 is background # Risk: for optic nerves
+ y_label_mask = (component_array == component_largest_sortedidx).astype(np.float32)
+ y[:,:,:,label_id] = y_label_mask
+ if verbose: print (' - [INFO][losses.get_largest_component()]: label_id: ', label_id, '(',round(time.time() - t0,3),'s)')
+ else:
+ if verbose: print (' - [INFO][losses.get_largest_component()] label_id: ', label_id, ' has only background!!')
+
+ # [TODO]: set other components as background (i.e. label=0)
+
+ y = y.astype(np.float32)
+ return y
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+def remove_smaller_components(y_true, y_pred, meta='', label_ids_small = [], verbose=False):
+ """
+ Takes as input predicted probs and returns a binary mask by removing some of the smallest components for each label
+
+ Params
+ ------
+ y: [H,W,D,L] --> predicted probabilities of np.float32
+ - Ref: https://simpleitk.readthedocs.io/en/v1.2.4/Documentation/docs/source/filters.html
+ """
+ t0 = time.time()
+
+ try:
+
+ # Step 0 - Preprocess by selecting one voxel per class
+ y_pred_copy = copy.deepcopy(y_pred) # [H,W,D,L] with probs
+ label_count = y_pred_copy.shape[-1]
+ y_pred_copy = np.argmax(y_pred_copy, axis=-1) # [H,W,D]
+ y_pred_copy = np.concatenate([np.expand_dims(y_pred_copy == label_id, axis=-1) for label_id in range(label_count)],axis=-1) # [H,W,D,L] as a binary mask
+
+ for label_id in range(y_pred_copy.shape[-1]):
+
+ if label_id > 0: # Note: pointless to do it for background
+ if verbose: t0 = time.time()
+ y_label = y_pred_copy[:,:,:,label_id] # [H,W,D]
+
+ # Step 1 - Get different components
+ ccFilter = sitk.ConnectedComponentImageFilter()
+ component_img = ccFilter.Execute(sitk.GetImageFromArray(y_label.astype(np.uint8)))
+ component_array = sitk.GetArrayFromImage(component_img) # will contain pseduo-labels given to different components
+ component_count = ccFilter.GetObjectCount()
+ component_sizes = np.bincount(component_array.flatten()) # count the voxels belong to different components
+
+ # Step 2 - Evaluate each component
+ if component_count >= 1: # at least a foreground (along with background)
+
+ # Step 2.1 - Sort them on the basis of voxel count
+ component_sizes_sorted = np.asarray(sorted(component_sizes, reverse=True))
+ if verbose:
+ print ('\n - [INFO][losses.get_largest_component()] label_id: ', label_id, ' || sizes: ', component_sizes_sorted)
+ print (' - [INFO][losses.get_largest_component()] unique_comp_labels: ', np.unique(component_array))
+
+ # Step 2.1 - Remove really small components for good Hausdorff calculation
+ component_sizes_sorted_unique = np.unique(component_sizes_sorted[::-1]) # ascending order
+ for comp_size_id, comp_size in enumerate(component_sizes_sorted_unique):
+ if label_id in label_ids_small:
+ if comp_size <= config.MIN_SIZE_COMPONENT:
+ components_labels = [each[0] for each in np.argwhere(component_sizes == comp_size)]
+ for component_label in components_labels:
+ component_array[component_array == component_label] = 0
+ else:
+ if comp_size_id < len(component_sizes_sorted_unique) - 2: # set to 0 except background and foreground
+ components_labels = [each[0] for each in np.argwhere(component_sizes == comp_size)]
+ for component_label in components_labels:
+ component_array[component_array == component_label] = 0
+ if verbose: print (' - [INFO][losses.get_largest_component()] unique_comp_labels: ', np.unique(component_array))
+ y_pred_copy[:,:,:,label_id] = component_array.astype(np.bool).astype(np.float32)
+ if verbose: print (' - [INFO][losses.get_largest_component()] label_id: ', label_id, '(',round(time.time() - t0,3),'s)')
+
+ if 0:
+ # Step 1 - Hausdorff
+ y_true_label = y_true[:,:,:,label_id]
+ y_pred_label = y_pred_copy[:,:,:,label_id]
+
+ hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
+ hausdorff_distance_filter.Execute(sitk.GetImageFromArray(y_true_label.astype(np.uint8)), sitk.GetImageFromArray(y_pred_label.astype(np.uint8)))
+ print (' - hausdorff: ', hausdorff_distance_filter.GetHausdorffDistance())
+
+ # Step 2 - 95% Hausdorff
+ y_true_contour = sitk.LabelContour(sitk.GetImageFromArray(y_true_label.astype(np.uint8)), False)
+ y_pred_contour = sitk.LabelContour(sitk.GetImageFromArray(y_pred_label.astype(np.uint8)), False)
+ y_true_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(y_true_contour, squaredDistance=False, useImageSpacing=True)) # i.e. euclidean distance
+ y_pred_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(y_pred_contour, squaredDistance=False, useImageSpacing=True))
+ dist_y_pred = sitk.GetArrayViewFromImage(y_pred_distance_map)[sitk.GetArrayViewFromImage(y_true_distance_map)==0] # pointless?
+ dist_y_true = sitk.GetArrayViewFromImage(y_true_distance_map)[sitk.GetArrayViewFromImage(y_pred_distance_map)==0]
+ print (' - 95 hausdorff:', np.percentile(dist_y_true,95), np.percentile(dist_y_pred,95))
+
+ else:
+ print (' - [INFO][losses.get_largest_component()] for meta: {} || label_id: {} has only background!! ({}) '.format(meta, label_id, component_sizes))
+
+ if time.time() - t0 > MAX_FUNC_TIME:
+ print (' - [INFO][losses.get_largest_component()] Taking too long: ', round(time.time() - t0,2),'s')
+
+ y = y_pred_copy.astype(np.float32)
+ return y
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+def get_hausdorff(y_true, y_pred, spacing, verbose=False):
+ """
+ :param y_true: [H, W, D, L]
+ :param y_pred: [H, W, D, L]
+ - Ref: https://simpleitk.readthedocs.io/en/master/filters.html?highlight=%20HausdorffDistanceImageFilter()#simpleitk-filters
+ - Ref: http://insightsoftwareconsortium.github.io/SimpleITK-Notebooks/Python_html/34_Segmentation_Evaluation.html
+ """
+
+ try:
+ hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
+
+ hausdorff_labels = []
+ for label_id in range(y_pred.shape[-1]):
+
+ y_true_label = y_true[:,:,:,label_id] # [H,W,D]
+ y_pred_label = y_pred[:,:,:,label_id] # [H,W,D]
+
+ # Calculate loss (over all pixels)
+ if np.sum(y_true_label) > 0:
+ if label_id > 0:
+ try:
+ if np.sum(y_true_label) > 0:
+ y_true_sitk = sitk.GetImageFromArray(y_true_label.astype(np.uint8))
+ y_pred_sitk = sitk.GetImageFromArray(y_pred_label.astype(np.uint8))
+ y_true_sitk.SetSpacing(tuple(spacing))
+ y_pred_sitk.SetSpacing(tuple(spacing))
+ hausdorff_distance_filter.Execute(y_true_sitk, y_pred_sitk)
+ hausdorff_labels.append(hausdorff_distance_filter.GetHausdorffDistance())
+ if verbose: print (' - [INFO][get_hausdorff()] label_id: {} || hausdorff: {}'.format(label_id, hausdorff_labels[-1]))
+ else:
+ hausdorff_labels.append(-1)
+ except:
+ print (' - [ERROR][get_hausdorff()] label_id: {}'.format(label_id))
+ hausdorff_labels.append(-1)
+ else:
+ hausdorff_labels.append(0)
+ else:
+ hausdorff_labels.append(0)
+
+ hausdorff_labels = np.array(hausdorff_labels)
+ hausdorff = np.mean(hausdorff_labels[hausdorff_labels>0])
+ return hausdorff, hausdorff_labels
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+def get_surface_distances(y_true, y_pred, spacing, verbose=False):
+
+ """
+ :param y_true: [H, W, D, L] --> binary mask of np.float32
+ :param y_pred: [H, W, D, L] --> also, binary mask of np.float32
+ - Ref: https://discourse.itk.org/t/computing-95-hausdorff-distance/3832/3
+ - Ref: https://git.lumc.nl/mselbiallyelmahdy/jointregistrationsegmentation-via-crossstetch/-/blob/master/lib/label_eval.py
+ """
+
+ try:
+ hausdorff_labels = []
+ hausdorff95_labels = []
+ msd_labels = []
+ for label_id in range(y_pred.shape[-1]):
+
+ y_true_label = y_true[:,:,:,label_id] # [H,W,D]
+ y_pred_label = y_pred[:,:,:,label_id] # [H,W,D]
+
+ # Calculate loss (over all pixels)
+ if np.sum(y_true_label) > 0:
+
+ if label_id > 0:
+ if np.sum(y_pred_label) > 0:
+ y_true_sitk = sitk.GetImageFromArray(y_true_label.astype(np.uint8))
+ y_pred_sitk = sitk.GetImageFromArray(y_pred_label.astype(np.uint8))
+ y_true_sitk.SetSpacing(tuple(spacing))
+ y_pred_sitk.SetSpacing(tuple(spacing))
+ y_true_contour = sitk.LabelContour(y_true_sitk, False, backgroundValue=0)
+ y_pred_contour = sitk.LabelContour(y_pred_sitk, False, backgroundValue=0)
+
+ y_true_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(y_true_sitk, squaredDistance=False, useImageSpacing=True))
+ y_pred_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(y_pred_sitk, squaredDistance=False, useImageSpacing=True))
+ dist_y_true = sitk.GetArrayFromImage(y_true_distance_map*sitk.Cast(y_pred_contour, sitk.sitkFloat32))
+ dist_y_pred = sitk.GetArrayFromImage(y_pred_distance_map*sitk.Cast(y_true_contour, sitk.sitkFloat32))
+ dist_y_true = dist_y_true[dist_y_true != 0]
+ dist_y_pred = dist_y_pred[dist_y_pred != 0]
+
+ if len(dist_y_true):
+ msd_labels.append(np.mean(np.array(list(dist_y_true) + list(dist_y_pred))))
+ if len(dist_y_true) and len(dist_y_pred):
+ hausdorff_labels.append( np.max( [np.max(dist_y_true), np.max(dist_y_pred)] ) )
+ hausdorff95_labels.append(np.max([np.percentile(dist_y_true, 95), np.percentile(dist_y_pred, 95)]))
+ elif len(dist_y_true) and not len(dist_y_pred):
+ hausdorff_labels.append(np.max(dist_y_true))
+ hausdorff95_labels.append(np.percentile(dist_y_true, 95))
+ elif not len(dist_y_true) and not len(dist_y_pred):
+ hausdorff_labels.append(np.max(dist_y_pred))
+ hausdorff95_labels.append(np.percentile(dist_y_pred, 95))
+ else:
+ hausdorff_labels.append(-1)
+ hausdorff95_labels.append(-1)
+ msd_labels.append(-1)
+
+ else:
+ hausdorff_labels.append(-1)
+ hausdorff95_labels.append(-1)
+ msd_labels.append(-1)
+
+ else:
+ hausdorff_labels.append(0)
+ hausdorff95_labels.append(0)
+ msd_labels.append(0)
+
+ else:
+ hausdorff_labels.append(0)
+ hausdorff95_labels.append(0)
+ msd_labels.append(0)
+
+ hausdorff_labels = np.array(hausdorff_labels)
+ hausdorff_mean = np.mean(hausdorff_labels[hausdorff_labels > 0])
+ hausdorff95_labels = np.array(hausdorff95_labels)
+ hausdorff95_mean = np.mean(hausdorff95_labels[hausdorff95_labels > 0])
+ msd_labels = np.array(msd_labels)
+ msd_mean = np.mean(msd_labels[msd_labels > 0])
+ return hausdorff_mean, hausdorff_labels, hausdorff95_mean, hausdorff95_labels, msd_mean, msd_labels
+
+ except:
+ traceback.print_exc()
+ return -1, [], -1, []
+
+def dice_numpy_slice(y_true_slice, y_pred_slice):
+ """
+ Specifically designed for 2D slices
+
+ Params
+ ------
+ y_true_slice: [H,W]
+ y_pred_slice: [H,W]
+ """
+
+ sum_true = np.sum(y_true_slice)
+ sum_pred = np.sum(y_pred_slice)
+ if sum_true > 0 and sum_pred > 0:
+ num = 2 * np.sum(y_true_slice *y_pred_slice)
+ den = sum_true + sum_pred
+ return num/den
+ elif sum_true > 0 and sum_pred == 0:
+ return 0
+ elif sum_true == 0 and sum_pred > 0:
+ return -0.1
+ else:
+ return -1
+
+def dice_numpy(y_true, y_pred):
+ """
+ :param y_true: [H, W, D, L]
+ :param y_pred: [H, W, D, L]
+ - Ref: V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation
+ - Ref: https://gist.github.com/jeremyjordan/9ea3032a32909f71dd2ab35fe3bacc08#file-soft_dice_loss-py
+ """
+
+ dice_labels = []
+ for label_id in range(y_pred.shape[-1]):
+
+ y_true_label = y_true[:,:,:,label_id] # [H,W,D]
+ y_pred_label = y_pred[:,:,:,label_id] + 1e-8 # [H,W,D]
+
+ # Calculate loss (over all pixels)
+ if np.sum(y_true_label) > 0:
+ num = 2*np.sum(y_true_label * y_pred_label)
+ den = np.sum(y_true_label + y_pred_label)
+ dice_label = num/den
+ else:
+ dice_label = -1.0
+
+ dice_labels.append(dice_label)
+
+ dice_labels = np.array(dice_labels)
+ dice = np.mean(dice_labels[dice_labels>0])
+ return dice, dice_labels
+
+############################################################
+# LOSSES #
+############################################################
+
+@tf.function
+def loss_dice_3d_tf_func(y_true, y_pred, label_mask, weights=[], verbose=False):
+
+ """
+ Calculates soft-DICE loss
+
+ :param y_true: [B, H, W, C, L]
+ :param y_pred: [B, H, W, C, L]
+ :param label_mask: [B,L]
+ - Ref: V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation
+ - Ref: https://gist.github.com/jeremyjordan/9ea3032a32909f71dd2ab35fe3bacc08#file-soft_dice_loss-py
+ """
+
+ # Step 0 - Init
+ dice_labels = []
+ label_mask = tf.cast(label_mask, dtype=tf.float32) # [B,L]
+
+ # Step 1 - Get DICE of each sample in each label
+ y_pred = y_pred + _EPSILON
+ dice_labels = (2*tf.math.reduce_sum(y_true * y_pred, axis=[1,2,3]))/(tf.math.reduce_sum(y_true + y_pred, axis=[1,2,3])) # [B,H,W,D,L] -> [B,L]
+ dice_labels = dice_labels*label_mask # if mask of a label (e.g. background) has been explicitly set to 0, do not consider its loss
+
+ # Step 2 - Mask results on the basis of ground truth availability
+ label_mask = tf.where(tf.math.greater(label_mask,0), label_mask, _EPSILON) # to handle division by 0
+ dice_for_train = None
+ dice_labels_for_train = None
+ dice_labels_for_report = tf.math.reduce_sum(dice_labels,axis=0) / tf.math.reduce_sum(label_mask, axis=0)
+ dice_for_report = tf.math.reduce_mean(tf.math.reduce_sum(dice_labels,axis=1) / tf.math.reduce_sum(label_mask, axis=1))
+
+ # Step 3 - Weighted DICE
+ if len(weights):
+ label_weights = weights / tf.math.reduce_sum(weights) # nomalized
+ dice_labels_w = dice_labels * label_weights
+ dice_labels_for_train = tf.math.reduce_sum(dice_labels_w,axis=0) / tf.math.reduce_sum(label_mask, axis=0)
+ dice_for_train = tf.math.reduce_mean(tf.math.reduce_sum(dice_labels_w,axis=1) / tf.math.reduce_sum(label_mask, axis=1))
+ else:
+ dice_labels_for_train = dice_labels_for_report
+ dice_for_train = dice_for_report
+
+ # Step 4 - Return results
+ return 1.0 - dice_for_train, 1.0 - dice_labels_for_train, dice_for_report, dice_labels_for_report
+
+@tf.function
+def loss_ce_3d_tf_func(y_true, y_pred, label_mask, weights=[], verbose=False):
+ """
+ Calculates cross entropy loss
+
+ Params
+ ------
+ y_true : [B, H, W, C, L]
+ y_pred : [B, H, W, C, L]
+ label_mask: [B,L]
+ - Ref: https://www.dlology.com/blog/multi-class-classification-with-focal-loss-for-imbalanced-datasets/
+ """
+
+ # Step 0 - Init
+ loss_labels = []
+ label_mask = tf.cast(label_mask, dtype=tf.float32)
+ y_pred = y_pred + _EPSILON
+
+ # Step 1.1 - Foreground loss
+ loss_labels = -1.0 * y_true * tf.math.log(y_pred) # [B,H,W,D,L]
+ loss_labels = label_mask * tf.math.reduce_sum(loss_labels, axis=[1,2,3]) # [B,H,W,D,L] --> [B,L]
+
+ # Step 1.2 - Background loss
+ y_pred_neg = 1 - y_pred + _EPSILON
+ loss_labels_neg = -1.0 * (1 - y_true) * tf.math.log(y_pred_neg) # [B,H,W,D,L]
+ loss_labels_neg = label_mask * tf.math.reduce_sum(loss_labels_neg, axis=[1,2,3])
+ loss_labels = loss_labels + loss_labels_neg
+
+ # Step 2 - Mask results on the basis of ground truth availability
+ label_mask = tf.where(tf.math.greater(label_mask,0), label_mask, _EPSILON) # for reasons of division
+ loss_for_train = None
+ loss_labels_for_train = None
+ loss_labels_for_report = tf.math.reduce_sum(loss_labels,axis=0) / tf.math.reduce_sum(label_mask, axis=0)
+ loss_for_report = tf.math.reduce_mean(tf.math.reduce_sum(loss_labels,axis=1) / tf.math.reduce_sum(label_mask, axis=1))
+
+ # Step 3 - Weighted DICE
+ if len(weights):
+ label_weights = weights / tf.math.reduce_sum(weights) # nomalized
+ loss_labels_w = loss_labels * label_weights
+ loss_labels_for_train = tf.math.reduce_sum(loss_labels_w,axis=0) / tf.math.reduce_sum(label_mask, axis=0)
+ loss_for_train = tf.math.reduce_mean(tf.math.reduce_sum(loss_labels_w,axis=1) / tf.math.reduce_sum(label_mask, axis=1))
+ else:
+ loss_labels_for_train = loss_labels_for_report
+ loss_for_train = loss_for_report
+
+ # Step 4 - Return results
+ return loss_for_train, loss_labels_for_train, loss_for_report, loss_labels_for_report
+
+@tf.function
+def loss_cebasic_3d_tf_func(y_true, y_pred, label_mask, weights=[], verbose=False):
+ """
+ Calculates cross entropy loss
+
+ Params
+ ------
+ y_true : [B, H, W, C, L]
+ y_pred : [B, H, W, C, L]
+ label_mask: [B,L]
+ - Ref: https://www.dlology.com/blog/multi-class-classification-with-focal-loss-for-imbalanced-datasets/
+ """
+
+ # Step 0 - Init
+ loss_labels = []
+ label_mask = tf.cast(label_mask, dtype=tf.float32)
+ y_pred = y_pred + _EPSILON
+
+ # Step 1 - Foreground loss
+ loss_labels = -1.0 * y_true * tf.math.log(y_pred) # [B,H,W,D,L]
+ loss_labels = label_mask * tf.math.reduce_sum(loss_labels, axis=[1,2,3]) # [B,H,W,D,L] --> [B,L]
+
+ # Step 2 - Mask results on the basis of ground truth availability
+ label_mask = tf.where(tf.math.greater(label_mask,0), label_mask, _EPSILON) # for reasons of division
+ loss_for_train = None
+ loss_labels_for_train = None
+ loss_labels_for_report = tf.math.reduce_sum(loss_labels,axis=0) / tf.math.reduce_sum(label_mask, axis=0)
+ loss_for_report = tf.math.reduce_mean(tf.math.reduce_sum(loss_labels,axis=1) / tf.math.reduce_sum(label_mask, axis=1))
+
+ # Step 3 - Weighted DICE
+ if len(weights):
+ label_weights = weights / tf.math.reduce_sum(weights) # nomalized
+ loss_labels_w = loss_labels * label_weights
+ loss_labels_for_train = tf.math.reduce_sum(loss_labels_w,axis=0) / tf.math.reduce_sum(label_mask, axis=0)
+ loss_for_train = tf.math.reduce_mean(tf.math.reduce_sum(loss_labels_w,axis=1) / tf.math.reduce_sum(label_mask, axis=1))
+ else:
+ loss_labels_for_train = loss_labels_for_report
+ loss_for_train = loss_for_report
+
+ # Step 4 - Return results
+ return loss_for_train, loss_labels_for_train, loss_for_report, loss_labels_for_report
\ No newline at end of file
diff --git a/src/model/models.py b/src/model/models.py
new file mode 100644
index 0000000..dbb27cb
--- /dev/null
+++ b/src/model/models.py
@@ -0,0 +1,572 @@
+# Import internal libraries
+import src.config as config
+
+# Import external libraries
+import pdb
+import traceback
+import numpy as np
+import tensorflow as tf
+import tensorflow_addons as tfa
+import tensorflow_probability as tfp
+
+############################################################
+# 3D MODEL BLOCKS #
+############################################################
+
+class ConvBlock3D(tf.keras.layers.Layer):
+
+ def __init__(self, filters, kernel_size=(3,3,3), strides=(1, 1, 1), padding='same'
+ , dilation_rate=(1,1,1)
+ , activation='relu'
+ , trainable=False
+ , pool=False
+ , name=''):
+ super(ConvBlock3D, self).__init__(name='{}_ConvBlock3D'.format(name))
+
+ # Step 0 - Init
+ self.pool = pool
+ self.filters = filters
+ self.trainable = trainable
+ if type(filters) == int:
+ filters = [filters]
+
+ # Step 1 - Conv Blocks
+ self.conv_layer = tf.keras.Sequential()
+ for filter_id, filter_count in enumerate(filters):
+ self.conv_layer.add(
+ tf.keras.layers.Conv3D(filters=filter_count, kernel_size=kernel_size, strides=strides, padding=padding
+ , dilation_rate=dilation_rate
+ , activation=activation
+ , name='Conv_{}'.format(filter_id))
+ )
+ self.conv_layer.add(tf.keras.layers.BatchNormalization(trainable=trainable))
+ # https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization
+ ## with the argument training=False (which is the default), the layer normalizes its output using a moving average of the mean and standard deviation of the batches it has seen during training
+ # if filter_id == 0 and dropout is not None:
+ # self.conv_layer.add(tf.keras.layers.Dropout(rate=dropout, name='DropOut'))
+
+ # Step 2 - Pooling Blocks
+ if self.pool:
+ self.pool_layer = tf.keras.layers.MaxPooling3D((2,2,2), strides=(2,2,2), name='Pool')
+
+ @tf.function
+ def call(self, x):
+
+ x = self.conv_layer(x)
+
+ if self.pool:
+ return x, self.pool_layer(x)
+ else:
+ return x
+
+class ConvBlock3DSERes(tf.keras.layers.Layer):
+ """
+ For channel-wise attention
+ """
+
+ def __init__(self, filters, kernel_size=(3,3,3), strides=(1, 1, 1), padding='same'
+ , dilation_rate=(1,1,1)
+ , activation='relu'
+ , trainable=False
+ , pool=False
+ , squeeze_ratio=None
+ , init=False
+ , name=''):
+
+ super(ConvBlock3DSERes, self).__init__(name='{}_ConvBlock3DSERes'.format(name))
+
+ # Step 0 - Init
+ self.init = init
+ self.trainable = trainable
+
+ # Step 1 - Init (to get equivalent feature map count)
+ if self.init:
+ self.convblock_filterequalizer = tf.keras.layers.Conv3D(filters=filters[0], kernel_size=(1,1,1), strides=(1,1,1), padding='same'
+ , activation='relu'
+ )
+
+ # Step 2- Conv Block
+ self.convblock_res = ConvBlock3D(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding
+ , dilation_rate=dilation_rate
+ , activation=activation
+ , trainable=trainable
+ , pool=False
+ , name=name
+ )
+
+ # Step 3 - Squeeze Block
+ """
+ Ref: https://github.com/imkhan2/se-resnet/blob/master/se_resnet.py
+ """
+ self.squeeze_ratio = squeeze_ratio
+ if self.squeeze_ratio is not None:
+ self.seblock = tf.keras.Sequential()
+ self.seblock.add(tf.keras.layers.GlobalAveragePooling3D())
+ self.seblock.add(tf.keras.layers.Reshape(target_shape=(1,1,1,filters[0])))
+ self.seblock.add(tf.keras.layers.Conv3D(filters=filters[0]//squeeze_ratio, kernel_size=(1,1,1), strides=(1,1,1), padding='same'
+ , activation='relu'))
+ self.seblock.add(tf.keras.layers.Conv3D(filters=filters[0], kernel_size=(1,1,1), strides=(1,1,1), padding='same'
+ , activation='sigmoid'))
+
+ self.pool = pool
+ if self.pool:
+ self.pool_layer = tf.keras.layers.MaxPooling3D((2,2,2), strides=(2,2,2), name='{}_Pool'.format(name))
+
+ @tf.function
+ def call(self, x):
+
+ if self.init:
+ x = self.convblock_filterequalizer(x)
+
+ x_res = self.convblock_res(x)
+
+ if self.squeeze_ratio is not None:
+ x_se = self.seblock(x_res) # squeeze and then get excitation factor
+ x_res = tf.math.multiply(x_res, x_se) # excited block
+
+ y = x + x_res
+
+ if self.pool:
+ return y, self.pool_layer(y)
+ else:
+ return y
+
+
+class ConvBlock3DDropOut(tf.keras.layers.Layer):
+
+ def __init__(self, filters, kernel_size=(3,3,3), strides=(1, 1, 1), padding='same'
+ , dilation_rate=(1,1,1)
+ , activation='relu'
+ , trainable=False
+ , dropout=None
+ , pool=False
+ , name=''):
+ super(ConvBlock3DDropOut, self).__init__(name='{}_ConvBlock3DDropOut'.format(name))
+
+ self.pool = pool
+ self.filters = filters
+ self.trainable = trainable
+
+ if type(filters) == int:
+ filters = [filters]
+
+ self.conv_layer = tf.keras.Sequential()
+ for filter_id, filter_count in enumerate(filters):
+
+ if dropout is not None:
+ self.conv_layer.add(tf.keras.layers.Dropout(rate=dropout, name='DropOut_{}'.format(filter_id))) # before every conv layer (could also be after every layer?)
+
+ self.conv_layer.add(
+ tf.keras.layers.Conv3D(filters=filter_count, kernel_size=kernel_size, strides=strides, padding=padding
+ , dilation_rate=dilation_rate
+ , activation=activation
+ , name='Conv_{}'.format(filter_id))
+ )
+ self.conv_layer.add(tf.keras.layers.BatchNormalization(trainable=trainable))
+
+ if self.pool:
+ self.pool_layer = tf.keras.layers.MaxPooling3D((2,2,2), strides=(2,2,2), name='Pool')
+
+ @tf.function
+ def call(self, x):
+
+ x = self.conv_layer(x)
+
+ if self.pool:
+ return x, self.pool_layer(x)
+ else:
+ return x
+
+class ConvBlock3DSEResDropout(tf.keras.layers.Layer):
+
+ def __init__(self, filters, kernel_size=(3,3,3), strides=(1, 1, 1), padding='same'
+ , dilation_rate=(1,1,1)
+ , activation='relu'
+ , trainable=False
+ , dropout=None
+ , pool=False
+ , squeeze_ratio=None
+ , init=False
+ , name=''):
+
+ super(ConvBlock3DSEResDropout, self).__init__(name='{}_ConvBlock3DSEResDropout'.format(name))
+
+ self.init = init
+ self.trainable = trainable
+
+ if self.init:
+ self.convblock_filterequalizer = tf.keras.layers.Conv3D(filters=filters[0], kernel_size=(1,1,1), strides=(1,1,1), padding='same'
+ , activation='relu'
+ )
+
+ self.convblock_res = ConvBlock3DDropOut(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding
+ , dilation_rate=dilation_rate
+ , activation=activation
+ , trainable=trainable
+ , dropout=dropout
+ , pool=False
+ , name=name
+ )
+
+ """
+ Ref: https://github.com/imkhan2/se-resnet/blob/master/se_resnet.py
+ """
+ self.squeeze_ratio = squeeze_ratio
+ if self.squeeze_ratio is not None:
+ self.seblock = tf.keras.Sequential()
+ self.seblock.add(tf.keras.layers.GlobalAveragePooling3D())
+ self.seblock.add(tf.keras.layers.Reshape(target_shape=(1,1,1,filters[0])))
+ self.seblock.add(tf.keras.layers.Conv3D(filters=filters[0]//squeeze_ratio, kernel_size=(1,1,1), strides=(1,1,1), padding='same'
+ , activation='relu'))
+ self.seblock.add(tf.keras.layers.Conv3D(filters=filters[0], kernel_size=(1,1,1), strides=(1,1,1), padding='same'
+ , activation='sigmoid'))
+
+ self.pool = pool
+ if self.pool:
+ self.pool_layer = tf.keras.layers.MaxPooling3D((2,2,2), strides=(2,2,2), name='{}_Pool'.format(name))
+
+ @tf.function
+ def call(self, x):
+
+ if self.init:
+ x = self.convblock_filterequalizer(x)
+
+ x_res = self.convblock_res(x)
+
+ if self.squeeze_ratio is not None:
+ x_se = self.seblock(x_res) # squeeze and then get excitation factor
+ x_res = tf.math.multiply(x_res, x_se) # excited block
+
+ y = x + x_res
+
+ if self.pool:
+ return y, self.pool_layer(y)
+ else:
+ return y
+
+
+class UpConvBlock3D(tf.keras.layers.Layer):
+
+ def __init__(self, filters, kernel_size=(2,2,2), strides=(2, 2, 2), padding='same', trainable=False, name=''):
+ super(UpConvBlock3D, self).__init__(name='{}_UpConv3D'.format(name))
+
+ self.trainable = trainable
+ self.upconv_layer = tf.keras.Sequential()
+ self.upconv_layer.add(tf.keras.layers.Conv3DTranspose(filters, kernel_size, strides, padding=padding
+ , activation='relu'
+ , kernel_regularizer=None
+ , name='UpConv_{}'.format(self.name))
+ )
+ # self.upconv_layer.add(tf.keras.layers.BatchNormalization(trainable=trainable))
+
+ @tf.function
+ def call(self, x):
+ return self.upconv_layer(x)
+
+
+class ConvBlock3DFlipOut(tf.keras.layers.Layer):
+ """
+ Ref
+ - https://www.tensorflow.org/probability/api_docs/python/tfp/layers/Convolution3DFlipout
+ """
+
+ def __init__(self, filters, kernel_size=(3,3,3), strides=(1, 1, 1), padding='same'
+ , dilation_rate=(1,1,1)
+ , activation='relu'
+ , trainable=False
+ , pool=False
+ , name=''):
+ super(ConvBlock3DFlipOut, self).__init__(name='{}ConvBlock3DFlipOut'.format(name))
+
+ self.pool = pool
+ self.filters = filters
+
+ if type(filters) == int:
+ filters = [filters]
+
+ self.conv_layer = tf.keras.Sequential()
+ for filter_id, filter_count in enumerate(filters):
+ self.conv_layer.add(
+ tfp.layers.Convolution3DFlipout(filters=filter_count, kernel_size=kernel_size, strides=strides, padding=padding
+ , dilation_rate=dilation_rate
+ , activation=activation
+ # , kernel_prior_fn=?
+ , name='Conv3DFlip_{}'.format(filter_id))
+ )
+ self.conv_layer.add(tfa.layers.GroupNormalization(groups=filter_count//2, trainable=trainable))
+
+ if self.pool:
+ self.pool_layer = tf.keras.layers.MaxPooling3D((2,2,2), strides=(2,2,2), name='Pool')
+
+ def call(self, x):
+
+ x = self.conv_layer(x)
+
+ if self.pool:
+ return x, self.pool_layer(x)
+ else:
+ return x
+
+class ConvBlock3DSEResFlipOut(tf.keras.layers.Layer):
+ """
+ For channel-wise attention
+ """
+
+ def __init__(self, filters, kernel_size=(3,3,3), strides=(1, 1, 1), padding='same'
+ , dilation_rate=(1,1,1)
+ , activation='relu'
+ , trainable=False
+ , pool=False
+ , squeeze_ratio=None
+ , init=False
+ , name=''):
+
+ super(ConvBlock3DSEResFlipOut, self).__init__(name='{}ConvBlock3DSEResFlipOut'.format(name))
+
+ self.init = init
+ if self.init:
+ self.convblock_filterequalizer = tf.keras.layers.Conv3D(filters=filters[0], kernel_size=(1,1,1), strides=(1,1,1), padding='same'
+ , activation='relu'
+ , kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None)
+
+ self.convblock_res = ConvBlock3DFlipOut(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding
+ , dilation_rate=dilation_rate
+ , activation=activation
+ , trainable=trainable
+ , pool=False
+ , name=name
+ )
+
+ """
+ Ref: https://github.com/imkhan2/se-resnet/blob/master/se_resnet.py
+ """
+ self.squeeze_ratio = squeeze_ratio
+ if self.squeeze_ratio is not None:
+ self.seblock = tf.keras.Sequential()
+ self.seblock.add(tf.keras.layers.GlobalAveragePooling3D())
+ self.seblock.add(tf.keras.layers.Reshape(target_shape=(1,1,1,filters[0])))
+ self.seblock.add(tf.keras.layers.Conv3D(filters=filters[0]//squeeze_ratio, kernel_size=(1,1,1), strides=(1,1,1), padding='same'
+ , activation='relu'
+ , kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None))
+ self.seblock.add(tf.keras.layers.Conv3D(filters=filters[0], kernel_size=(1,1,1), strides=(1,1,1), padding='same'
+ , activation='sigmoid'
+ , kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None))
+
+ self.pool = pool
+ if self.pool:
+ self.pool_layer = tf.keras.layers.MaxPooling3D((2,2,2), strides=(2,2,2), name='{}_Pool'.format(name))
+
+ def call(self, x):
+
+ if self.init:
+ x = self.convblock_filterequalizer(x)
+
+ x_res = self.convblock_res(x)
+
+ if self.squeeze_ratio is not None:
+ x_se = self.seblock(x_res) # squeeze and then get excitation factor
+ x_res = tf.math.multiply(x_res, x_se) # excited block
+
+ y = x + x_res
+
+ if self.pool:
+ return y, self.pool_layer(y)
+ else:
+ return y
+
+
+############################################################
+# 3D MODELS #
+############################################################
+
+class ModelFocusNetDropOut(tf.keras.Model):
+
+ def __init__(self, class_count, trainable=False, verbose=False):
+ """
+ Params
+ ------
+ class_count: to know how many class activation maps are needed
+ trainable: set to False when BNorm does not need to have its parameters recalculated e.g. during testing
+ """
+ super(ModelFocusNetDropOut, self).__init__(name='ModelFocusNetDropOut')
+
+ # Step 0 - Init
+ self.verbose = verbose
+ self.trainable = trainable
+
+ dropout = [None, 0.25, 0.25, 0.25, 0.25, 0.25, None, None]
+ filters = [[16,16], [32,32]]
+ dilation_xy = [1, 2, 3, 6, 12, 18]
+ dilation_z = [1, 1, 1, 1, 1 , 1]
+
+
+ # Se-Res Blocks
+ self.convblock1 = ConvBlock3DSEResDropout(filters=filters[0], kernel_size=(3,3,1), dilation_rate=(dilation_xy[0], dilation_xy[0], dilation_z[0]), trainable=trainable, dropout=dropout[0], pool=True , squeeze_ratio=2, name='Block1') # Dim/2 (e.g. 240/2=120)(rp=(3,5,10),(1,1,2))
+ self.convblock2 = ConvBlock3DSEResDropout(filters=filters[0], kernel_size=(3,3,3), dilation_rate=(dilation_xy[1], dilation_xy[1], dilation_z[1]), trainable=trainable, dropout=dropout[0], pool=False, squeeze_ratio=2, name='Block2') # Dim/2 (e.g. 240/2=120)(rp=(14,18) ,(4,6))
+
+ # Dense ASPP
+ self.convblock3 = ConvBlock3DDropOut(filters=filters[1], dilation_rate=(dilation_xy[2], dilation_xy[2], dilation_z[2]), trainable=trainable, dropout=dropout[1], pool=False, name='Block3_ASPP') # Dim/2 (e.g. 240/2=120) (rp=(24,30),(8,10))
+ self.convblock4 = ConvBlock3DDropOut(filters=filters[1], dilation_rate=(dilation_xy[3], dilation_xy[3], dilation_z[3]), trainable=trainable, dropout=dropout[2], pool=False, name='Block4_ASPP') # Dim/2 (e.g. 240/2=120) (rp=(42,54),(12,14))
+ self.convblock5 = ConvBlock3DDropOut(filters=filters[1], dilation_rate=(dilation_xy[4], dilation_xy[4], dilation_z[4]), trainable=trainable, dropout=dropout[3], pool=False, name='Block5_ASPP') # Dim/2 (e.g. 240/2=120) (rp=(78,102),(16,18))
+ self.convblock6 = ConvBlock3DDropOut(filters=filters[1], dilation_rate=(dilation_xy[5], dilation_xy[5], dilation_z[5]), trainable=trainable, dropout=dropout[4], pool=False, name='Block6_ASPP') # Dim/2 (e.g. 240/2=120) (rp=(138,176),(20,22))
+ self.convblock7 = ConvBlock3DSEResDropout(filters=filters[1], dilation_rate=(dilation_xy[0], dilation_xy[0], dilation_z[0]), trainable=trainable, dropout=dropout[5], pool=False, squeeze_ratio=2, init=True, name='Block7') # Dim/2 (e.g. 240/2=120) (rp=(178,180),(24,26))
+
+ # Upstream
+ self.convblock8 = ConvBlock3DSEResDropout(filters=filters[1], dilation_rate=(dilation_xy[0], dilation_xy[0], dilation_z[0]), trainable=trainable, dropout=dropout[6], pool=False, squeeze_ratio=2, init=True, name='Block8') # Dim/2 (e.g. 240/2=120)
+
+ self.upconvblock9 = UpConvBlock3D(filters=filters[0][0], trainable=trainable, name='Block9_1') # Dim/1 (e.g. 240/1=240)
+ self.convblock9 = ConvBlock3DSEResDropout(filters=filters[0], dilation_rate=(dilation_xy[0], dilation_xy[0], dilation_z[0]), trainable=trainable, dropout=dropout[7], pool=False, squeeze_ratio=2, init=True, name='Block9') # Dim/1 (e.g. 240/1=240)
+
+ # Final
+ self.convblock10 = tf.keras.layers.Conv3D(filters=class_count, strides=(1,1,1), kernel_size=(3,3,3), padding='same'
+ , dilation_rate=(1,1,1)
+ , activation='softmax'
+ , name='Block10')
+
+ @tf.function
+ def call(self, x):
+
+ # SE-Res Blocks
+ conv1, pool1 = self.convblock1(x)
+ conv2 = self.convblock2(pool1)
+
+ # Dense ASPP
+ conv3 = self.convblock3(conv2)
+ conv3_op = tf.concat([conv2, conv3], axis=-1)
+
+ conv4 = self.convblock4(conv3_op)
+ conv4_op = tf.concat([conv3_op, conv4], axis=-1)
+
+ conv5 = self.convblock5(conv4_op)
+ conv5_op = tf.concat([conv4_op, conv5], axis=-1)
+
+ conv6 = self.convblock6(conv5_op)
+ conv6_op = tf.concat([conv5_op, conv6], axis=-1)
+
+ conv7 = self.convblock7(conv6_op)
+
+ # Upstream
+ conv8 = self.convblock8(tf.concat([pool1, conv7], axis=-1))
+
+ up9 = self.upconvblock9(conv8)
+ conv9 = self.convblock9(tf.concat([conv1, up9], axis=-1))
+
+ # Final
+ conv10 = self.convblock10(conv9)
+
+ if self.verbose:
+ print (' ---------- Model: ', self.name)
+ print (' - x: ', x.shape)
+ print (' - conv1: ', conv1.shape)
+ print (' - conv2: ', conv2.shape)
+ print (' - conv3_op: ', conv3_op.shape)
+ print (' - conv4_op: ', conv4_op.shape)
+ print (' - conv5_op: ', conv5_op.shape)
+ print (' - conv6_op: ', conv6_op.shape)
+ print (' - conv7: ', conv7.shape)
+ print (' - conv8: ', conv8.shape)
+ print (' - conv9: ', conv9.shape)
+ print (' - conv10: ', conv10.shape)
+
+
+ return conv10
+
+class ModelFocusNetFlipOut(tf.keras.Model):
+
+ def __init__(self, class_count, trainable=False, verbose=False):
+ super(ModelFocusNetFlipOut, self).__init__(name='ModelFocusNetFlipOut')
+
+ # Step 0 - Init
+ self.verbose = verbose
+
+ filters = [[16,16], [32,32]]
+ dilation_xy = [1, 2, 3, 6, 12, 18]
+ dilation_z = [1, 1, 1, 1, 1 , 1]
+
+ # Se-Res Blocks
+ self.convblock1 = ConvBlock3DSERes(filters=filters[0], kernel_size=(3,3,1), dilation_rate=(dilation_xy[0], dilation_xy[0], dilation_z[0]), trainable=trainable, pool=True , squeeze_ratio=2, name='Block1') # Dim/2 (e.g. 96/2=48, 240/2=120)(rp=(3,5,10),(1,1,2))
+ self.convblock2 = ConvBlock3DSERes(filters=filters[0] , dilation_rate=(dilation_xy[1], dilation_xy[1], dilation_z[1]), trainable=trainable, pool=False, squeeze_ratio=2, name='Block2') # Dim/2 (e.g. 96/2=48, 240/2=120)(rp=(14,18) ,(4,6))
+
+ # Dense ASPP
+ self.convblock3 = ConvBlock3DFlipOut(filters=filters[1], dilation_rate=(dilation_xy[2], dilation_xy[2], dilation_z[2]), trainable=trainable, pool=False, name='Block3_ASPP') # Dim/2 (e.g. 96/2=48, 240/2=120) (rp=(24,30),(16,18))
+ self.convblock4 = ConvBlock3DFlipOut(filters=filters[1], dilation_rate=(dilation_xy[3], dilation_xy[3], dilation_z[3]), trainable=trainable, pool=False, name='Block4_ASPP') # Dim/2 (e.g. 96/2=48, 240/2=120) (rp=(42,54),(20,22))
+ self.convblock5 = ConvBlock3DFlipOut(filters=filters[1], dilation_rate=(dilation_xy[4], dilation_xy[4], dilation_z[4]), trainable=trainable, pool=False, name='Block5_ASPP') # Dim/2 (e.g. 96/2=48, 240/2=120) (rp=(78,102),(24,26))
+ self.convblock6 = ConvBlock3DFlipOut(filters=filters[1], dilation_rate=(dilation_xy[5], dilation_xy[5], dilation_z[5]), trainable=trainable, pool=False, name='Block6_ASPP') # Dim/2 (e.g. 96/2=48, 240/2=120) (rp=(138,176),(28,30))
+ self.convblock7 = ConvBlock3DSEResFlipOut(filters=filters[1], dilation_rate=(dilation_xy[0], dilation_xy[0], dilation_z[0]), trainable=trainable, pool=False, squeeze_ratio=2, init=True, name='Block7') # Dim/2 (e.g. 96/2=48) (rp=(176,180),(32,34))
+
+ # Upstream
+ self.convblock8 = ConvBlock3DSERes(filters=filters[1], dilation_rate=(dilation_xy[0], dilation_xy[0], dilation_z[0]), trainable=trainable, pool=False, squeeze_ratio=2, init=True, name='Block8') # Dim/2 (e.g. 96/2=48)
+
+ self.upconvblock9 = UpConvBlock3D(filters=filters[0][0], trainable=trainable, name='Block9_1') # Dim/1 (e.g. 96/1 = 96)
+ self.convblock9 = ConvBlock3DSERes(filters=filters[0], dilation_rate=(dilation_xy[0], dilation_xy[0], dilation_z[0]), trainable=trainable, pool=False, squeeze_ratio=2, init=True, name='Block9') # Dim/1 (e.g. 96/1 = 96)
+
+ self.convblock10 = tf.keras.layers.Conv3D(filters=class_count, strides=(1,1,1), kernel_size=(3,3,3), padding='same'
+ , dilation_rate=(1,1,1)
+ , activation='softmax'
+ , name='Block10')
+
+ # @tf.function (cant call model.losses if this is enabled)
+ def call(self, x):
+
+ # SE-Res Blocks
+ conv1, pool1 = self.convblock1(x)
+ conv2 = self.convblock2(pool1)
+
+ # Dense ASPP
+ conv3 = self.convblock3(conv2)
+ conv3_op = tf.concat([conv2, conv3], axis=-1)
+
+ conv4 = self.convblock4(conv3_op)
+ conv4_op = tf.concat([conv3_op, conv4], axis=-1)
+
+ conv5 = self.convblock5(conv4_op)
+ conv5_op = tf.concat([conv4_op, conv5], axis=-1)
+
+ conv6 = self.convblock6(conv5_op)
+ conv6_op = tf.concat([conv5_op, conv6], axis=-1)
+
+ conv7 = self.convblock7(conv6_op)
+
+ # Upstream
+ conv8 = self.convblock8(tf.concat([pool1, conv7], axis=-1))
+
+ up9 = self.upconvblock9(conv8)
+ conv9 = self.convblock9(tf.concat([conv1, up9], axis=-1))
+
+ # Final
+ conv10 = self.convblock10(conv9)
+
+ if self.verbose:
+ print (' ---------- Model: ', self.name)
+ print (' - x: ', x.shape)
+ print (' - conv1: ', conv1.shape)
+ print (' - conv2: ', conv2.shape)
+ print (' - conv3_op: ', conv3_op.shape)
+ print (' - conv4_op: ', conv4_op.shape)
+ print (' - conv5_op: ', conv5_op.shape)
+ print (' - conv6_op: ', conv6_op.shape)
+ print (' - conv7: ', conv7.shape)
+ print (' - conv8: ', conv8.shape)
+ print (' - conv9: ', conv9.shape)
+ print (' - conv10: ', conv10.shape)
+
+
+ return conv10
+
+############################################################
+# MAIN #
+############################################################
+
+if __name__ == "__main__":
+
+ X = tf.random.normal((2,140,140,40,1))
+
+ print ('\n ------------------- ModelFocusNetDropOut ------------------- ')
+ model = ModelFocusNetDropOut(class_count=10, trainable=True)
+ y_predict = model(X, training=True)
+ model.summary()
+
+ print ('\n ------------------- ModelFocusNetFlipOut ------------------- ')
+ model = ModelFocusNetFlipOut(class_count=10, trainable=True)
+ y_predict = model(X, training=True)
+ model.summary()
\ No newline at end of file
diff --git a/src/model/trainer.py b/src/model/trainer.py
new file mode 100644
index 0000000..aba24a3
--- /dev/null
+++ b/src/model/trainer.py
@@ -0,0 +1,1718 @@
+# Import internal libraries
+import src.config as config
+import src.model.models as models
+import src.model.utils as modutils
+import src.model.losses as losses
+import src.dataloader.utils as datautils
+
+# Import external libraries
+import os
+import gc
+import pdb
+import copy
+import time
+import tqdm
+import datetime
+import traceback
+import numpy as np
+import tensorflow as tf
+from pathlib import Path
+import tensorflow_probability as tfp
+
+import matplotlib.pyplot as plt;
+import seaborn as sns;
+
+_EPSILON = tf.keras.backend.epsilon()
+DEBUG = True
+
+############################################################
+# MODEL RELATED #
+############################################################
+
+def set_lr(epoch, optimizer, init_lr):
+
+ if epoch > 1 and epoch % 20 == 0:
+ optimizer.lr.assign(optimizer.lr * 0.98)
+
+############################################################
+# METRICS RELATED #
+############################################################
+class ModelMetrics():
+
+ def __init__(self, metric_type, params):
+
+ self.params = params
+
+ self.label_map = params['internal']['label_map']
+ self.label_ids = params['internal']['label_ids']
+ self.logging_tboard = params['metrics']['logging_tboard']
+ self.metric_type = metric_type
+
+ self.losses_obj = self.get_losses_obj(params)
+
+ self.init_metrics(params)
+ if self.logging_tboard:
+ self.init_tboard_writers(params)
+
+ self.reset_metrics(params)
+ self.init_epoch0(params)
+ self.reset_metrics(params)
+
+ def get_losses_obj(self, params):
+ losses_obj = {}
+ for loss_key in params['metrics']['metrics_loss']:
+ if config.LOSS_DICE == params['metrics']['metrics_loss'][loss_key]:
+ losses_obj[loss_key] = losses.loss_dice_3d_tf_func
+ if config.LOSS_CE == params['metrics']['metrics_loss'][loss_key]:
+ losses_obj[loss_key] = losses.loss_ce_3d_tf_func
+ if config.LOSS_CE_BASIC == params['metrics']['metrics_loss'][loss_key]:
+ losses_obj[loss_key] = losses.loss_cebasic_3d_tf_func
+
+ return losses_obj
+
+ def init_metrics(self, params):
+ """
+ These are metrics derived from tensorflows library
+ """
+ # Metrics for losses (during training for smaller grids)
+ self.metrics_loss_obj = {}
+ for metric_key in params['metrics']['metrics_loss']:
+ self.metrics_loss_obj[metric_key] = {}
+ self.metrics_loss_obj[metric_key]['total'] = tf.keras.metrics.Mean(name='Avg{}-{}'.format(metric_key, self.metric_type))
+ if params['metrics']['metrics_loss'][metric_key] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_CE_BASIC]:
+ for label_id in self.label_ids:
+ self.metrics_loss_obj[metric_key][label_id] = tf.keras.metrics.Mean(name='Avg{}-Label-{}-{}'.format(metric_key, label_id, self.metric_type))
+
+ # Metrics for eval (for full 3D volume)
+ self.metrics_eval_obj = {}
+ for metric_key in params['metrics']['metrics_eval']:
+ self.metrics_eval_obj[metric_key] = {}
+ self.metrics_eval_obj[metric_key]['total'] = tf.keras.metrics.Mean(name='Avg{}-{}'.format(metric_key, self.metric_type))
+ if params['metrics']['metrics_eval'][metric_key] in [config.LOSS_DICE]:
+ for label_id in self.label_ids:
+ self.metrics_eval_obj[metric_key][label_id] = tf.keras.metrics.Mean(name='Avg{}-Label-{}-{}'.format(metric_key, label_id, self.metric_type))
+
+ # Time Metrics
+ self.metric_time_dataloader = tf.keras.metrics.Mean(name='AvgTime-Dataloader-{}'.format(self.metric_type))
+ self.metric_time_model_predict = tf.keras.metrics.Mean(name='AvgTime-ModelPredict-{}'.format(self.metric_type))
+ self.metric_time_model_loss = tf.keras.metrics.Mean(name='AvgTime-ModelLoss-{}'.format(self.metric_type))
+ self.metric_time_model_backprop = tf.keras.metrics.Mean(name='AvgTime-ModelBackProp-{}'.format(self.metric_type))
+
+ def reset_metrics(self, params):
+
+ # Metrics for losses (during training for smaller grids)
+ for metric_key in params['metrics']['metrics_loss']:
+ self.metrics_loss_obj[metric_key]['total'].reset_states()
+ if params['metrics']['metrics_loss'][metric_key] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_CE_BASIC]:
+ for label_id in self.label_ids:
+ self.metrics_loss_obj[metric_key][label_id].reset_states()
+
+ # Metrics for eval (for full 3D volume)
+ for metric_key in params['metrics']['metrics_eval']:
+ self.metrics_eval_obj[metric_key]['total'].reset_states()
+ if params['metrics']['metrics_eval'][metric_key] in [config.LOSS_DICE]:
+ for label_id in self.label_ids:
+ self.metrics_eval_obj[metric_key][label_id].reset_states()
+
+ # Time Metrics
+ self.metric_time_dataloader.reset_states()
+ self.metric_time_model_predict.reset_states()
+ self.metric_time_model_loss.reset_states()
+ self.metric_time_model_backprop.reset_states()
+
+ def init_tboard_writers(self, params):
+ """
+ These are tensorboard writer
+ """
+ # Writers for loss (during training for smaller grids)
+ self.writers_loss_obj = {}
+ for metric_key in params['metrics']['metrics_loss']:
+ self.writers_loss_obj[metric_key] = {}
+ self.writers_loss_obj[metric_key]['total'] = modutils.get_tensorboard_writer(params['exp_name'], suffix=self.metric_type + '-Loss')
+ if params['metrics']['metrics_loss'][metric_key] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_CE_BASIC]:
+ for label_id in self.label_ids:
+ self.writers_loss_obj[metric_key][label_id] = modutils.get_tensorboard_writer(params['exp_name'], suffix=self.metric_type + '-Loss-' + str(label_id))
+
+ # Writers for eval (for full 3D volume)
+ self.writers_eval_obj = {}
+ for metric_key in params['metrics']['metrics_eval']:
+ self.writers_eval_obj[metric_key] = {}
+ self.writers_eval_obj[metric_key]['total'] = modutils.get_tensorboard_writer(params['exp_name'], suffix=self.metric_type + '-Eval')
+ if params['metrics']['metrics_eval'][metric_key] in [config.LOSS_DICE]:
+ for label_id in self.label_ids:
+ self.writers_eval_obj[metric_key][label_id] = modutils.get_tensorboard_writer(params['exp_name'], suffix=self.metric_type + '-Eval-' + str(label_id))
+
+ # Time and other writers
+ self.writer_lr = modutils.get_tensorboard_writer(params['exp_name'], suffix='LR')
+ self.writer_time_dataloader = modutils.get_tensorboard_writer(params['exp_name'], suffix='Time-Dataloader')
+ self.writer_time_model_predict = modutils.get_tensorboard_writer(params['exp_name'], suffix='Time-Model-Predict')
+ self.writer_time_model_loss = modutils.get_tensorboard_writer(params['exp_name'], suffix='Time-Model-Loss')
+ self.writer_time_model_backprop = modutils.get_tensorboard_writer(params['exp_name'], suffix='Time-Model-Backprop')
+
+ def init_epoch0(self, params):
+
+ for metric_str in self.metrics_loss_obj:
+ self.update_metric_loss(metric_str, 1e-6)
+ if params['metrics']['metrics_loss'][metric_str] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_CE_BASIC]:
+ self.update_metric_loss_labels(metric_str, {label_id: 1e-6 for label_id in self.label_ids})
+
+ for metric_str in self.metrics_eval_obj:
+ self.update_metric_eval_labels(metric_str, {label_id: 1e-6 for label_id in self.label_ids})
+
+ self.update_metrics_time(time_dataloader=1e-6, time_predict=1e-6, time_loss=1e-6, time_backprop=1e-6)
+
+ self.write_epoch_summary(epoch=0, label_map=self.label_map, params=None, eval_condition=True)
+
+ def update_metrics_time(self, time_dataloader, time_predict, time_loss, time_backprop):
+ if time_dataloader is not None:
+ self.metric_time_dataloader.update_state(time_dataloader)
+ if time_predict is not None:
+ self.metric_time_model_predict.update_state(time_predict)
+ if time_loss is not None:
+ self.metric_time_model_loss.update_state(time_loss)
+ if time_backprop is not None:
+ self.metric_time_model_backprop.update_state(time_backprop)
+
+ def update_metric_loss(self, metric_str, metric_val):
+ # Metrics for losses (during training for smaller grids)
+ self.metrics_loss_obj[metric_str]['total'].update_state(metric_val)
+
+ @tf.function
+ def update_metric_loss_labels(self, metric_str, metric_vals_labels):
+ # Metrics for losses (during training for smaller grids)
+
+ for label_id in self.label_ids:
+ if metric_vals_labels[label_id] > 0.0:
+ self.metrics_loss_obj[metric_str][label_id].update_state(metric_vals_labels[label_id])
+
+ def update_metric_eval(self, metric_str, metric_val):
+ # Metrics for eval (for full 3D volume)
+ self.metrics_eval_obj[metric_str]['total'].update_state(metric_val)
+
+ def update_metric_eval_labels(self, metric_str, metric_vals_labels, do_average=False):
+ # Metrics for eval (for full 3D volume)
+
+ metric_avg = []
+ for label_id in self.label_ids:
+ if label_id in metric_vals_labels:
+ if metric_vals_labels[label_id] > 0:
+ self.metrics_eval_obj[metric_str][label_id].update_state(metric_vals_labels[label_id])
+ if do_average:
+ if label_id > 0:
+ metric_avg.append(metric_vals_labels[label_id])
+
+ if do_average:
+ if len(metric_avg):
+ self.metrics_eval_obj[metric_str]['total'].update_state(np.mean(metric_avg))
+
+ def write_epoch_summary(self, epoch, label_map, params=None, eval_condition=False):
+
+ if self.logging_tboard:
+ # Metrics for losses (during training for smaller grids)
+ for metric_str in self.metrics_loss_obj:
+ modutils.make_summary('Loss/{}'.format(metric_str), epoch, writer1=self.writers_loss_obj[metric_str]['total'], value1=self.metrics_loss_obj[metric_str]['total'].result())
+ if self.params['metrics']['metrics_loss'][metric_str] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_CE_BASIC]:
+ if len(self.metrics_loss_obj[metric_str]) > 1: # i.e. has label ids
+ for label_id in self.label_ids:
+ label_name, _ = modutils.get_info_from_label_id(label_id, label_map)
+ modutils.make_summary('Loss/{}/{}'.format(metric_str, label_name), epoch, writer1=self.writers_loss_obj[metric_str][label_id], value1=self.metrics_loss_obj[metric_str][label_id].result())
+
+ # Metrics for eval (for full 3D volume)
+ if eval_condition:
+ for metric_str in self.metrics_eval_obj:
+ modutils.make_summary('Eval3D/{}'.format(metric_str), epoch, writer1=self.writers_eval_obj[metric_str]['total'], value1=self.metrics_eval_obj[metric_str]['total'].result())
+ if len(self.metrics_eval_obj[metric_str]) > 1: # i.e. has label ids
+ for label_id in self.label_ids:
+ label_name, _ = modutils.get_info_from_label_id(label_id, label_map)
+ modutils.make_summary('Eval3D/{}/{}'.format(metric_str, label_name), epoch, writer1=self.writers_eval_obj[metric_str][label_id], value1=self.metrics_eval_obj[metric_str][label_id].result())
+
+ # Time Metrics
+ modutils.make_summary('Info/Time/Dataloader' , epoch, writer1=self.writer_time_dataloader , value1=self.metric_time_dataloader.result())
+ modutils.make_summary('Info/Time/ModelPredict' , epoch, writer1=self.writer_time_model_predict , value1=self.metric_time_model_predict.result())
+ modutils.make_summary('Info/Time/ModelLoss' , epoch, writer1=self.writer_time_model_loss , value1=self.metric_time_model_loss.result())
+ modutils.make_summary('Info/Time/ModelBackProp', epoch, writer1=self.writer_time_model_backprop, value1=self.metric_time_model_backprop.result())
+
+ # Learning Rate
+ if params is not None:
+ if 'optimizer' in params:
+ modutils.make_summary('Info/LR', epoch, writer1=self.writer_lr, value1=params['optimizer'].lr)
+
+ def update_pbar(self, pbar):
+ desc_str = ''
+
+ # Metrics for losses (during training for smaller grids)
+ for metric_str in self.metrics_loss_obj:
+ if len(desc_str): desc_str += ','
+
+ if metric_str in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_CE_BASIC]:
+ metric_avg = []
+ for label_id in self.label_ids:
+ if label_id > 0:
+ label_result = self.metrics_loss_obj[metric_str][label_id].result().numpy()
+ if label_result > 0:
+ metric_avg.append(label_result)
+
+ mean_val = 0
+ if len(metric_avg):
+ mean_val = np.mean(metric_avg)
+ loss_text = '{}Loss:{:2f}'.format(metric_str, mean_val)
+ desc_str += loss_text
+
+ # GPU Memory
+ if 1:
+ try:
+ if len(self.metrics_loss_obj) > 1:
+ desc_str = desc_str[:-1] # to remove the extra ','
+ desc_str += ',' + str(modutils.get_tf_gpu_memory())
+ except:
+ pass
+
+ pbar.set_description(desc=desc_str, refresh=True)
+
+def eval_3D_finalize(exp_name, patient_img, patient_gt, patient_pred_processed, patient_pred, patient_pred_std, patient_pred_ent, patient_pred_mif, patient_pred_unc
+ , patient_id_curr
+ , model_folder_epoch_patches
+ , loss_labels_val, hausdorff_labels_val, hausdorff95_labels_val, msd_labels_vals
+ , spacing, label_map
+ , save=False):
+
+ try:
+
+ # Step 1 - Save 3D grid to visualize in 3D Slicer (drag-and-drop mechanism)
+ if save:
+
+ datautils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_img.nrrd' , patient_img[:,:,:,0], spacing)
+ datautils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_mask.nrrd', np.argmax(patient_gt, axis=3),spacing)
+ maskpred_labels = np.argmax(patient_pred_processed, axis=3) # not "np.argmax(patient_pred, axis=3)" since it does not contain any postprocessing
+ datautils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpred.nrrd', maskpred_labels, spacing)
+
+ maskpred_labels_probmean = np.take_along_axis(patient_pred, np.expand_dims(maskpred_labels,axis=-1), axis=-1)[:,:,:,0]
+ datautils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredmean.nrrd', maskpred_labels_probmean, spacing)
+
+ if np.sum(patient_pred_std):
+ maskpred_labels_std = np.take_along_axis(patient_pred_std, np.expand_dims(maskpred_labels,axis=-1), axis=-1)[:,:,:,0]
+ datautils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredstd.nrrd', maskpred_labels_std, spacing)
+
+ maskpred_std_max = np.max(patient_pred_std, axis=-1)
+ datautils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredstdmax.nrrd', maskpred_std_max, spacing)
+
+ if np.sum(patient_pred_ent):
+ maskpred_ent = patient_pred_ent # [H,W,D]
+ datautils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredent.nrrd', maskpred_ent, spacing)
+
+ if np.sum(patient_pred_mif):
+ maskpred_mif = patient_pred_mif
+ datautils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredmif.nrrd', maskpred_mif, spacing)
+
+ if np.sum(patient_pred_unc):
+ if len(patient_pred_unc.shape) == 4:
+ maskpred_labels_unc = np.take_along_axis(patient_pred_unc, np.expand_dims(maskpred_labels,axis=-1), axis=-1)[:,:,:,0] # [H,W,D,C] --> [H,W,D]
+ else:
+ maskpred_labels_unc = patient_pred_unc
+ datautils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredunc.nrrd', maskpred_labels_unc, spacing)
+
+ try:
+ # Step 3.1.3.2 - PLot results for that patient
+ f, axarr = plt.subplots(3,1, figsize=(15,10))
+ boxplot_dice, boxplot_hausdorff, boxplot_hausdorff95 = {}, {}, {}
+ boxplot_dice_mean_list = []
+ for label_id in range(len(loss_labels_val)):
+ label_name, _ = modutils.get_info_from_label_id(label_id, label_map)
+ boxplot_dice[label_name] = [loss_labels_val[label_id]]
+ boxplot_hausdorff[label_name] = [hausdorff_labels_val[label_id]]
+ boxplot_hausdorff95[label_name] = [hausdorff95_labels_val[label_id]]
+ if label_id > 0 and loss_labels_val[label_id] > 0:
+ boxplot_dice_mean_list.append(loss_labels_val[label_id])
+
+ axarr[0].boxplot(boxplot_dice.values())
+ axarr[0].set_xticks(range(1, len(boxplot_dice)+1))
+ axarr[0].set_xticklabels(boxplot_dice.keys())
+ axarr[0].set_ylim([0.0,1.1])
+ axarr[0].grid()
+ axarr[0].set_title('DICE - Avg: {} \n w/o chiasm: {}'.format(
+ '%.4f' % (np.mean(boxplot_dice_mean_list))
+ , '%.4f' % (np.mean(boxplot_dice_mean_list[0:1] + boxplot_dice_mean_list[2:])) # avoid label_id=2
+ )
+ )
+
+ axarr[1].boxplot(boxplot_hausdorff.values())
+ axarr[1].set_xticks(range(1,len(boxplot_hausdorff)+1))
+ axarr[1].set_xticklabels(boxplot_hausdorff.keys())
+ axarr[1].grid()
+ axarr[1].set_title('Hausdorff')
+
+ axarr[2].boxplot(boxplot_hausdorff95.values())
+ axarr[2].set_xticks(range(1,len(boxplot_hausdorff95)+1))
+ axarr[2].set_xticklabels(boxplot_hausdorff95.keys())
+ axarr[2].set_title('95% Hausdorff')
+ axarr[2].grid()
+
+ plt.savefig(str(Path(model_folder_epoch_patches).joinpath('results_' + patient_id_curr + '.png')), bbox_inches='tight') # , bbox_inches='tight'
+ plt.close()
+
+ except:
+ traceback.print_exc()
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+def get_ece(y_true, y_predict, patient_id, res_global, verbose=False):
+ """
+ Params
+ ------
+ y_true : [H,W,D,C], np.array, binary
+ y_predict: [H,W,D,C], np.array, with softmax probability values
+ - Ref: https://github.com/sirius8050/Expected-Calibration-Error/blob/master/ECE.py
+ : On Calibration of Modern Neural Networks
+ - Ref(future): https://github.com/yding5/AdaptiveBinning
+ : Revisiting the evaluation of uncertainty estimation and its application to explore model complexity-uncertainty trade-off
+ """
+ res = {}
+ nan_value = -0.1
+
+ if verbose: print (' - [get_ece()] patient_id: ', patient_id)
+
+ try:
+
+ # Step 0 - Init
+ label_count = y_true.shape[-1]
+
+ # Step 1 - Calculate o_predict
+ o_true = np.argmax(y_true, axis=-1)
+ o_predict = np.argmax(y_predict, axis=-1)
+
+ # Step 2 - Loop over different classes
+ for label_id in range(label_count):
+
+ if label_id > -1:
+
+ if verbose: print (' --- [get_ece()] label_id: ', label_id)
+
+ if label_id not in res_global: res_global[label_id] = {'o_predict_label':[], 'y_predict_label':[], 'o_true_label':[]}
+
+ # Step 2.1 - Get o_predict_label(label_ids), o_true_label(label_ids), y_predict_label(probs) [and append to global list]
+ ## NB: You are considering TP + FP here
+ o_true_label = o_true[o_predict == label_id]
+ o_predict_label = o_predict[o_predict == label_id]
+ y_predict_label = y_predict[:,:,:,label_id][o_predict == label_id]
+ res_global[label_id]['o_true_label'].extend(o_true_label.flatten().tolist())
+ res_global[label_id]['o_predict_label'].extend(o_predict_label.flatten().tolist())
+ res_global[label_id]['y_predict_label'].extend(y_predict_label.flatten().tolist())
+
+ if len(o_true_label) and len(y_predict_label):
+
+ # Step 2.2 - Bin the probs and calculate their mean
+ y_predict_label_bin_ids = np.digitize(y_predict_label, np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.01]), right=False) - 1
+ y_predict_binned_vals = [y_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)]
+ y_predict_bins_mean = [np.mean(vals) if len(vals) else nan_value for vals in y_predict_binned_vals]
+
+ # Step 2.3 - Calculate the accuracy of each bin
+ o_predict_label_bins = [o_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)]
+ o_true_label_bins = [o_true_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)]
+ y_predict_bins_accuracy = [np.sum(o_predict_label_bins[bin_id] == o_true_label_bins[bin_id])/len(o_predict_label_bins[bin_id]) if len(o_predict_label_bins[bin_id]) else nan_value for bin_id in range(label_count)]
+ y_predict_bins_len = [len(o_predict_label_bins[bin_id]) for bin_id in range(label_count)]
+
+ # Step 2.4 - Wrapup
+ N = np.prod(y_predict_label.shape)
+ ce = np.array((np.array(y_predict_bins_len)/N)*(np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean)))
+ ce[ce == 0] = nan_value # i.e. y_predict_bins_accuracy[bin_id] == y_predict_bins_mean[bind_id] = nan_value
+ res[label_id] = ce
+
+ else:
+ res[label_id] = -1
+
+ if 0:
+ if label_id == 1:
+ diff = np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean)
+ print (' - [get_ece()][BStem] diff: ', ['%.4f' % each for each in diff])
+
+ # NB: This considers the whole volume
+ o_true_label = y_true[:,:,:,label_id] # [1=this label, 0=other label]
+ o_predict_label = np.array(o_predict, copy=True)
+ o_predict_label[o_predict_label != label_id] = 0
+ o_predict_label[o_predict_label == label_id] = 1 # [1 - predicted this label, 0 = predicted other label]
+ y_predict_label = y_predict[:,:,:,label_id]
+
+ # Step x.2 - Bin the probs and calculate their mean
+ y_predict_label_bin_ids = np.digitize(y_predict_label, np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.01]), right=False) - 1
+ y_predict_binned_vals = [y_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)]
+ y_predict_bins_mean = [np.mean(vals) if len(vals) else nan_value for vals in y_predict_binned_vals]
+
+ # Step x.3 - Calculate the accuracy of each bin
+ o_predict_label_bins = [o_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)]
+ o_true_label_bins = [o_true_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)]
+ y_predict_bins_accuracy = [np.sum(o_predict_label_bins[bin_id] == o_true_label_bins[bin_id])/len(o_predict_label_bins[bin_id]) if len(o_predict_label_bins[bin_id]) else nan_value for bin_id in range(label_count)]
+ y_predict_bins_len = [len(o_predict_label_bins[bin_id]) for bin_id in range(label_count)]
+
+ N_new = np.prod(y_predict_label.shape)
+ diff_new = np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean)
+ ce_new = np.array((np.array(y_predict_bins_len)/N_new)*(np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean)))
+ print (' - [get_ece()][BStem] diff_new: ', ['%.4f' % each for each in diff_new])
+
+ pdb.set_trace()
+
+ if verbose:
+ print (' --- [get_ece()] y_predict_bins_accuracy: ', ['%.4f' % (each) for each in np.array(y_predict_bins_accuracy)])
+ print (' --- [get_ece()] CE : ', ['%.4f' % (each) for each in np.array(res[label_id])])
+ print (' --- [get_ece()] ECE: ', np.sum(np.abs(res[label_id][res[label_id] != nan_value])))
+
+ # Prob bars
+ # plt.hist(y_predict[:,:,:,label_id].flatten(), bins=10)
+ # plt.title('Softmax Probs (label={})\nPatient:{}'.format(label_id, patient_id))
+ # plt.show()
+
+ # GT Prob bars
+ # plt.bar(np.arange(len(y_predict_bins_len))/10.0 + 0.1, y_predict_bins_len, width=0.05)
+ # plt.title('Softmax Probs (GT) (label={})\nPatient:{}'.format(label_id, patient_id))
+ # plt.xlabel('Probabilities')
+ # plt.show()
+
+ # GT Probs (sorted) in plt.plot (with equally-spaced bins)
+ # from collections import Counter
+ # tmp = np.sort(y_predict_label)
+ # plt.plot(range(len(tmp)), tmp, color='orange')
+ # tmp_bins = np.digitize(tmp, np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.01])) - 1
+ # tmp_bins_len = Counter(tmp_bins)
+ # boundary_start = 0
+ # plt.plot([0,0],[0.0,1.0], color='black', alpha=0.5, linestyle='dashed', label='Bins(equally-spaced)')
+ # for boundary in np.arange(0,len(tmp_bins_len)): plt.plot([boundary_start+tmp_bins_len[boundary], boundary_start+tmp_bins_len[boundary]], [0.0,1.0], color='black', alpha=0.5, linestyle='dashed'); boundary_start+=tmp_bins_len[boundary]
+ # plt.title('Sorted Softmax Probs (GT) (label={})\nPatient:{}'.format(label_id, patient_id))
+ # plt.legend()
+ # plt.show()
+
+ # GT Probs (sorted) in plt.plot (with equally-sized bins)
+ if label_id == 1:
+ Path('tmp').mkdir(parents=True, exist_ok=True)
+ tmp = np.sort(y_predict_label)
+ tmp_len = len(tmp)
+ plt.plot(range(len(tmp)), tmp, color='orange')
+ for boundary in np.arange(0,tmp_len, int(tmp_len//10)): plt.plot([boundary, boundary], [0.0,1.0], color='black', alpha=0.5, linestyle='dashed')
+ plt.plot([0,0],[0,0], color='black', alpha=0.5, linestyle='dashed', label='Bins(equally-sized)')
+ plt.title('Sorted Softmax Probs (GT) (label={})\nPatient:{}'.format(label_id, patient_id))
+ plt.legend()
+ # plt.show()
+ plt.savefig('./tmp/ECE_SortedProbs_label_{}_{}.png'.format(label_id, patient_id, ), bbox_inches='tight');plt.close()
+
+ # ECE plot
+ plt.plot(np.arange(11), np.arange(11)/10.0, linestyle='dashed', color='black', alpha=0.8)
+ plt.scatter(np.arange(len(y_predict_bins_mean)) + 0.5 , y_predict_bins_mean, alpha=0.5, color='g', marker='s', label='Mean Pred')
+ plt.scatter(np.arange(len(y_predict_bins_accuracy)) + 0.5 , y_predict_bins_accuracy, alpha=0.5, color='b', marker='x', label='Accuracy')
+ diff = np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean)
+ for bin_id in range(len(y_predict_bins_accuracy)): plt.plot([bin_id + 0.5, bin_id + 0.5],[y_predict_bins_accuracy[bin_id], y_predict_bins_mean[bin_id]], color='pink')
+ plt.plot([bin_id + 0.5, bin_id + 0.5],[y_predict_bins_accuracy[bin_id], y_predict_bins_mean[bin_id]], color='pink', label='CE')
+ plt.xticks(ticks=np.arange(11), labels=np.arange(11)/10.0)
+ plt.title('CE (label={})\nPatient:{}'.format(label_id, patient_id))
+ plt.xlabel('Probability')
+ plt.ylabel('Accuracy')
+ plt.legend()
+ # plt.show()
+ plt.savefig('./tmp/ECE_label_{}_{}.png'.format(label_id, patient_id, ), bbox_inches='tight');plt.close()
+
+ # pdb.set_trace()
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+ return res_global, res
+
+def eval_3D_summarize(res, ece_global_obj, model, eval_type, deepsup_eval, label_map, model_folder_epoch_patches, times_mcruns, ttotal, save=False, verbose=False):
+
+ try:
+
+ pid = os.getpid()
+
+ ###############################################################################
+ # Summarize #
+ ###############################################################################
+
+ # Step 1 - Summarize DICE + Surface Distances
+ if 1:
+ loss_labels_avg, loss_labels_std = [], []
+ hausdorff_labels_avg, hausdorff_labels_std = [], []
+ hausdorff95_labels_avg, hausdorff95_labels_std = [], []
+ msd_labels_avg, msd_labels_std = [], []
+
+ loss_labels_list = np.array([res[patient_id][config.KEY_DICE_LABELS] for patient_id in res])
+ hausdorff_labels_list = np.array([res[patient_id][config.KEY_HD_LABELS] for patient_id in res])
+ hausdorff95_labels_list = np.array([res[patient_id][config.KEY_HD95_LABELS] for patient_id in res])
+ msd_labels_list = np.array([res[patient_id][config.KEY_MSD_LABELS] for patient_id in res])
+
+ for label_id in range(loss_labels_list.shape[1]):
+ tmp_vals = loss_labels_list[:,label_id]
+ loss_labels_avg.append(np.mean(tmp_vals[tmp_vals > 0]))
+ loss_labels_std.append(np.std(tmp_vals[tmp_vals > 0]))
+
+ if label_id > 0:
+ tmp_vals = hausdorff_labels_list[:,label_id]
+ hausdorff_labels_avg.append(np.mean(tmp_vals[tmp_vals > 0])) # avoids -1 for "erroneous" HD, and 0 for "not to be calculated" HD
+ hausdorff_labels_std.append(np.std(tmp_vals[tmp_vals > 0]))
+
+ tmp_vals = hausdorff95_labels_list[:,label_id]
+ hausdorff95_labels_avg.append(np.mean(tmp_vals[tmp_vals > 0]))
+ hausdorff95_labels_std.append(np.std(tmp_vals[tmp_vals > 0]))
+
+ tmp_vals = msd_labels_list[:,label_id]
+ msd_labels_avg.append(np.mean(tmp_vals[tmp_vals > 0]))
+ msd_labels_std.append(np.std(tmp_vals[tmp_vals > 0]))
+
+ else:
+ hausdorff_labels_avg.append(0)
+ hausdorff_labels_std.append(0)
+ hausdorff95_labels_avg.append(0)
+ hausdorff95_labels_std.append(0)
+ msd_labels_avg.append(0)
+ msd_labels_std.append(0)
+
+ loss_avg = np.mean([res[patient_id][config.KEY_DICE_AVG] for patient_id in res])
+ print (' --------------------------- eval_type: ', eval_type)
+ print (' - dice_labels_3D : ', ['%.4f' % each for each in loss_labels_avg])
+ print (' - dice_labels_3D : ', ['%.4f' % each for each in loss_labels_std])
+ print (' - dice_3D : %.4f' % np.mean(loss_labels_avg))
+ print (' - dice_3D (w/o bgd): %.4f' % np.mean(loss_labels_avg[1:]))
+ print (' - dice_3D (w/o bgd, w/o chiasm): %.4f' % np.mean(loss_labels_avg[1:2] + loss_labels_avg[3:]))
+ print ('')
+ print (' - hausdorff_labels_3D : ', ['%.4f' % each for each in hausdorff_labels_avg])
+ print (' - hausdorff_labels_3D : ', ['%.4f' % each for each in hausdorff_labels_std])
+ print (' - hausdorff_3D (w/o bgd): %.4f' % np.mean(hausdorff_labels_avg[1:]))
+ print (' - hausdorff_3D (w/o bgd, w/o chiasm): %.4f' % np.mean(hausdorff_labels_avg[1:2] + hausdorff_labels_avg[3:]))
+ print ('')
+ print (' - hausdorff95_labels_3D : ', ['%.4f' % each for each in hausdorff95_labels_avg])
+ print (' - hausdorff95_labels_3D : ', ['%.4f' % each for each in hausdorff95_labels_std])
+ print (' - hausdorff95_3D (w/o bgd): %.4f' % np.mean(hausdorff95_labels_avg[1:]))
+ print (' - hausdorff95_3D (w/o bgd, w/o chiasm): %.4f' % np.mean(hausdorff95_labels_avg[1:2] + hausdorff95_labels_avg[3:]))
+ print ('')
+ print (' - msd_labels_3D : ', ['%.4f' % each for each in msd_labels_avg])
+ print (' - msd_labels_3D : ', ['%.4f' % each for each in msd_labels_std])
+ print (' - msd_3D (w/o bgd): %.4f' % np.mean(msd_labels_avg[1:]))
+ print (' - msd_3D (w/o bgd, w/o chiasm): %.4f' % np.mean(msd_labels_avg[1:2] + msd_labels_avg[3:]))
+
+ # Step 2 - Summarize ECE
+ if 1:
+ print ('')
+ gc.collect()
+ nan_value = config.VAL_ECE_NAN
+ ece_labels_obj = {}
+ ece_labels = []
+ label_count = len(ece_global_obj)
+ pbar_desc_prefix = '[ECE]'
+ ece_global_obj_keys = list(ece_global_obj.keys())
+ res[config.KEY_PATIENT_GLOBAL] = {}
+ with tqdm.tqdm(total=label_count, desc=pbar_desc_prefix, disable=True) as pbar_ece:
+ for label_id in ece_global_obj_keys:
+ o_true_label = np.array(ece_global_obj[label_id]['o_true_label'])
+ o_predict_label = np.array(ece_global_obj[label_id]['o_predict_label'])
+ y_predict_label = np.array(ece_global_obj[label_id]['y_predict_label'])
+ if label_id in ece_global_obj: del ece_global_obj[label_id]
+ gc.collect()
+
+ # Step 1.1 - Bin the probs and calculate their mean
+ y_predict_label_bin_ids = np.digitize(y_predict_label, np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.01]), right=False) - 1
+ y_predict_binned_vals = [y_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)]
+ y_predict_bins_mean = [np.mean(vals) if len(vals) else nan_value for vals in y_predict_binned_vals]
+
+ # Step 1.2 - Calculate the accuracy of each bin
+ o_predict_label_bins = [o_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)]
+ o_true_label_bins = [o_true_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)]
+ y_predict_bins_accuracy = [np.sum(o_predict_label_bins[bin_id] == o_true_label_bins[bin_id])/len(o_predict_label_bins[bin_id]) if len(o_predict_label_bins[bin_id]) else nan_value for bin_id in range(label_count)]
+ y_predict_bins_len = [len(o_predict_label_bins[bin_id]) for bin_id in range(label_count)]
+
+ # Step 1.3 - Wrapup
+ N = np.prod(y_predict_label.shape)
+ ce = np.array((np.array(y_predict_bins_len)/N)*(np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean)))
+ ce[ce == 0] = nan_value
+ ece_label = np.sum(np.abs(ce[ce != nan_value]))
+ ece_labels.append(ece_label)
+ ece_labels_obj[label_id] = {'y_predict_bins_mean':y_predict_bins_mean, 'y_predict_bins_accuracy':y_predict_bins_accuracy, 'ce':ce, 'ece':ece_label}
+
+ pbar_ece.update(1)
+ memory = pbar_desc_prefix + ' [' + str(modutils.get_memory(pid)) + ']'
+ pbar_ece.set_description(desc=memory, refresh=True)
+
+ res[config.KEY_PATIENT_GLOBAL][label_id] = {'ce':ce, 'ece':ece_label}
+
+ print (' - ece_labels : ', ['%.4f' % each for each in ece_labels])
+ print (' - ece : %.4f' % np.mean(ece_labels))
+ print (' - ece (w/o bgd): %.4f' % np.mean(ece_labels[1:]))
+ print (' - ece (w/o bgd, w/o chiasm): %.4f' % np.mean(ece_labels[1:2] + ece_labels[3:]))
+ print ('')
+
+ del ece_global_obj
+ gc.collect()
+
+ # Step 3 - Plot
+ if 1:
+ if save and not deepsup_eval:
+ f, axarr = plt.subplots(3,1, figsize=(15,10))
+ boxplot_dice, boxplot_hausdorff, boxplot_hausdorff95, boxplot_msd = {}, {}, {}, {}
+ for label_id in range(len(loss_labels_list[0])):
+ label_name, _ = modutils.get_info_from_label_id(label_id, label_map)
+ boxplot_dice[label_name] = loss_labels_list[:,label_id]
+ boxplot_hausdorff[label_name] = hausdorff_labels_list[:,label_id]
+ boxplot_hausdorff95[label_name] = hausdorff95_labels_list[:,label_id]
+ boxplot_msd[label_name] = msd_labels_list[:,label_id]
+
+ axarr[0].boxplot(boxplot_dice.values())
+ axarr[0].set_xticks(range(1, len(boxplot_dice)+1))
+ axarr[0].set_xticklabels(boxplot_dice.keys())
+ axarr[0].set_ylim([0.0,1.1])
+ axarr[0].set_title('DICE (Avg: {}) \n w/o chiasm:{}'.format(
+ '%.4f' % np.mean(loss_labels_avg[1:])
+ , '%.4f' % np.mean(loss_labels_avg[1:2] + loss_labels_avg[3:])
+ )
+ )
+
+ axarr[1].boxplot(boxplot_hausdorff.values())
+ axarr[1].set_xticks(range(1, len(boxplot_hausdorff)+1))
+ axarr[1].set_xticklabels(boxplot_hausdorff.keys())
+ axarr[1].set_ylim([0.0,10.0])
+ axarr[1].set_title('Hausdorff')
+
+ axarr[2].boxplot(boxplot_hausdorff95.values())
+ axarr[2].set_xticks(range(1, len(boxplot_hausdorff95)+1))
+ axarr[2].set_xticklabels(boxplot_hausdorff95.keys())
+ axarr[2].set_ylim([0.0,6.0])
+ axarr[2].set_title('95% Hausdorff')
+
+ path_results = str(Path(model_folder_epoch_patches).joinpath('results_all.png'))
+ plt.savefig(path_results, bbox_inches='tight')
+ plt.close()
+
+ f, axarr = plt.subplots(1, figsize=(15,10))
+ axarr.boxplot(boxplot_dice.values())
+ axarr.set_xticks(range(1, len(boxplot_dice)+1))
+ axarr.set_xticklabels(boxplot_dice.keys())
+ axarr.set_ylim([0.0,1.1])
+ axarr.set_yticks(np.arange(0,1.1,0.05))
+ axarr.set_title('DICE')
+ axarr.grid()
+ path_results = str(Path(model_folder_epoch_patches).joinpath('results_all_dice.png'))
+ plt.savefig(path_results, bbox_inches='tight')
+ plt.close()
+
+ f, axarr = plt.subplots(1, figsize=(15,10))
+ axarr.boxplot(boxplot_hausdorff95.values())
+ axarr.set_xticks(range(1, len(boxplot_hausdorff95)+1))
+ axarr.set_xticklabels(boxplot_hausdorff95.keys())
+ axarr.set_title('95% HD')
+ axarr.grid()
+ path_results = str(Path(model_folder_epoch_patches).joinpath('results_all_hd95.png'))
+ plt.savefig(path_results, bbox_inches='tight')
+ plt.close()
+
+ f, axarr = plt.subplots(1, figsize=(15,10))
+ axarr.boxplot(boxplot_hausdorff.values())
+ axarr.set_xticks(range(1, len(boxplot_hausdorff)+1))
+ axarr.set_xticklabels(boxplot_hausdorff.keys())
+ axarr.set_title('Hausdorff Distance')
+ axarr.grid()
+ path_results = str(Path(model_folder_epoch_patches).joinpath('results_all_hd.png'))
+ plt.savefig(path_results, bbox_inches='tight')
+ plt.close()
+
+ f, axarr = plt.subplots(1, figsize=(15,10))
+ axarr.boxplot(boxplot_msd.values())
+ axarr.set_xticks(range(1, len(boxplot_msd)+1))
+ axarr.set_xticklabels(boxplot_msd.keys())
+ axarr.set_title('MSD')
+ axarr.grid()
+ path_results = str(Path(model_folder_epoch_patches).joinpath('results_all_msd.png'))
+ plt.savefig(path_results, bbox_inches='tight')
+ plt.close()
+
+ # ECE
+ for label_id in ece_labels_obj:
+ y_predict_bins_mean = ece_labels_obj[label_id]['y_predict_bins_mean']
+ y_predict_bins_accuracy = ece_labels_obj[label_id]['y_predict_bins_accuracy']
+ ece = ece_labels_obj[label_id]['ece']
+
+ plt.plot(np.arange(11), np.arange(11)/10.0, linestyle='dashed', color='black', alpha=0.8)
+ plt.scatter(np.arange(len(y_predict_bins_mean)) + 0.5 , y_predict_bins_mean, alpha=0.5, color='g', marker='s', label='Mean Pred')
+ plt.scatter(np.arange(len(y_predict_bins_accuracy)) + 0.5 , y_predict_bins_accuracy, alpha=0.5, color='b', marker='x', label='Accuracy')
+ for bin_id in range(len(y_predict_bins_accuracy)): plt.plot([bin_id + 0.5, bin_id + 0.5],[y_predict_bins_accuracy[bin_id], y_predict_bins_mean[bin_id]], color='pink')
+ plt.plot([bin_id + 0.5, bin_id + 0.5],[y_predict_bins_accuracy[bin_id], y_predict_bins_mean[bin_id]], color='pink', label='CE')
+ plt.xticks(ticks=np.arange(11), labels=np.arange(11)/10.0)
+ plt.title('CE (label={})\nECE: {}'.format(label_id, '%.5f' % (ece)))
+ plt.xlabel('Probability')
+ plt.ylabel('Accuracy')
+ plt.ylim([-0.15, 1.05])
+ plt.legend()
+
+ # plt.show()
+ path_results = str(Path(model_folder_epoch_patches).joinpath('results_ece_label{}.png'.format(label_id)))
+ plt.savefig(str(path_results), bbox_inches='tight')
+ plt.close()
+
+ # Step 5 - Save data as .json
+ if 1:
+ try:
+ filename = str(Path(model_folder_epoch_patches).joinpath(config.FILENAME_EVAL3D_JSON))
+ modutils.write_json(res, filename)
+
+ except:
+ traceback.print_exc()
+ if DEBUG: pdb.set_trace()
+
+ model.trainable=True
+ print ('\n - [eval_3D] Avg of times_mcruns : {:f} +- {:f}'.format(np.mean(times_mcruns), np.std(times_mcruns)))
+ print (' - [eval_3D()] Total time passed (save={}) : {}s \n'.format(save, round(time.time() - ttotal, 2)))
+ if verbose: pdb.set_trace()
+
+ return loss_avg, {i:loss_labels_avg[i] for i in range(len(loss_labels_avg))}
+
+ except:
+ model.trainable=True
+ traceback.print_exc()
+ if DEBUG: pdb.set_trace()
+ return -1, {}
+
+def eval_3D_process_outputs(exp_name, res, ece_global_obj, patient_id_curr, meta1_batch, patient_img, patient_gt, patient_pred_vals, patient_pred_overlap, patient_pred_std, patient_pred_ent, patient_pred_mif, patient_pred_unc
+ , deepsup_eval, model_folder_epoch_imgs, model_folder_epoch_patches, label_map, label_colors, t99, save=False, verbose=False):
+
+ try:
+
+ # Step 3.1.1 - Get stitched patient grid
+ if verbose: t0 = time.time()
+ patient_pred_ent = patient_pred_ent/patient_pred_overlap # [H,W,D]/[H,W,D]
+ patient_pred_mif = patient_pred_mif/patient_pred_overlap
+ patient_pred_overlap = np.expand_dims(patient_pred_overlap, -1)
+ patient_pred = patient_pred_vals/patient_pred_overlap # [H,W,D,C]/[H,W,D,1]
+ patient_pred_std = patient_pred_std/patient_pred_overlap
+ patient_pred_unc = patient_pred_unc/patient_pred_overlap
+ del patient_pred_vals
+ del patient_pred_overlap
+ patient_pred_unc = np.take_along_axis(patient_pred_unc, np.expand_dims(np.argmax(patient_pred, axis=-1),axis=-1), axis=-1)[:,:,:,0]
+
+ gc.collect() # returns number of unreachable objects collected by GC
+ patient_pred_postprocessed = losses.remove_smaller_components(patient_gt, patient_pred, meta=patient_id_curr, label_ids_small = [2,4,5])
+ if verbose: print (' - [eval_3D_process_outputs()] Post-Process time : ', time.time() - t0,'s')
+
+ # Step 3.1.2 - Loss Calculation
+ spacing = np.array([meta1_batch[4], meta1_batch[5], meta1_batch[6]])/100.0
+ try:
+ if verbose: t0 = time.time()
+ loss_avg_val, loss_labels_val = losses.dice_numpy(patient_gt, patient_pred_postprocessed)
+ hausdorff_avg_val, hausdorff_labels_val, hausdorff95_avg_val, hausdorff95_labels_val, msd_avg_val, msd_labels_vals = losses.get_surface_distances(patient_gt, patient_pred_postprocessed, spacing)
+ if verbose:
+ print (' - [eval_3D_process_outputs()] DICE : ', ['%.4f' % (each) for each in loss_labels_val])
+ print (' - [eval_3D_process_outputs()] HD95 : ', ['%.4f' % (each) for each in hausdorff95_labels_val])
+
+ if loss_avg_val != -1 and len(loss_labels_val):
+ res[patient_id_curr] = {
+ config.KEY_DICE_AVG : loss_avg_val
+ , config.KEY_DICE_LABELS : loss_labels_val
+ , config.KEY_HD_AVG : hausdorff95_avg_val
+ , config.KEY_HD_LABELS : hausdorff_labels_val
+ , config.KEY_HD95_AVG : hausdorff95_avg_val
+ , config.KEY_HD95_LABELS : hausdorff95_labels_val
+ , config.KEY_MSD_AVG : msd_avg_val
+ , config.KEY_MSD_LABELS : msd_labels_vals
+ }
+ else:
+ print (' - [INFO][eval_3D_process_outputs()] patient_id: ', patient_id_curr)
+ if verbose: print (' - [eval_3D_process_outputs()] Loss calculation time: ', time.time() - t0,'s')
+ except:
+ traceback.print_exc()
+ if DEBUG: pdb.set_trace()
+
+ # Step 3.1.3 - ECE calculation
+ if verbose: t0 = time.time()
+ ece_global_obj, ece_patient_obj = get_ece(patient_gt, patient_pred, patient_id_curr, ece_global_obj)
+ res[patient_id_curr][config.KEY_ECE_LABELS] = ece_patient_obj
+ if verbose: print (' - [eval_3D_process_outputs()] ECE time : ', time.time() - t0,'s')
+
+ # Step 3.1.5 - Save/Visualize
+ if not deepsup_eval:
+ if verbose: t0 = time.time()
+ eval_3D_finalize(exp_name, patient_img, patient_gt, patient_pred_postprocessed, patient_pred, patient_pred_std, patient_pred_ent, patient_pred_mif, patient_pred_unc
+ , patient_id_curr
+ , model_folder_epoch_patches
+ , loss_labels_val, hausdorff_labels_val, hausdorff95_labels_val, msd_labels_vals
+ , spacing, label_map
+ , save=save)
+ if verbose: print (' - [eval_3D_process_outputs()] Save as .nrrd time : ', time.time() - t0,'s')
+
+ if verbose: print (' - [eval_3D_process_outputs()] Total patient time : ', time.time() - t99,'s')
+
+ # Step 3.1.6
+ del patient_img
+ del patient_gt
+ del patient_pred
+ del patient_pred_std
+ del patient_pred_ent
+ del patient_pred_postprocessed
+ del patient_pred_mif
+ del patient_pred_unc
+ gc.collect()
+
+ return res, ece_global_obj
+
+ except:
+ traceback.print_exc()
+ if DEBUG: pdb.set_trace()
+ return res, ece_global_obj
+
+def eval_3D_get_outputs(model, X, Y, training_bool, MC_RUNS, deepsup, deepsup_eval):
+
+ # Step 0 - Init
+ DO_KEYS = [config.KEY_MIF, config.KEY_ENT]
+ # DO_KEYS = [config.KEY_STD]
+
+ # Step 1 - Warm up model
+ _ = model(X, training=training_bool)
+
+ # Step 2 - Run Monte-Carlo predictions
+ error_mcruns = False
+ try:
+ tic_mcruns = time.time()
+ y_predict = tf.stack([model(X, training=training_bool) for _ in range(MC_RUNS)]) # [MC,B,H,W,D,C] # [Note] with model.trainable=False and model.training=True we get dropout at inference time
+ toc_mcruns = time.time()
+ except tf.errors.ResourceExhaustedError as e:
+ print (' - [eval_3D_get_outputs()] OOM error for MC_RUNS={}'.format(MC_RUNS))
+ error_mcruns = True
+
+ if error_mcruns:
+ try:
+ MC_RUNS = 5
+ tic_mcruns = time.time()
+ if deepsup:
+ if deepsup_eval:
+ y_predict = tf.stack([model(X, training=training_bool)[0] for _ in range(MC_RUNS)]) # [MC,B,H,W,D,C] # [Note] with model.trainable=False and model.training=True we get dropout at inference time
+ X = X[:,::2,::2,::2,:]
+ Y = Y[:,::2,::2,::2,:]
+ else:
+ y_predict = tf.stack([model(X, training=training_bool)[1] for _ in range(MC_RUNS)]) # [MC,B,H,W,D,C] # [Note] with model.trainable=False and model.training=True we get dropout at inference time
+ else:
+ y_predict = tf.stack([model(X, training=training_bool) for _ in range(MC_RUNS)]) # [MC,B,H,W,D,C] # [Note] with model.trainable=False and model.training=True we get dropout at inference time
+ toc_mcruns = time.time()
+ except tf.errors.ResourceExhaustedError as e:
+ print (' - [eval_3D_get_outputs()] OOM error for MC_RUNS=5')
+ import sys; sys.exit(1)
+
+ # Step 3 - Calculate different metrics
+ if config.KEY_MIF in DO_KEYS:
+ y_predict_mif = y_predict * tf.math.log(y_predict + _EPSILON) # [MC,B,H,W,D,C]
+ y_predict_mif = tf.math.reduce_sum(y_predict_mif, axis=[0,-1])/MC_RUNS # [MC,B,H,W,D,C] -> [B,H,W,D]
+ else:
+ y_predict_mif = []
+
+ if config.KEY_STD in DO_KEYS:
+ y_predict_std = tf.math.reduce_std(y_predict, axis=0) # [MC,B,H,W,D,C] -> [B,H,W,D,C]
+ else:
+ y_predict_std = []
+
+ if config.KEY_PERC in DO_KEYS:
+ y_predict_perc = tfp.stats.percentile(y_predict, q=[30,70], axis=0, interpolation='nearest')
+ y_predict_unc = y_predict_perc[1] - y_predict_perc[0]
+ del y_predict_perc
+ gc.collect()
+ else:
+ y_predict_unc = []
+
+ y_predict = tf.math.reduce_mean(y_predict, axis=0) # [MC,B,H,W,D,C] -> [B,H,W,D,C]
+
+ if config.KEY_ENT in DO_KEYS:
+ y_predict_ent = -1*tf.math.reduce_sum(y_predict * tf.math.log(y_predict + _EPSILON), axis=-1) # [B,H,W,D,C] -> # [B,H,W,D] ent = -p.log(p)
+ y_predict_mif = y_predict_ent + y_predict_mif # [B,H,W,D] + [B,H,W,D] = [B,H,W,D]; MI = ent + expectation(ent)
+ else:
+ y_predict_ent = []
+ y_predict_mif = []
+
+ return Y, y_predict, y_predict_std, y_predict_ent, y_predict_mif, y_predict_unc, toc_mcruns-tic_mcruns
+
+def eval_3D(model, dataset_eval, dataset_eval_gen, params, show=False, save=False, verbose=False):
+
+ try:
+
+ # Step 0.0 - Variables under debugging
+ pass
+
+ # Step 0.1 - Extract params
+ PROJECT_DIR = config.PROJECT_DIR
+ exp_name = params['exp_name']
+ pid = params['pid']
+ eval_type = params['eval_type']
+ batch_size = 2
+ epoch = params['epoch']
+ deepsup = params['deepsup']
+ deepsup_eval = params['deepsup_eval']
+ label_map = dict(dataset_eval.get_label_map())
+ label_colors = dict(dataset_eval.get_label_colors())
+
+ if verbose: print (''); print (' --------------------- eval_3D({}) ---------------------'.format(eval_type))
+
+ # Step 0.2 - Init results array
+ res = {}
+ ece_global_obj = {}
+ patient_grid_count = {}
+
+ # Step 0.3 - Init temp variables
+ patient_id_curr = None
+ w_grid, h_grid, d_grid = None, None, None
+ meta1_batch = None
+ patient_gt = None
+ patient_img = None
+ patient_pred_overlap = None
+ patient_pred_vals = None
+ model_folder_epoch_patches = None
+ model_folder_epoch_imgs = None
+
+ mc_runs = params.get(config.KEY_MC_RUNS, None)
+ training_bool = params.get(config.KEY_TRAINING_BOOL, None)
+ model_folder_epoch_patches, model_folder_epoch_imgs = modutils.get_eval_folders(PROJECT_DIR, exp_name, epoch, eval_type, mc_runs, training_bool, create=True)
+
+ # Step 0.4 - Debug vars
+ ttotal,t0, t99 = time.time(), None, None
+ times_mcruns = []
+
+ # Step 1 - Loop over dataset_eval (which provides patients & grids in an ordered manner)
+ print ('')
+ model.trainable = False
+ pbar_desc_prefix = 'Eval3D_{} [batch={}]'.format(eval_type, batch_size)
+ training_bool = params.get(config.KEY_TRAINING_BOOL,True) # [True, False]
+ with tqdm.tqdm(total=len(dataset_eval), desc=pbar_desc_prefix, leave=False) as pbar_eval:
+ for (X,Y,meta1,meta2) in dataset_eval_gen.repeat(1):
+
+ MC_RUNS = params.get(config.KEY_MC_RUNS,config.VAL_MC_RUNS_DEFAULT)
+ Y, y_predict, y_predict_std, y_predict_ent, y_predict_mif, y_predict_unc, mcruns_time = eval_3D_get_outputs(model, X, Y, training_bool, MC_RUNS, deepsup, deepsup_eval)
+ times_mcruns.append(mcruns_time)
+
+ for batch_id in range(X.shape[0]):
+
+ # Step 2 - Get grid info
+ patient_id_running = meta2[batch_id].numpy().decode('utf-8')
+ if patient_id_running in patient_grid_count: patient_grid_count[patient_id_running] += 1
+ else: patient_grid_count[patient_id_running] = 1
+
+ meta1_batch = meta1[batch_id].numpy()
+ w_start, h_start, d_start = meta1_batch[1], meta1_batch[2], meta1_batch[3]
+
+ # Step 3 - Check if its a new patient
+ if patient_id_running != patient_id_curr:
+
+ # Step 3.1 - Sort out old patient (patient_id_curr)
+ if patient_id_curr != None:
+
+ res, ece_global_obj = eval_3D_process_outputs(exp_name, res, ece_global_obj, patient_id_curr, meta1_batch, patient_img, patient_gt, patient_pred_vals, patient_pred_overlap, patient_pred_std, patient_pred_ent, patient_pred_mif, patient_pred_unc
+ , deepsup_eval, model_folder_epoch_imgs, model_folder_epoch_patches, label_map, label_colors, t99, show=show, save=save, verbose=verbose)
+
+ # Step 3.2 - Create variables for new patient
+ if verbose: t99 = time.time()
+ patient_id_curr = patient_id_running
+ patient_scan_size = meta1_batch[7:10]
+ dataset_name = patient_id_curr.split('-')[0]
+ dataset_this = dataset_eval.get_subdataset(param_name=dataset_name)
+ w_grid, h_grid, d_grid = dataset_this.w_grid, dataset_this.h_grid, dataset_this.d_grid
+ if deepsup_eval:
+ patient_scan_size = patient_scan_size//2
+ w_grid, h_grid, d_grid = w_grid//2, h_grid//2, d_grid//2
+
+ patient_pred_size = list(patient_scan_size) + [len(dataset_this.LABEL_MAP)]
+ patient_pred_overlap = np.zeros(patient_scan_size, dtype=np.uint8)
+ patient_pred_ent = np.zeros(patient_scan_size, dtype=np.float32)
+ patient_pred_mif = np.zeros(patient_scan_size, dtype=np.float32)
+ patient_pred_vals = np.zeros(patient_pred_size, dtype=np.float32)
+ patient_pred_std = np.zeros(patient_pred_size, dtype=np.float32)
+ patient_pred_unc = np.zeros(patient_pred_size, dtype=np.float32)
+ patient_gt = np.zeros(patient_pred_size, dtype=np.float32)
+ if show or save:
+ patient_img = np.zeros(list(patient_scan_size) + [1], dtype=np.float32)
+ else:
+ patient_img = []
+
+ # Step 4 - If not new patient anymore, fill up data
+ if deepsup_eval:
+ w_start, h_start, d_start = w_start//2, h_start//2, d_start//2
+ patient_pred_vals[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += y_predict[batch_id]
+ if len(y_predict_std):
+ patient_pred_std[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += y_predict_std[batch_id]
+ if len(y_predict_ent):
+ patient_pred_ent[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += y_predict_ent[batch_id]
+ if len(y_predict_mif):
+ patient_pred_mif[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += y_predict_mif[batch_id]
+ if len(y_predict_unc):
+ patient_pred_unc[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += y_predict_unc[batch_id]
+
+ patient_pred_overlap[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += np.ones(y_predict[batch_id].shape[:-1], dtype=np.uint8)
+ patient_gt[w_start:w_start+w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] = Y[batch_id]
+ if show or save:
+ patient_img[w_start:w_start+w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] = X[batch_id]
+
+ pbar_eval.update(batch_size)
+ mem_used = modutils.get_memory(pid)
+ memory = pbar_desc_prefix + ' [' + mem_used + ']'
+ pbar_eval.set_description(desc=memory, refresh=True)
+
+ # Step 3 - For last patient
+ res, ece_global_obj = eval_3D_process_outputs(exp_name, res, ece_global_obj, patient_id_curr, meta1_batch, patient_img, patient_gt, patient_pred_vals, patient_pred_overlap, patient_pred_std, patient_pred_ent, patient_pred_mif, patient_pred_unc
+ , deepsup_eval, model_folder_epoch_imgs, model_folder_epoch_patches, label_map, label_colors, t99, show=show, save=save, verbose=verbose)
+ print ('\n - [eval_3D()] Time passed to accumulate grids & process patients: ', round(time.time() - ttotal, 2), 's')
+
+ return eval_3D_summarize(res, ece_global_obj, model, eval_type, deepsup_eval, label_map, model_folder_epoch_patches, times_mcruns, ttotal, save=save, show=show, verbose=verbose)
+
+ except:
+ traceback.print_exc()
+ if DEBUG: pdb.set_trace()
+ model.trainable=True
+ return -1, {}
+
+############################################################
+# VAL #
+############################################################
+
+class Validator:
+
+ def __init__(self, params):
+
+ self.params = params
+
+ # Dataloader Params
+ self.data_dir = self.params['dataloader']['data_dir']
+ self.dir_type = self.params['dataloader']['dir_type']
+ self.grid = self.params['dataloader']['grid']
+ self.crop_init = self.params['dataloader']['crop_init']
+ self.resampled = self.params['dataloader']['resampled']
+
+ self._get_dataloader()
+ self._get_model()
+
+ def _set_dataloader(self):
+
+ # Step 1 - Init params
+ batch_size = self.params['dataloader']['batch_size']
+ prefetch_batch = self.params['dataloader']['prefetch_batch']
+
+ # Step 2 - Load dataset
+ if self.dir_type == config.DATALOADER_MICCAI2015_TEST:
+
+ self.dataset_test_eval = datautils.get_dataloader_3D_test_eval(self.data_dir, dir_type=self.dir_type
+ , grid=self.grid, crop_init=self.crop_init
+ , resampled=self.resampled
+ )
+
+ elif self.dir_type == config.DATALOADER_DEEPMINDTCIA_TEST:
+
+ self.annotator_type = self.params['dataloader']['annotator_type']
+ self.dataset_test_eval = datautils.get_dataloader_deepmindtcia(self.data_dir, dir_type=self.dir_type, annotator_type=self.annotator_type
+ , grid=self.grid, crop_init=self.crop_init, resampled=self.resampled
+ )
+
+ # Step 3 - Load dataloader/datagenerator
+ self.datagen_test_eval = self.dataset_test_eval.generator().batch(batch_size).prefetch(prefetch_batch)
+
+ def _set_model(self):
+
+ # Step 1 - Init params
+ exp_name = self.params['exp_name']
+ load_epoch = self.params['model']['load_epoch']
+ class_count = len(self.dataset_test_eval.datasets[0].LABEL_MAP.values())
+
+ # Step 2 - Get model arch
+ if self.params['model']['name'] == config.MODEL_FOCUSNET_DROPOUT:
+ print (' - [Trainer][_models()] ModelFocusNetDropOut')
+ self.model = models.ModelFocusNetDropOut(class_count=class_count, trainable=False)
+
+ # Step 3 - Load model
+ load_model_params = {'PROJECT_DIR': config.PROJECT_DIR
+ , 'exp_name': exp_name
+ , 'load_epoch': load_epoch
+ , 'optimizer': tf.keras.optimizers.Adam()
+ }
+ modutils.load_model(self.model, load_type=config.MODE_TRAIN, params=load_model_params)
+ print ('')
+ print (' - [train.py][val()] Model({}) Loaded for {} at epoch-{} (validation purposes) !'.format(str(self.model), exp_name, load_epoch))
+ print ('')
+
+ def validate(self, verbose=False):
+
+ save = self.params['save']
+ loss_avg, loss_labels_avg = eval_3D(self.model, self.dataset_test_eval, self.datagen_test_eval, self.params, save=save, verbose=verbose)
+
+############################################################
+# TRAINER #
+############################################################
+
+class Trainer:
+
+ def __init__(self, params):
+
+ # Init
+ self.params = params
+
+ # Print
+ self._train_preprint()
+
+ # Random Seeds
+ self._set_seed()
+
+ # Set the dataloaders
+ self._set_dataloaders()
+
+ # Set the model
+ self._set_model()
+
+ # Set Metrics
+ self._set_metrics()
+
+ # Other flags
+ self.write_model_done = False
+
+ def _train_preprint(self):
+ print ('')
+ print (' -------------- {} ({})'.format(self.params['exp_name'], str(datetime.datetime.now())))
+
+ print ('')
+ print (' DATALOADER ')
+ print (' ---------- ')
+ print (' - dir_type: ', self.params['dataloader']['dir_type'])
+
+ print (' -- resampled: ', self.params['dataloader']['resampled'])
+ print (' -- crop_init: ', self.params['dataloader']['crop_init'])
+ print (' -- grid: ', self.params['dataloader']['grid'])
+ print (' --- filter_grid : ', self.params['dataloader']['filter_grid'])
+ print (' --- random_grid : ', self.params['dataloader']['random_grid'])
+ print (' --- centred_prob : ', self.params['dataloader']['centred_prob'])
+
+ print (' -- batch_size: ', self.params['dataloader']['batch_size'])
+ print (' -- prefetch_batch : ', self.params['dataloader']['prefetch_batch'])
+ print (' -- parallel_calls : ', self.params['dataloader']['parallel_calls'])
+ print (' -- shuffle : ', self.params['dataloader']['shuffle'])
+
+ print ('')
+ print (' MODEL ')
+ print (' ----- ')
+ print (' - Model: ', str(self.params['model']['name']))
+ print (' -- Model TBoard: ', self.params['model']['model_tboard'])
+ print (' -- Profiler: ', self.params['model']['profiler']['profile'])
+ if self.params['model']['profiler']['profile']:
+ print (' ---- Profiler Epochs: ', self.params['model']['profiler']['epochs'])
+ print (' ---- Step Per Epochs: ', self.params['model']['profiler']['steps_per_epoch'])
+ print (' - Optimizer: ', str(self.params['model']['optimizer']))
+ print (' -- Init LR: ', self.params['model']['init_lr'])
+ print (' -- Fixed LR: ', self.params['model']['fixed_lr'])
+
+ print (' - Epochs: ', self.params['model']['epochs'])
+ print (' -- Save : every {} epochs'.format(self.params['model']['epochs_save']))
+ print (' -- Eval3D : every {} epochs '.format(self.params['model']['epochs_eval']))
+ print (' -- Viz3D : every {} epochs '.format(self.params['model']['epochs_viz']))
+
+ print ('')
+ print (' METRICS ')
+ print (' ------- ')
+ print (' - Logging-TBoard : ', self.params['metrics']['logging_tboard'])
+ if not self.params['metrics']['logging_tboard']:
+ print (' !!!!!!!!!!!!!!!!!!! NO LOGGING-TBOARD !!!!!!!!!!!!!!!!!!!')
+ print ('')
+ print (' - Eval : ', self.params['metrics']['metrics_eval'])
+ print (' - Loss : ', self.params['metrics']['metrics_loss'])
+ print (' -- Weighted Loss : ', self.params['metrics']['loss_weighted'])
+ print (' -- Masked Loss : ', self.params['metrics']['loss_mask'])
+ print (' -- Combo : ', self.params['metrics']['loss_combo'])
+
+ print ('')
+ print (' DEVOPS ')
+ print (' ------ ')
+ self.pid = os.getpid()
+ print (' - OS-PID: ', self.pid)
+ print (' - Seed: ', self.params['random_seed'])
+
+ print ('')
+
+ def _set_seed(self):
+ np.random.seed(self.params['random_seed'])
+ tf.random.set_seed(self.params['random_seed'])
+
+ def _set_dataloaders(self):
+
+ print ('')
+ print (' DATALOADER OBJECTS')
+ print (' ---------- ')
+
+ # Params - Directories
+ data_dir = self.params['dataloader']['data_dir']
+ dir_type = self.params['dataloader']['dir_type']
+ dir_type_eval = ['_'.join(dir_type)]
+
+ # Params - Single volume
+ resampled = self.params['dataloader']['resampled']
+ crop_init = self.params['dataloader']['crop_init']
+ grid = self.params['dataloader']['grid']
+ filter_grid = self.params['dataloader']['filter_grid']
+ random_grid = self.params['dataloader']['random_grid']
+ centred_prob = self.params['dataloader']['centred_prob']
+
+ # Params - Dataloader
+ batch_size = self.params['dataloader']['batch_size']
+ prefetch_batch = self.params['dataloader']['prefetch_batch']
+ parallel_calls = self.params['dataloader']['parallel_calls']
+ shuffle_size = self.params['dataloader']['shuffle']
+
+ # Params - Debug
+ pass
+
+ # Datasets
+ self.dataset_train = datautils.get_dataloader_3D_train(data_dir, dir_type=dir_type
+ , grid=grid, crop_init=crop_init, filter_grid=filter_grid
+ , random_grid=random_grid
+ , resampled=resampled
+ , parallel_calls=parallel_calls
+ , centred_dataloader_prob=centred_prob
+ )
+ self.dataset_train_eval = datautils.get_dataloader_3D_train_eval(data_dir, dir_type=dir_type_eval
+ , grid=grid, crop_init=crop_init
+ , resampled=resampled
+ )
+ self.dataset_test_eval = datautils.get_dataloader_3D_test_eval(data_dir
+ , grid=grid, crop_init=crop_init
+ , resampled=resampled
+ )
+
+ # Get labels Ids
+ self.label_map = dict(self.dataset_train.get_label_map())
+ self.label_ids = self.label_map.values()
+ self.params['internal'] = {}
+ self.params['internal']['label_map'] = self.label_map # for use in Metrics
+ self.params['internal']['label_ids'] = self.label_ids # for use in Metrics
+ self.label_weights = list(self.dataset_train.get_label_weights())
+
+ # Generators
+ # repeat() -> shuffle() --> batch() --> prefetch(): endlessly-output-data --> shuffle-within-buffer --> create-batch-from-shuffle --> prefetch-enough-batches
+ self.dataset_train_gen = self.dataset_train.generator().repeat().shuffle(shuffle_size).batch(batch_size).apply(tf.data.experimental.prefetch_to_device(device='/GPU:0', buffer_size=prefetch_batch))
+ self.dataset_train_eval_gen = self.dataset_train_eval.generator().batch(2).prefetch(4)
+ self.dataset_test_eval_gen = self.dataset_test_eval.generator().batch(2).prefetch(4)
+
+ def _set_model(self):
+
+ print ('')
+ print (' MODEL OBJECTS')
+ print (' ---------- ')
+
+ # Step 1 - Get class ids
+ class_count = len(self.label_ids)
+
+ # Step 2 - Get model arch
+ if self.params['model']['name'] == config.MODEL_FOCUSNET_DROPOUT:
+ print (' - [Trainer][_models()] ModelFocusNetDropOut')
+ self.model = models.ModelFocusNetDropOut(class_count=class_count, trainable=True)
+ self.kl_alpha_init = 0.0
+
+ elif self.params['model']['name'] == config.MODEL_FOCUSNET_FLIPOUT:
+ print (' - [Trainer][_models()] ModelFocusNetFlipOut')
+ self.model = models.ModelFocusNetFlipOut(class_count=class_count, trainable=True)
+ self.kl_schedule = self.params['model']['kl_schedule']
+ self.kl_alpha_init = 0.1
+ self.kl_alpha_increase_per_epoch = 0.0001
+ self.initial_epoch = 250
+
+ # Step 3 - Get optimizer
+ if self.params['model']['optimizer'] == config.OPTIMIZER_ADAM:
+ self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.params['model']['init_lr'])
+
+ # Step 4 - Load model (if needed)
+ epochs = self.params['model']['epochs']
+ if not self.params['model']['load_model']['load']:
+ # Step 4.1 - Set epoch range under non-loading situations
+ self.epoch_range = range(1,epochs+1)
+ else:
+
+ # Step 4.2.1 - Some model-loading params
+ load_epoch = self.params['model']['load_model']['load_epoch']
+ load_exp_name = self.params['model']['load_model']['load_exp_name']
+ load_optimizer_lr = self.params['model']['load_model']['load_optimizer_lr']
+ load_model_params = {'PROJECT_DIR': config.PROJECT_DIR, 'load_epoch': load_epoch, 'optimizer':self.optimizer}
+
+ print ('')
+ print (' - [Trainer][_set_model()] Loading pretrained model')
+ print (' - [Trainer][_set_model()] Model: ', self.model)
+
+ # Step 4.2.2.1 - If loading is done from the same exp_name
+ if load_exp_name is None:
+ load_model_params['exp_name'] = self.params['exp_name']
+ self.epoch_range = range(load_epoch+1, epochs+1)
+ print (' - [Trainer][_set_model()] Training from epoch:{} to {}'.format(load_epoch, epochs))
+ # Step 4.2.2.1 - If loading is done from another exp_name
+ else:
+ self.epoch_range = range(1, epochs+1)
+ load_model_params['exp_name'] = load_exp_name
+ print (' - [Trainer][_set_model()] Training from epoch:{} to {}'.format(1, epochs))
+
+ print (' - [Trainer][_set_model()] exp_name: ', load_model_params['exp_name'])
+
+ # Step 4.3 - Finally, load model from the checkpoint
+ modutils.load_model(self.model, load_type=config.MODE_TRAIN, params=load_model_params)
+ print (' - [Trainer][_set_model()] Model Loaded at epoch-{} !'.format(load_epoch))
+ print (' -- [Trainer][_set_model()] Optimizer.lr : ', self.optimizer.lr.numpy())
+ if load_optimizer_lr is not None:
+ self.optimizer.lr.assign(load_optimizer_lr)
+ print (' -- [Trainer][_set_model()] Optimizer.lr : ', self.optimizer.lr.numpy())
+
+ # Step 5 - Create model weights
+ _ = self.model(tf.random.normal((1,140,140,40,1))) # [NOTE: This needs to be the same as the input grid size in FlipOut]
+ print (' -- [Trainer][_set_model()] Created model weights ')
+ try:
+ print (' --------------------------------------- ')
+ print (self.model.summary(line_length=150))
+ print (' --------------------------------------- ')
+ count = 0
+ for var in self.model.trainable_variables:
+ print (' - var: ', var.name)
+ count += 1
+ if count > 20:
+ print (' ... ')
+ break
+ except:
+ print (' - [Trainer][_set_model()] model.summary() failed')
+ pass
+
+ def _set_metrics(self):
+ self.metrics = {}
+ self.metrics[config.MODE_TRAIN] = ModelMetrics(metric_type=config.MODE_TRAIN, params=self.params)
+ self.metrics[config.MODE_TEST] = ModelMetrics(metric_type=config.MODE_TEST, params=self.params)
+
+ def _set_profiler(self, epoch, epoch_step):
+ """
+ - Ref: https://www.tensorflow.org/guide/data_performance_analysis#analysis_workflow
+ """
+ exp_name = self.params['exp_name']
+
+ if self.params['model']['profiler']['profile']:
+ if epoch in self.params['model']['profiler']['epochs']:
+ if epoch_step == self.params['model']['profiler']['starting_step']:
+ self.logdir = Path(config.MODEL_CHKPOINT_MAINFOLDER).joinpath(exp_name, config.MODEL_LOGS_FOLDERNAME, 'profiler', str(epoch))
+ tf.profiler.experimental.start(str(self.logdir))
+ print (' - tf.profiler.experimental.start(logdir)')
+ print ('')
+ elif epoch_step == self.params['model']['profiler']['starting_step'] + self.params['model']['profiler']['steps_per_epoch']:
+ print (' - tf.profiler.experimental.stop()')
+ tf.profiler.experimental.stop()
+ print ('')
+
+ @tf.function
+ def _train_loss(self, Y, y_predict, meta1, mode):
+
+ trainMetrics = self.metrics[mode]
+ metrics_loss = self.params['metrics']['metrics_loss']
+ loss_weighted = self.params['metrics']['loss_weighted']
+ loss_mask = self.params['metrics']['loss_mask']
+ loss_combo = self.params['metrics']['loss_combo']
+
+ label_ids = self.label_ids
+ label_weights = tf.cast(self.label_weights, dtype=tf.float32)
+
+ loss_vals = tf.constant(0.0, dtype=tf.float32)
+ mask = losses.get_mask(meta1[:,-len(label_ids):], Y)
+
+ inf_flag = False
+ nan_flag = False
+
+ for metric_str in metrics_loss:
+
+ weights = []
+ if loss_weighted[metric_str]:
+ weights = label_weights
+
+ if not loss_mask[metric_str]:
+ mask = tf.cast(tf.cast(mask + 1, dtype=tf.bool), dtype=tf.float32)
+
+ if metrics_loss[metric_str] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_CE_BASIC]:
+ loss_val_train, loss_labellist_train, metric_val_report, metric_labellist_report = trainMetrics.losses_obj[metric_str](Y, y_predict, mask, weights=weights)
+ if metrics_loss[metric_str] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_CE_BASIC]:
+ nan_list = tf.math.is_nan(loss_labellist_train)
+ nan_val = tf.math.is_nan(loss_val_train)
+ inf_list = tf.math.is_inf(loss_labellist_train)
+ inf_val = tf.math.is_inf(loss_val_train)
+ if nan_val or tf.math.reduce_any(nan_list):
+ nan_flag = True
+ tf.print (' - [ERROR][Trainer][_train_loss()] Loss NaN spotted: ', metric_str, ' || nan_list: ', nan_list, ' || nan_val: ', nan_val)
+ tf.print (' - [ERROR][Trainer][_train_loss()] Loss NaN spotted: ', metric_str, ' || mask: ', mask)
+ tf.print (' - [ERROR][Trainer][_train_loss()] Loss NaN spotted: ', metric_str, ' || loss_vals: ', loss_vals)
+ elif inf_val or tf.math.reduce_any(inf_list):
+ inf_flag = True
+ tf.print (' - [ERROR][Trainer][_train_loss()] Loss Inf spotted: ', metric_str, ' || loss_val_train: ', loss_val_train)
+ tf.print (' - [ERROR][Trainer][_train_loss()] Loss Inf spotted: ', metric_str, ' || inf_list: ', inf_list, ' || inf_val: ', inf_val)
+ tf.print (' - [ERROR][Trainer][_train_loss()] Loss Inf spotted: ', metric_str, ' || mask: ', mask)
+ tf.print (' - [ERROR][Trainer][_train_loss()] Loss NaN spotted: ', metric_str, ' || loss_vals: ', loss_vals)
+ else:
+ if len(metric_labellist_report):
+ trainMetrics.update_metric_loss_labels(metric_str, metric_labellist_report) # in sub-3D settings, this value is only indicative of performance
+ trainMetrics.update_metric_loss(metric_str, loss_val_train)
+
+ if 0: # [DEBUG]
+ tf.print(' - metric_str: ', metric_str, ' || loss_val_train: ', loss_val_train)
+
+ if metric_str in loss_combo:
+ loss_val_train = loss_val_train*loss_combo[metric_str]
+ loss_vals = tf.math.add(loss_vals, loss_val_train) # Averaged loss
+
+
+ if nan_flag or inf_flag:
+ loss_vals = 0.0 # no backprop when something was wrong
+
+ return loss_vals
+
+ @tf.function
+ def _train_step(self, X, Y, meta1, meta2, kl_alpha):
+
+ try:
+ model = self.model
+ optimizer = self.optimizer
+ if self.params['model']['name'] == config.MODEL_FOCUSNET_FLIPOUT:
+ kl_scale_fac = self.params['model']['kl_scale_factor']
+ else:
+ kl_scale_fac = 0.0
+
+ y_predict = None
+ loss_vals = None
+ gradients = None
+
+ # Step 1 - Calculate loss
+ with tf.GradientTape() as tape:
+
+ t2 = tf.timestamp()
+ y_predict = model(X, training=True)
+ t2_ = tf.timestamp()
+
+ loss_vals = self._train_loss(Y, y_predict, meta1, mode=config.MODE_TRAIN)
+
+ if self.params['model']['name'] == config.MODEL_FOCUSNET_FLIPOUT:
+ print (' - [Trainer][_train_step()] Model FlipOut')
+ kl = tf.math.add_n(model.losses)
+ kl_loss = kl*kl_alpha/kl_scale_fac
+ loss_vals = loss_vals + kl_loss
+
+ # Step 2 - Calculate gradients
+ t3 = tf.timestamp()
+ if not tf.math.reduce_any(tf.math.is_nan(loss_vals)):
+ all_vars = model.trainable_variables
+
+ gradients = tape.gradient(loss_vals, all_vars) # dL/dW
+
+ # Step 3 - Apply gradients
+ optimizer.apply_gradients(zip(gradients, all_vars))
+
+ else:
+ tf.print('\n ====================== [NaN Error] ====================== ')
+ tf.print(' - [ERROR][Trainer][_train_step()] Loss NaN spotted || loss_vals: ', loss_vals)
+ tf.print(' - [ERROR][Trainer][_train_step()] meta2: ', meta2, ' || meta1: ', meta1)
+
+ t3_ = tf.timestamp()
+ return t2_-t2, t3-t2_, t3_-t3
+
+ except tf.errors.ResourceExhaustedError as e:
+ tf.print('\n ====================== [OOM Error] ====================== ')
+ tf.print(' - [ERROR][Trainer][_train_step()] meta2: ', meta2, ' || meta1: ', meta1)
+ traceback.print_exc()
+ return None, None, None
+
+ except:
+ tf.print('\n ====================== [Some Error] ====================== ')
+ tf.print(' - [ERROR][Trainer][_train_step()] meta2: ', meta2, ' || meta1: ', meta1)
+ traceback.print_exc()
+ return None, None, None
+
+ def train(self):
+
+ # Global params
+ exp_name = self.params['exp_name']
+
+ # Dataloader params
+ batch_size = self.params['dataloader']['batch_size']
+
+ # Model/Training params
+ fixed_lr = self.params['model']['fixed_lr']
+ init_lr = self.params['model']['init_lr']
+ max_epoch = self.params['model']['epochs']
+ epoch_range = iter(self.epoch_range)
+ epoch_length = len(self.dataset_train)
+ params_save_model = {'PROJECT_DIR': config.PROJECT_DIR, 'exp_name': exp_name, 'optimizer':self.optimizer}
+
+ # Metrics params
+ metrics_eval = self.params['metrics']['metrics_eval']
+ trainMetrics = self.metrics[config.MODE_TRAIN]
+
+ # KL Divergence Params
+ kl_alpha = self.kl_alpha_init # [0.0, self.kl_alpha_init]
+
+ # Eval Params
+ params_eval = {'PROJECT_DIR': config.PROJECT_DIR, 'exp_name': exp_name, 'pid': self.pid
+ , 'eval_type': config.MODE_TRAIN, 'batch_size': batch_size}
+
+ # Viz params
+ epochs_save = self.params['model']['epochs_save']
+ epochs_viz = self.params['model']['epochs_viz']
+ epochs_eval = self.params['model']['epochs_eval']
+
+ # Random vars
+ t_start_time = time.time()
+
+ try:
+
+ epoch_step = 0
+ epoch = None
+ pbar = None
+ t1 = time.time()
+ for (X,Y,meta1,meta2) in self.dataset_train_gen:
+ t1_ = time.time()
+
+ try:
+ # Epoch starter code
+ if epoch_step == 0:
+
+ # Get Epoch
+ epoch = next(epoch_range)
+
+ # Metrics
+ trainMetrics.reset_metrics(self.params)
+
+ # LR
+ if not fixed_lr:
+ set_lr(epoch, self.optimizer, init_lr)
+ self.model.trainable = True
+
+ # Calculate kl_alpha (only for FlipOut model)
+ if self.params['model']['name'] == config.MODEL_FOCUSNET_FLIPOUT:
+ if self.kl_schedule == config.KL_DIV_ANNEALING:
+ if epoch > self.initial_epoch:
+ if epoch % self.kl_epochs_change == 0:
+ kl_alpha = tf.math.minimum(self.kl_alpha_max, self.kl_alpha_init + (epoch - self.initial_epoch)/float(self.kl_epochs_change) * self.kl_alpha_increase_per_epoch)
+ else:
+ kl_alpha = 0.0
+
+ # Pretty print
+ print ('')
+ print (' ===== [{}] EPOCH:{}/{} (LR={:3f}) =================='.format(exp_name, epoch, max_epoch, self.optimizer.lr.numpy()))
+
+ # Start a fresh pbar
+ pbar = tqdm.tqdm(total=epoch_length, desc='')
+
+ # Model Writing to tensorboard
+ if self.params['model']['model_tboard'] and self.write_model_done is False :
+ self.write_model_done = True
+ modutils.write_model(self.model, X, self.params)
+
+ # Start/Stop Profiling (after dataloader is kicked off)
+ self._set_profiler(epoch, epoch_step)
+
+ # Calculate loss and gradients from them
+ time_predict, time_loss, time_backprop = self._train_step(X, Y, meta1, meta2, kl_alpha)
+
+ # Update metrics (time + eval + plots)
+ time_dataloader = t1_ - t1
+ trainMetrics.update_metrics_time(time_dataloader, time_predict, time_loss, time_backprop)
+
+ # Update looping stuff
+ epoch_step += batch_size
+ pbar.update(batch_size)
+ trainMetrics.update_pbar(pbar)
+
+ except:
+ modutils.print_exp_name(exp_name + '-' + config.MODE_TRAIN, epoch)
+ params_save_model['epoch'] = epoch
+ modutils.save_model(self.model, params_save_model)
+ traceback.print_exc()
+ pdb.set_trace()
+
+ if epoch_step >= epoch_length:
+
+ # Reset epoch-loop params
+ pbar.close()
+ epoch_step = 0
+
+ try:
+ # Model save
+ if epoch % epochs_save == 0:
+ params_save_model['epoch'] = epoch
+ modutils.save_model(self.model, params_save_model)
+
+ # Eval on full 3D
+ if epoch % epochs_eval == 0:
+ self.params['epoch'] = epoch
+ save=False
+ if epoch > 0 and epoch % epochs_viz == 0:
+ save=True
+
+ self.model.trainable = False
+ for metric_str in metrics_eval:
+ if metrics_eval[metric_str] in [config.LOSS_DICE]:
+ params_eval['epoch'] = epoch
+ params_eval['eval_type'] = config.MODE_TRAIN
+ eval_avg, eval_labels_avg = eval_3D(self.model, self.dataset_train_eval, self.dataset_train_eval_gen, params_eval, save=save)
+ trainMetrics.update_metric_eval_labels(metric_str, eval_labels_avg, do_average=True)
+
+ # Test
+ if epoch % epochs_eval == 0:
+ self._test()
+ self.model.trainable = True
+
+ # Epochs summary/wrapup
+ eval_condition = epoch % epochs_eval == 0
+ trainMetrics.write_epoch_summary(epoch, self.label_map, {'optimizer':self.optimizer}, eval_condition)
+
+ if epoch > 0 and epoch % self.params['others']['epochs_timer'] == 0:
+ elapsed_seconds = time.time() - t_start_time
+ print (' - Total time elapsed : {}'.format( str(datetime.timedelta(seconds=elapsed_seconds)) ))
+ if epoch % self.params['others']['epochs_memory'] == 0:
+ mem_before = modutils.get_memory(self.pid)
+ gc_n = gc.collect()
+ mem_after = modutils.get_memory(self.pid)
+ print(' - Unreachable objects collected by GC: {} || ({}) -> ({})'.format(gc_n, mem_before, mem_after))
+
+ # Break out of loop at end of all epochs
+ if epoch == max_epoch:
+ print ('\n\n - [Trainer][train()] All epochs finished')
+ break
+
+ except:
+ modutils.print_exp_name(exp_name + '-' + config.MODE_TRAIN, epoch)
+ params_save_model['epoch'] = epoch
+ modutils.save_model(self.model, params_save_model)
+ traceback.print_exc()
+ pdb.set_trace()
+
+ t1 = time.time() # reset dataloader time calculator
+
+ except:
+ modutils.print_exp_name(exp_name + '-' + config.MODE_TRAIN, epoch)
+ traceback.print_exc()
+ pdb.set_trace()
+
+ def _test(self):
+
+ exp_name = None
+ epoch = None
+ try:
+
+ # Step 1.1 - Params
+ exp_name = self.params['exp_name']
+ epoch = self.params['epoch']
+ deepsup = self.params['model']['deepsup']
+
+ metrics_eval = self.params['metrics']['metrics_eval']
+ epochs_viz = self.params['model']['epochs_viz']
+ batch_size = self.params['dataloader']['batch_size']
+
+ # vars
+ testMetrics = self.metrics[config.MODE_TEST]
+ testMetrics.reset_metrics(self.params)
+ params_eval = {'PROJECT_DIR': config.PROJECT_DIR, 'exp_name': exp_name, 'pid': self.pid
+ , 'eval_type': config.MODE_TEST, 'batch_size': batch_size
+ , 'epoch':epoch}
+
+ # Step 2 - Eval on full 3D
+ save=False
+ if epoch > 0 and epoch % epochs_viz == 0:
+ save=True
+ for metric_str in metrics_eval:
+ if metrics_eval[metric_str] in [config.LOSS_DICE]:
+ params_eval['eval_type'] = config.MODE_TEST
+ eval_avg, eval_labels_avg = eval_3D(self.model, self.dataset_test_eval, self.dataset_test_eval_gen, params_eval, save=save)
+ testMetrics.update_metric_eval_labels(metric_str, eval_labels_avg, do_average=True)
+
+ testMetrics.write_epoch_summary(epoch, self.label_map, {}, True)
+
+ except:
+ modutils.print_exp_name(exp_name + '-' + config.MODE_TEST, epoch)
+ traceback.print_exc()
+ pdb.set_trace()
+
diff --git a/src/model/trainer_flipout.py b/src/model/trainer_flipout.py
new file mode 100644
index 0000000..2b45341
--- /dev/null
+++ b/src/model/trainer_flipout.py
@@ -0,0 +1,2435 @@
+# Import internal libraries
+import src.config as config
+import src.utils as utils
+import src.models as models
+import src.losses as losses
+
+import medloader.dataloader.config as medconfig
+import medloader.dataloader.tensorflow.augmentations as aug
+from medloader.dataloader.tensorflow.dataset import ZipDataset
+from medloader.dataloader.tensorflow.han_miccai2015v2 import HaNMICCAI2015Dataset
+from medloader.dataloader.tensorflow.han_deepmindtcia import HaNDeepMindTCIADataset
+import medloader.dataloader.utils as medutils
+
+# Import external libraries
+import os
+import gc
+import pdb
+import copy
+import time
+import tqdm
+import datetime
+import traceback
+import numpy as np
+import tensorflow as tf
+import tensorflow_probability as tfp
+from pathlib import Path
+
+import matplotlib.pyplot as plt;
+import seaborn as sns;
+
+_EPSILON = tf.keras.backend.epsilon()
+DEBUG = False
+
+############################################################
+# DATALOADER RELATED #
+############################################################
+
+def get_dataloader_3D_train(data_dir, dir_type=['train', 'train_additional']
+ , dimension=3, grid=True, crop_init=True, resampled=True, mask_type=medconfig.MASK_TYPE_ONEHOT
+ , transforms=[], filter_grid=False, random_grid=True
+ , parallel_calls=None, deterministic=False
+ , patient_shuffle=True
+ , centred_dataloader_prob=0.0
+ , debug=False, single_sample=False
+ , pregridnorm=True):
+
+ debug = False
+ datasets = []
+
+ # Dataset 1
+ for dir_type_ in dir_type:
+
+ # Step 1 - Get dataset class
+ dataset_han_miccai2015 = HaNMICCAI2015Dataset(data_dir=data_dir, dir_type=dir_type_
+ , dimension=dimension, grid=grid, crop_init=crop_init, resampled=resampled, mask_type=mask_type
+ , transforms=transforms, filter_grid=filter_grid, random_grid=random_grid
+ , parallel_calls=parallel_calls, deterministic=deterministic
+ , patient_shuffle=patient_shuffle
+ , centred_dataloader_prob=centred_dataloader_prob
+ , debug=debug, single_sample=single_sample
+ , pregridnorm=pregridnorm)
+
+ # Step 2 - Training transforms
+ x_shape_w = dataset_han_miccai2015.w_grid
+ x_shape_h = dataset_han_miccai2015.h_grid
+ x_shape_d = dataset_han_miccai2015.d_grid
+ label_map = dataset_han_miccai2015.LABEL_MAP
+ HU_MIN = dataset_han_miccai2015.HU_MIN
+ HU_MAX = dataset_han_miccai2015.HU_MAX
+ transforms = [
+ # minmax
+ aug.Rotate3DSmall(label_map, mask_type)
+ , aug.Deform2Punt5D((x_shape_h, x_shape_w, x_shape_d), label_map, grid_points=50, stddev=4, div_factor=2, debug=False)
+ , aug.Translate(label_map, translations=[40,40])
+ , aug.Noise(x_shape=(x_shape_h, x_shape_w, x_shape_d, 1), mean=0.0, std=0.1)
+ ]
+ dataset_han_miccai2015.transforms = transforms
+
+ # Step 3 - Training filters for background-only grids
+ if filter_grid:
+ dataset_han_miccai2015.filter = aug.FilterByMask(len(dataset_han_miccai2015.LABEL_MAP), dataset_han_miccai2015.SAMPLER_PERC)
+
+ # Step 4 - Append to list
+ datasets.append(dataset_han_miccai2015)
+
+ dataset = ZipDataset(datasets)
+ return dataset
+
+def get_dataloader_3D_train_eval(data_dir, dir_type=['train_train_additional']
+ , dimension=3, grid=True, crop_init=True, resampled=True, mask_type=medconfig.MASK_TYPE_ONEHOT
+ , transforms=[], filter_grid=False, random_grid=False
+ , parallel_calls=None, deterministic=True
+ , patient_shuffle=False
+ , centred_dataloader_prob=0.0
+ , debug=False, single_sample=False
+ , pregridnorm=True):
+
+ datasets = []
+
+ # Dataset 1
+ for dir_type_ in dir_type:
+
+ # Step 1 - Get dataset class
+ dataset_han_miccai2015 = HaNMICCAI2015Dataset(data_dir=data_dir, dir_type=dir_type_
+ , dimension=dimension, grid=grid, crop_init=crop_init, resampled=resampled, mask_type=mask_type
+ , transforms=transforms, filter_grid=filter_grid, random_grid=random_grid
+ , parallel_calls=parallel_calls, deterministic=deterministic
+ , patient_shuffle=patient_shuffle
+ , centred_dataloader_prob=centred_dataloader_prob
+ , debug=debug, single_sample=single_sample
+ , pregridnorm=pregridnorm)
+
+ # Step 2 - Training transforms
+ # None
+
+ # Step 3 - Training filters for background-only grids
+ # None
+
+ # Step 4 - Append to list
+ datasets.append(dataset_han_miccai2015)
+
+ dataset = ZipDataset(datasets)
+ return dataset
+
+def get_dataloader_3D_test_eval(data_dir, dir_type=['test_offsite']
+ , dimension=3, grid=True, crop_init=True, resampled=True, mask_type=medconfig.MASK_TYPE_ONEHOT
+ , transforms=[], filter_grid=False, random_grid=False
+ , parallel_calls=None, deterministic=True
+ , patient_shuffle=False
+ , debug=False, single_sample=False
+ , pregridnorm=True):
+
+ datasets = []
+
+ # Dataset 1
+ for dir_type_ in dir_type:
+
+ # Step 1 - Get dataset class
+ dataset_han_miccai2015 = HaNMICCAI2015Dataset(data_dir=data_dir, dir_type=dir_type_
+ , dimension=dimension, grid=grid, crop_init=crop_init, resampled=resampled, mask_type=mask_type
+ , transforms=transforms, filter_grid=filter_grid, random_grid=random_grid
+ , parallel_calls=parallel_calls, deterministic=deterministic
+ , patient_shuffle=patient_shuffle
+ , debug=debug, single_sample=single_sample
+ , pregridnorm=pregridnorm)
+
+ # Step 2 - Testing transforms
+ # None
+
+ # Step 3 - Testing filters for background-only grids
+ # None
+
+ # Step 4 - Append to list
+ datasets.append(dataset_han_miccai2015)
+
+ dataset = ZipDataset(datasets)
+ return dataset
+
+def get_dataloader_deepmindtcia(data_dir
+ , dir_type=[medconfig.DATALOADER_DEEPMINDTCIA_TEST]
+ , annotator_type=[medconfig.DATALOADER_DEEPMINDTCIA_ONC]
+ , grid=True, crop_init=True, resampled=True, mask_type=medconfig.MASK_TYPE_ONEHOT
+ , transforms=[], filter_grid=False, random_grid=False, pregridnorm=True
+ , parallel_calls=None, deterministic=False
+ , patient_shuffle=True
+ , centred_dataloader_prob = 0.0
+ , debug=False, single_sample=False):
+
+ datasets = []
+
+ # Dataset 1
+ for dir_type_ in dir_type:
+
+ for anno_type_ in annotator_type:
+
+ # Step 1 - Get dataset class
+ dataset_han_deepmindtcia = HaNDeepMindTCIADataset(data_dir=data_dir, dir_type=dir_type_, annotator_type=anno_type_
+ , grid=grid, crop_init=crop_init, resampled=resampled, mask_type=mask_type, pregridnorm=pregridnorm
+ , transforms=transforms, filter_grid=filter_grid, random_grid=random_grid
+ , parallel_calls=parallel_calls, deterministic=deterministic
+ , patient_shuffle=patient_shuffle
+ , centred_dataloader_prob=centred_dataloader_prob
+ , debug=debug, single_sample=single_sample)
+
+ # Step 2 - Append to list
+ datasets.append(dataset_han_deepmindtcia)
+
+ dataset = ZipDataset(datasets)
+ return dataset
+
+############################################################
+# MODEL RELATED #
+############################################################
+
+def set_lr(epoch, optimizer, init_lr):
+
+ # if epoch == 1: # for models that are preloaded from another model
+ # print (' - [set_lr()] Setting optimizer lr to ', init_lr)
+ # optimizer.lr.assign(init_lr)
+
+ if epoch > 1 and epoch % 20 == 0:
+ optimizer.lr.assign(optimizer.lr * 0.98)
+
+############################################################
+# METRICS RELATED #
+############################################################
+class ModelMetrics():
+
+ def __init__(self, metric_type, params):
+
+ self.params = params
+
+ self.label_map = params['internal']['label_map']
+ self.label_ids = params['internal']['label_ids']
+ self.logging_tboard = params['metrics']['logging_tboard']
+ self.metric_type = metric_type
+
+ self.losses_obj = self.get_losses_obj(params)
+ self.metrics_layers_kl_divergence = {} # empty for now
+
+ self.init_metrics(params)
+ if self.logging_tboard:
+ self.init_tboard_writers(params)
+
+ self.reset_metrics(params)
+ self.init_epoch0()
+ self.reset_metrics(params)
+
+ def get_losses_obj(self, params):
+ losses_obj = {}
+ for loss_key in params['metrics']['metrics_loss']:
+ if config.LOSS_DICE == params['metrics']['metrics_loss'][loss_key]:
+ losses_obj[loss_key] = losses.loss_dice_3d_tf_func
+ if config.LOSS_FOCAL == params['metrics']['metrics_loss'][loss_key]:
+ losses_obj[loss_key] = losses.loss_focal_3d_tf_func
+ if config.LOSS_CE == params['metrics']['metrics_loss'][loss_key]:
+ losses_obj[loss_key] = losses.loss_ce_3d_tf_func
+ if config.LOSS_NCC == params['metrics']['metrics_loss'][loss_key]:
+ losses_obj[loss_key] = losses.loss_ncc_numpy
+ if config.LOSS_PAVPU == params['metrics']['metrics_loss'][loss_key]:
+ losses_obj[loss_key] = losses.loss_avu_3d_tf_func
+
+ return losses_obj
+
+ def init_metrics(self, params):
+ """
+ These are metrics derived from tensorflows library
+ """
+ # Metrics for losses (during training for smaller grids)
+ self.metrics_loss_obj = {}
+ for metric_key in params['metrics']['metrics_loss']:
+ self.metrics_loss_obj[metric_key] = {}
+ self.metrics_loss_obj[metric_key]['total'] = tf.keras.metrics.Mean(name='Avg{}-{}'.format(metric_key, self.metric_type))
+ if params['metrics']['metrics_loss'][metric_key] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_FOCAL]:
+ for label_id in self.label_ids:
+ self.metrics_loss_obj[metric_key][label_id] = tf.keras.metrics.Mean(name='Avg{}-Label-{}-{}'.format(metric_key, label_id, self.metric_type))
+
+ # Metrics for eval (for full 3D volume)
+ self.metrics_eval_obj = {}
+ for metric_key in params['metrics']['metrics_eval']:
+ self.metrics_eval_obj[metric_key] = {}
+ self.metrics_eval_obj[metric_key]['total'] = tf.keras.metrics.Mean(name='Avg{}-{}'.format(metric_key, self.metric_type))
+ if params['metrics']['metrics_eval'][metric_key] in [config.LOSS_DICE]:
+ for label_id in self.label_ids:
+ self.metrics_eval_obj[metric_key][label_id] = tf.keras.metrics.Mean(name='Avg{}-Label-{}-{}'.format(metric_key, label_id, self.metric_type))
+
+ # Time Metrics
+ self.metric_time_dataloader = tf.keras.metrics.Mean(name='AvgTime-Dataloader-{}'.format(self.metric_type))
+ self.metric_time_model_predict = tf.keras.metrics.Mean(name='AvgTime-ModelPredict-{}'.format(self.metric_type))
+ self.metric_time_model_loss = tf.keras.metrics.Mean(name='AvgTime-ModelLoss-{}'.format(self.metric_type))
+ self.metric_time_model_backprop = tf.keras.metrics.Mean(name='AvgTime-ModelBackProp-{}'.format(self.metric_type))
+
+ # FlipOut Metrics
+ self.metric_kl_alpha = tf.keras.metrics.Mean(name='KL-Alpha')
+ self.metric_kl_divergence = tf.keras.metrics.Mean(name='KL-Divergence')
+
+ # Scalar Losses
+ self.metric_scalarloss_data = tf.keras.metrics.Mean(name='ScalarLoss-Data')
+ self.metric_scalarloss_reg = tf.keras.metrics.Mean(name='ScalarLoss-Reg')
+
+ def init_metrics_layers_kl_std(self, params, layers_kl_std):
+
+ self.metrics_layers_kl_divergence = {}
+ self.tboard_layers_kl_divergence = {}
+ self.writer_tboard_layers_std = {}
+ self.writer_tboard_layers_mean = {}
+ for layer_name in layers_kl_std:
+ self.metrics_layers_kl_divergence[layer_name] = tf.keras.metrics.Mean(name='KL-Divergence-{}'.format(layer_name))
+ self.metrics_layers_kl_divergence[layer_name].update_state(0)
+ self.tboard_layers_kl_divergence[layer_name] = utils.get_tensorboard_writer(params['exp_name'], suffix='KL-Divergence-Layer-{}'.format(layer_name))
+ utils.make_summary('BayesLossExtras/FlipOut/KLDivergence-{}'.format(layer_name), 0, writer1=self.tboard_layers_kl_divergence[layer_name], value1=self.metrics_layers_kl_divergence[layer_name].result())
+ self.metrics_layers_kl_divergence[layer_name].reset_states()
+
+ if 'std' in layers_kl_std[layer_name]:
+ keyname = layer_name + '-std'
+ self.writer_tboard_layers_std[keyname] = utils.get_tensorboard_writer(params['exp_name'], suffix='Std-Layer-{}'.format(keyname))
+ utils.make_summary_hist('Std/{}'.format(keyname), 0, writer1=self.writer_tboard_layers_std[keyname], value1=layers_kl_std[layer_name]['std'])
+
+ if 'mean' in layers_kl_std[layer_name]: # keep this between [-2,+2] for better visualization in tf.summary.histogram()
+ keyname = layer_name + '-mean'
+ self.writer_tboard_layers_mean[keyname] = utils.get_tensorboard_writer(params['exp_name'], suffix='Mean-Layer-{}'.format(keyname))
+ mean_vals = layers_kl_std[layer_name]['mean'].numpy()
+ mean_vals = mean_vals[mean_vals >= -2]
+ mean_vals = mean_vals[mean_vals <= 2]
+ utils.make_summary_hist('Mean/{}'.format(keyname), 0, writer1=self.writer_tboard_layers_mean[keyname], value1=mean_vals)
+
+ def reset_metrics(self, params):
+
+ # Metrics for losses (during training for smaller grids)
+ for metric_key in params['metrics']['metrics_loss']:
+ self.metrics_loss_obj[metric_key]['total'].reset_states()
+ if params['metrics']['metrics_loss'][metric_key] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_FOCAL]:
+ for label_id in self.label_ids:
+ self.metrics_loss_obj[metric_key][label_id].reset_states()
+
+ # Metrics for eval (for full 3D volume)
+ for metric_key in params['metrics']['metrics_eval']:
+ self.metrics_eval_obj[metric_key]['total'].reset_states()
+ if params['metrics']['metrics_eval'][metric_key] in [config.LOSS_DICE]:
+ for label_id in self.label_ids:
+ self.metrics_eval_obj[metric_key][label_id].reset_states()
+
+ # Time Metrics
+ self.metric_time_dataloader.reset_states()
+ self.metric_time_model_predict.reset_states()
+ self.metric_time_model_loss.reset_states()
+ self.metric_time_model_backprop.reset_states()
+
+ # FlipOut Metrics
+ self.metric_kl_alpha.reset_states()
+ self.metric_kl_divergence.reset_states()
+
+ # Scalar Losses
+ self.metric_scalarloss_data.reset_states()
+ self.metric_scalarloss_reg.reset_states()
+
+ # FlipOut-Layers
+ for layer_name in self.metrics_layers_kl_divergence:
+ self.metrics_layers_kl_divergence[layer_name].reset_states()
+
+ def init_tboard_writers(self, params):
+ """
+ These are tensorboard writer
+ """
+ # Writers for loss (during training for smaller grids)
+ self.writers_loss_obj = {}
+ for metric_key in params['metrics']['metrics_loss']:
+ self.writers_loss_obj[metric_key] = {}
+ self.writers_loss_obj[metric_key]['total'] = utils.get_tensorboard_writer(params['exp_name'], suffix=self.metric_type + '-Loss')
+ if params['metrics']['metrics_loss'][metric_key] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_FOCAL]:
+ for label_id in self.label_ids:
+ self.writers_loss_obj[metric_key][label_id] = utils.get_tensorboard_writer(params['exp_name'], suffix=self.metric_type + '-Loss-' + str(label_id))
+
+ # Writers for eval (for full 3D volume)
+ self.writers_eval_obj = {}
+ for metric_key in params['metrics']['metrics_eval']:
+ self.writers_eval_obj[metric_key] = {}
+ self.writers_eval_obj[metric_key]['total'] = utils.get_tensorboard_writer(params['exp_name'], suffix=self.metric_type + '-Eval')
+ if params['metrics']['metrics_eval'][metric_key] in [config.LOSS_DICE]:
+ for label_id in self.label_ids:
+ self.writers_eval_obj[metric_key][label_id] = utils.get_tensorboard_writer(params['exp_name'], suffix=self.metric_type + '-Eval-' + str(label_id))
+
+ # Time and other writers
+ self.writer_lr = utils.get_tensorboard_writer(params['exp_name'], suffix='LR')
+ self.writer_time_dataloader = utils.get_tensorboard_writer(params['exp_name'], suffix='Time-Dataloader')
+ self.writer_time_model_predict = utils.get_tensorboard_writer(params['exp_name'], suffix='Time-Model-Predict')
+ self.writer_time_model_loss = utils.get_tensorboard_writer(params['exp_name'], suffix='Time-Model-Loss')
+ self.writer_time_model_backprop = utils.get_tensorboard_writer(params['exp_name'], suffix='Time-Model-Backprop')
+
+ # FlipOut writers
+ self.writer_kl_alpha = utils.get_tensorboard_writer(params['exp_name'], suffix='KL-Alpha')
+ self.writer_kl_divergence = utils.get_tensorboard_writer(params['exp_name'], suffix='KL-Divergence')
+
+ # Scalar Losses
+ self.writer_scalarloss_data = utils.get_tensorboard_writer(params['exp_name'], suffix='ScalarLoss-Data')
+ self.writer_scalarloss_reg = utils.get_tensorboard_writer(params['exp_name'], suffix='ScalarLoss-Reg')
+
+ def init_epoch0(self):
+
+ for metric_str in self.metrics_loss_obj:
+ self.update_metric_loss(metric_str, 1e-6)
+ if params['metrics']['metrics_loss'][metric_str] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_FOCAL]:
+ self.update_metric_loss_labels(metric_str, {label_id: 1e-6 for label_id in self.label_ids})
+
+ for metric_str in self.metrics_eval_obj:
+ self.update_metric_eval_labels(metric_str, {label_id: 0 for label_id in self.label_ids})
+
+ self.update_metrics_time(time_dataloader=0, time_predict=0, time_loss=0, time_backprop=0)
+ self.update_metrics_kl(kl_alpha=0, kl_divergence=0, kl_divergence_layers={})
+ self.update_metrics_scalarloss(loss_data=0, loss_reg=0)
+
+ self.write_epoch_summary(epoch=0, label_map=self.label_map, params=None, eval_condition=True)
+
+ @tf.function
+ def update_metrics_kl(self, kl_alpha, kl_divergence, kl_divergence_layers):
+ self.metric_kl_alpha.update_state(kl_alpha)
+ self.metric_kl_divergence.update_state(kl_divergence)
+
+ for layer_name in kl_divergence_layers:
+ if layer_name in self.metrics_layers_kl_divergence:
+ self.metrics_layers_kl_divergence[layer_name].update_state(kl_divergence_layers[layer_name]['kl'])
+
+ @tf.function
+ def update_metrics_scalarloss(self, loss_data, loss_reg):
+ self.metric_scalarloss_data.update_state(loss_data)
+ self.metric_scalarloss_reg.update_state(loss_reg)
+
+ def update_metrics_time(self, time_dataloader, time_predict, time_loss, time_backprop):
+ if time_dataloader is not None:
+ self.metric_time_dataloader.update_state(time_dataloader)
+ if time_predict is not None:
+ self.metric_time_model_predict.update_state(time_predict)
+ if time_loss is not None:
+ self.metric_time_model_loss.update_state(time_loss)
+ if time_backprop is not None:
+ self.metric_time_model_backprop.update_state(time_backprop)
+
+ def update_metric_loss(self, metric_str, metric_val):
+ # Metrics for losses (during training for smaller grids)
+ self.metrics_loss_obj[metric_str]['total'].update_state(metric_val)
+
+ @tf.function
+ def update_metric_loss_labels(self, metric_str, metric_vals_labels):
+ # Metrics for losses (during training for smaller grids)
+
+ for label_id in self.label_ids:
+ if metric_vals_labels[label_id] > 0.0:
+ self.metrics_loss_obj[metric_str][label_id].update_state(metric_vals_labels[label_id])
+
+ def update_metric_eval(self, metric_str, metric_val):
+ # Metrics for eval (for full 3D volume)
+ self.metrics_eval_obj[metric_str]['total'].update_state(metric_val)
+
+ def update_metric_eval_labels(self, metric_str, metric_vals_labels, do_average=False):
+ # Metrics for eval (for full 3D volume)
+
+ try:
+ metric_avg = []
+ for label_id in self.label_ids:
+ if label_id in metric_vals_labels:
+ if metric_vals_labels[label_id] > 0:
+ self.metrics_eval_obj[metric_str][label_id].update_state(metric_vals_labels[label_id])
+ if do_average:
+ if label_id > 0:
+ metric_avg.append(metric_vals_labels[label_id])
+
+ if do_average:
+ if len(metric_avg):
+ self.metrics_eval_obj[metric_str]['total'].update_state(np.mean(metric_avg))
+
+ except:
+ traceback.print_exc()
+
+ def write_epoch_summary(self, epoch, label_map, params=None, eval_condition=False):
+
+ if self.logging_tboard:
+ # Metrics for losses (during training for smaller grids)
+ for metric_str in self.metrics_loss_obj:
+ utils.make_summary('Loss/{}'.format(metric_str), epoch, writer1=self.writers_loss_obj[metric_str]['total'], value1=self.metrics_loss_obj[metric_str]['total'].result())
+ if self.params['metrics']['metrics_loss'][metric_str] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_FOCAL]:
+ if len(self.metrics_loss_obj[metric_str]) > 1: # i.e. has label ids
+ for label_id in self.label_ids:
+ label_name, _ = utils.get_info_from_label_id(label_id, label_map)
+ utils.make_summary('Loss/{}/{}'.format(metric_str, label_name), epoch, writer1=self.writers_loss_obj[metric_str][label_id], value1=self.metrics_loss_obj[metric_str][label_id].result())
+
+ # Metrics for eval (for full 3D volume)
+ if eval_condition:
+ for metric_str in self.metrics_eval_obj:
+ utils.make_summary('Eval3D/{}'.format(metric_str), epoch, writer1=self.writers_eval_obj[metric_str]['total'], value1=self.metrics_eval_obj[metric_str]['total'].result())
+ if len(self.metrics_eval_obj[metric_str]) > 1: # i.e. has label ids
+ for label_id in self.label_ids:
+ label_name, _ = utils.get_info_from_label_id(label_id, label_map)
+ utils.make_summary('Eval3D/{}/{}'.format(metric_str, label_name), epoch, writer1=self.writers_eval_obj[metric_str][label_id], value1=self.metrics_eval_obj[metric_str][label_id].result())
+
+ # Time Metrics
+ utils.make_summary('Info/Time/Dataloader' , epoch, writer1=self.writer_time_dataloader , value1=self.metric_time_dataloader.result())
+ utils.make_summary('Info/Time/ModelPredict' , epoch, writer1=self.writer_time_model_predict , value1=self.metric_time_model_predict.result())
+ utils.make_summary('Info/Time/ModelLoss' , epoch, writer1=self.writer_time_model_loss , value1=self.metric_time_model_loss.result())
+ utils.make_summary('Info/Time/ModelBackProp', epoch, writer1=self.writer_time_model_backprop, value1=self.metric_time_model_backprop.result())
+
+ # FlipOut Metrics
+ utils.make_summary('BayesLoss/FlipOut/KLAlpha' , epoch, writer1=self.writer_kl_alpha , value1=self.metric_kl_alpha.result())
+ utils.make_summary('BayesLoss/FlipOut/KLDivergence' , epoch, writer1=self.writer_kl_divergence , value1=self.metric_kl_divergence.result())
+ for layer_name in self.metrics_layers_kl_divergence:
+ utils.make_summary('BayesLossExtras/FlipOut/KLDivergence-{}'.format(layer_name), epoch, writer1=self.tboard_layers_kl_divergence[layer_name], value1=self.metrics_layers_kl_divergence[layer_name].result())
+
+ # Scalar Loss Metrics
+ utils.make_summary('BayesLoss/FlipOut/ScalarLossData' , epoch, writer1=self.writer_scalarloss_data , value1=self.metric_scalarloss_data.result())
+ utils.make_summary('BayesLoss/FlipOut/ScalarLossReg' , epoch, writer1=self.writer_scalarloss_reg , value1=self.metric_scalarloss_reg.result())
+
+ # Learning Rate
+ if params is not None:
+ if 'optimizer' in params:
+ utils.make_summary('Info/LR', epoch, writer1=self.writer_lr, value1=params['optimizer'].lr)
+
+ def write_epoch_summary_std(self, layers_kl_std, epoch):
+
+ for layer_name in layers_kl_std:
+
+ if 'std' in layers_kl_std[layer_name]:
+ keyname = layer_name + '-std'
+ utils.make_summary_hist('Std/{}'.format(keyname), epoch, writer1=self.writer_tboard_layers_std[keyname], value1=layers_kl_std[layer_name]['std'])
+
+ if 'mean' in layers_kl_std[layer_name]:
+ keyname = layer_name + '-mean'
+ mean_vals = layers_kl_std[layer_name]['mean'].numpy()
+ mean_vals = mean_vals[mean_vals >= -2]
+ mean_vals = mean_vals[mean_vals <= 2]
+ utils.make_summary_hist('Mean/{}'.format(keyname), epoch, writer1=self.writer_tboard_layers_mean[keyname], value1=mean_vals)
+
+ def update_pbar(self, pbar):
+ desc_str = ''
+
+ # Metrics for losses (during training for smaller grids)
+ # Metrics for losses (during training for smaller grids)
+ if config.LOSS_PURATIO in self.metrics_loss_obj or config.LOSS_PAVPU in self.metrics_loss_obj:
+ for metric_str in self.metrics_loss_obj:
+ if metric_str in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_FOCAL, config.LOSS_PURATIO, config.LOSS_PAVPU]:
+ result = self.metrics_loss_obj[metric_str]['total'].result().numpy()
+ loss_text = '{}:{:2f},'.format(metric_str, result)
+ desc_str += loss_text
+ else:
+ for metric_str in self.metrics_loss_obj:
+ if len(desc_str): desc_str += ','
+
+ if metric_str in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_FOCAL]:
+ metric_avg = []
+ for label_id in self.label_ids:
+ if label_id > 0:
+ label_result = self.metrics_loss_obj[metric_str][label_id].result().numpy()
+ if label_result > 0:
+ metric_avg.append(label_result)
+
+ mean_val = 0
+ if len(metric_avg):
+ mean_val = np.mean(metric_avg)
+ loss_text = '{}Loss:{:2f}'.format(metric_str, mean_val)
+ desc_str += loss_text
+
+ # GPU Memory
+ if 1:
+ try:
+ if len(self.metrics_loss_obj) > 1:
+ desc_str = desc_str[:-1] # to remove the extra ','
+ desc_str += ',' + str(utils.get_tf_gpu_memory())
+ except:
+ pass
+
+ pbar.set_description(desc=desc_str, refresh=True)
+
+def eval_3D_finalize(exp_name, patient_img, patient_gt, patient_pred_processed, patient_pred, patient_pred_std, patient_pred_ent, patient_pred_mif, patient_pred_unc, patient_pred_error
+ , patient_id_curr
+ , model_folder_epoch_imgs, model_folder_epoch_patches
+ , loss_labels_val, hausdorff_labels_val, hausdorff95_labels_val, msd_labels_vals
+ , spacing, label_map, label_colors
+ , show=False, save=False):
+
+ try:
+
+ # Step 3.1.2 - Vizualize
+ if show:
+ if len(patient_pred_std):
+ maskpred_std = np.max(patient_pred_std, axis=-1)
+ maskpred_std = np.expand_dims(maskpred_std, axis=-1)
+ maskpred_std = np.repeat(maskpred_std, repeats=10, axis=-1)
+
+ maskpred_ent = np.expand_dims(patient_pred_ent, axis=-1) # [H,W,D] --> [H,W,D,1]
+ maskpred_ent = np.repeat(maskpred_ent, repeats=10, axis=-1) # [H,W,D,1] --> [H,W,D,10]
+
+ maskpred_mif = np.expand_dims(patient_pred_mif, axis=-1) # [H,W,D] --> [H,W,D,1]
+ maskpred_mif = np.repeat(maskpred_mif, repeats=10, axis=-1) # [H,W,D,1] --> [H,W,D,10]
+
+ if 1:
+ print (' - patient_id_curr: ', patient_id_curr)
+ f,axarr = plt.subplots(1,2)
+ axarr[0].hist(maskpred_ent[:,:,:,0].flatten(), bins=30)
+ axarr[0].set_title('Entropy')
+ axarr[1].hist(maskpred_mif[:,:,:,0].flatten(), bins=30)
+ axarr[1].set_title('MutInf')
+ plt.suptitle('Exp: {}\nPatient:{}'.format(exp_name, patient_id_curr))
+ plt.show()
+ pdb.set_trace()
+
+ utils.viz_model_output_3d(exp_name, patient_img, patient_gt, patient_pred, maskpred_std, patient_id_curr, model_folder_epoch_imgs, label_map, label_colors
+ , vmax_unc=0.06, unc_title='Predictive Std', unc_savesufix='stdmax')
+
+ utils.viz_model_output_3d(exp_name, patient_img, patient_gt, patient_pred, maskpred_ent, patient_id_curr, model_folder_epoch_imgs, label_map, label_colors
+ , vmax_unc=1.2, unc_title='Predictive Entropy', unc_savesufix='ent')
+
+ utils.viz_model_output_3d(exp_name, patient_img, patient_gt, patient_pred, maskpred_mif, patient_id_curr, model_folder_epoch_imgs, label_map, label_colors
+ , vmax_unc=0.06, unc_title='Mutual Information', unc_savesufix='mif')
+
+ # Step 3.1.3 - Save 3D grid to visualize in 3D Slicer (drag-and-drop mechanism)
+ if save:
+
+ import medloader.dataloader.utils as medutils
+ medutils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_img.nrrd' , patient_img[:,:,:,0], spacing)
+ medutils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_mask.nrrd', np.argmax(patient_gt, axis=3),spacing)
+ maskpred_labels = np.argmax(patient_pred_processed, axis=3) # not "np.argmax(patient_pred, axis=3)" since it does not contain any postprocessing
+ medutils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpred.nrrd', maskpred_labels, spacing)
+
+ maskpred_labels_probmean = np.take_along_axis(patient_pred, np.expand_dims(maskpred_labels,axis=-1), axis=-1)[:,:,:,0]
+ medutils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredmean.nrrd', maskpred_labels_probmean, spacing)
+
+ if np.sum(patient_pred_std):
+ maskpred_labels_std = np.take_along_axis(patient_pred_std, np.expand_dims(maskpred_labels,axis=-1), axis=-1)[:,:,:,0]
+ medutils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredstd.nrrd', maskpred_labels_std, spacing)
+
+ maskpred_std_max = np.max(patient_pred_std, axis=-1)
+ medutils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredstdmax.nrrd', maskpred_std_max, spacing)
+
+ if np.sum(patient_pred_ent):
+ maskpred_ent = patient_pred_ent # [H,W,D]
+ medutils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredent.nrrd', maskpred_ent, spacing)
+
+ if np.sum(patient_pred_mif):
+ maskpred_mif = patient_pred_mif
+ medutils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredmif.nrrd', maskpred_mif, spacing)
+
+ if np.sum(patient_pred_unc):
+ if len(patient_pred_unc.shape) == 4:
+ maskpred_labels_unc = np.take_along_axis(patient_pred_unc, np.expand_dims(maskpred_labels,axis=-1), axis=-1)[:,:,:,0] # [H,W,D,C] --> [H,W,D]
+ else:
+ maskpred_labels_unc = patient_pred_unc
+ medutils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredunc.nrrd', maskpred_labels_unc, spacing)
+
+ if np.sum(patient_pred_error):
+ medutils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskprederror.nrrd', patient_pred_error, spacing)
+
+ try:
+ # Step 3.1.3.2 - PLot results for that patient
+ f, axarr = plt.subplots(3,1, figsize=(15,10))
+ boxplot_dice, boxplot_hausdorff, boxplot_hausdorff95 = {}, {}, {}
+ boxplot_dice_mean_list = []
+ for label_id in range(len(loss_labels_val)):
+ label_name, _ = utils.get_info_from_label_id(label_id, label_map)
+ boxplot_dice[label_name] = [loss_labels_val[label_id]]
+ boxplot_hausdorff[label_name] = [hausdorff_labels_val[label_id]]
+ boxplot_hausdorff95[label_name] = [hausdorff95_labels_val[label_id]]
+ if label_id > 0 and loss_labels_val[label_id] > 0:
+ boxplot_dice_mean_list.append(loss_labels_val[label_id])
+
+ axarr[0].boxplot(boxplot_dice.values())
+ axarr[0].set_xticks(range(1, len(boxplot_dice)+1))
+ axarr[0].set_xticklabels(boxplot_dice.keys())
+ axarr[0].set_ylim([0.0,1.1])
+ axarr[0].grid()
+ axarr[0].set_title('DICE - Avg: {} \n w/o chiasm: {}'.format(
+ '%.4f' % (np.mean(boxplot_dice_mean_list))
+ , '%.4f' % (np.mean(boxplot_dice_mean_list[0:1] + boxplot_dice_mean_list[2:])) # avoid label_id=2
+ )
+ )
+
+ axarr[1].boxplot(boxplot_hausdorff.values())
+ axarr[1].set_xticks(range(1,len(boxplot_hausdorff)+1))
+ axarr[1].set_xticklabels(boxplot_hausdorff.keys())
+ axarr[1].grid()
+ axarr[1].set_title('Hausdorff')
+
+ axarr[2].boxplot(boxplot_hausdorff95.values())
+ axarr[2].set_xticks(range(1,len(boxplot_hausdorff95)+1))
+ axarr[2].set_xticklabels(boxplot_hausdorff95.keys())
+ axarr[2].set_title('95% Hausdorff')
+ axarr[2].grid()
+
+ plt.savefig(str(Path(model_folder_epoch_patches).joinpath('results_' + patient_id_curr + '.png')), bbox_inches='tight') # , bbox_inches='tight'
+ plt.close()
+
+ except:
+ traceback.print_exc()
+
+ except:
+ pdb.set_trace()
+ traceback.print_exc()
+
+def get_ece(y_true, y_predict, patient_id, res_global, verbose=False):
+ """
+ Params
+ ------
+ y_true : [H,W,D,C], np.array, binary
+ y_predict: [H,W,D,C], np.array, with softmax probability values
+ - Ref: https://github.com/sirius8050/Expected-Calibration-Error/blob/master/ECE.py
+ : On Calibration of Modern Neural Networks
+ - Ref(future): https://github.com/yding5/AdaptiveBinning
+ : Revisiting the evaluation of uncertainty estimation and its application to explore model complexity-uncertainty trade-off
+ """
+ res = {}
+ nan_value = -0.1
+
+ if verbose: print (' - [get_ece()] patient_id: ', patient_id)
+
+ try:
+
+ # Step 0 - Init
+ label_count = y_true.shape[-1]
+
+ # Step 1 - Calculate o_predict
+ o_true = np.argmax(y_true, axis=-1)
+ o_predict = np.argmax(y_predict, axis=-1)
+
+ # Step 2 - Loop over different classes
+ for label_id in range(label_count):
+
+ if label_id > -1:
+
+ if verbose: print (' --- [get_ece()] label_id: ', label_id)
+
+ if label_id not in res_global: res_global[label_id] = {'o_predict_label':[], 'y_predict_label':[], 'o_true_label':[]}
+
+ # Step 2.1 - Get o_predict_label(label_ids), o_true_label(label_ids), y_predict_label(probs) [and append to global list]
+ ## NB: You are considering TP + FP here
+ o_true_label = o_true[o_predict == label_id]
+ o_predict_label = o_predict[o_predict == label_id]
+ y_predict_label = y_predict[:,:,:,label_id][o_predict == label_id]
+ res_global[label_id]['o_true_label'].extend(o_true_label.flatten().tolist())
+ res_global[label_id]['o_predict_label'].extend(o_predict_label.flatten().tolist())
+ res_global[label_id]['y_predict_label'].extend(y_predict_label.flatten().tolist())
+
+ if len(o_true_label) and len(y_predict_label):
+
+ # Step 2.2 - Bin the probs and calculate their mean
+ y_predict_label_bin_ids = np.digitize(y_predict_label, np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.01]), right=False) - 1
+ y_predict_binned_vals = [y_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)]
+ y_predict_bins_mean = [np.mean(vals) if len(vals) else nan_value for vals in y_predict_binned_vals]
+
+ # Step 2.3 - Calculate the accuracy of each bin
+ o_predict_label_bins = [o_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)]
+ o_true_label_bins = [o_true_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)]
+ y_predict_bins_accuracy = [np.sum(o_predict_label_bins[bin_id] == o_true_label_bins[bin_id])/len(o_predict_label_bins[bin_id]) if len(o_predict_label_bins[bin_id]) else nan_value for bin_id in range(label_count)]
+ y_predict_bins_len = [len(o_predict_label_bins[bin_id]) for bin_id in range(label_count)]
+
+ # Step 2.4 - Wrapup
+ N = np.prod(y_predict_label.shape)
+ ce = np.array((np.array(y_predict_bins_len)/N)*(np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean)))
+ ce[ce == 0] = nan_value # i.e. y_predict_bins_accuracy[bin_id] == y_predict_bins_mean[bind_id] = nan_value
+ res[label_id] = ce
+
+ else:
+ res[label_id] = -1
+
+ if 0:
+ if label_id == 1:
+ diff = np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean)
+ print (' - [get_ece()][BStem] diff: ', ['%.4f' % each for each in diff])
+
+ # NB: This considers the whole volume
+ o_true_label = y_true[:,:,:,label_id] # [1=this label, 0=other label]
+ o_predict_label = np.array(o_predict, copy=True)
+ o_predict_label[o_predict_label != label_id] = 0
+ o_predict_label[o_predict_label == label_id] = 1 # [1 - predicted this label, 0 = predicted other label]
+ y_predict_label = y_predict[:,:,:,label_id]
+
+ # Step x.2 - Bin the probs and calculate their mean
+ y_predict_label_bin_ids = np.digitize(y_predict_label, np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.01]), right=False) - 1
+ y_predict_binned_vals = [y_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)]
+ y_predict_bins_mean = [np.mean(vals) if len(vals) else nan_value for vals in y_predict_binned_vals]
+
+ # Step x.3 - Calculate the accuracy of each bin
+ o_predict_label_bins = [o_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)]
+ o_true_label_bins = [o_true_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)]
+ y_predict_bins_accuracy = [np.sum(o_predict_label_bins[bin_id] == o_true_label_bins[bin_id])/len(o_predict_label_bins[bin_id]) if len(o_predict_label_bins[bin_id]) else nan_value for bin_id in range(label_count)]
+ y_predict_bins_len = [len(o_predict_label_bins[bin_id]) for bin_id in range(label_count)]
+
+ N_new = np.prod(y_predict_label.shape)
+ diff_new = np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean)
+ ce_new = np.array((np.array(y_predict_bins_len)/N_new)*(np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean)))
+ print (' - [get_ece()][BStem] diff_new: ', ['%.4f' % each for each in diff_new])
+
+ pdb.set_trace()
+
+ if verbose:
+ print (' --- [get_ece()] y_predict_bins_accuracy: ', ['%.4f' % (each) for each in np.array(y_predict_bins_accuracy)])
+ print (' --- [get_ece()] CE : ', ['%.4f' % (each) for each in np.array(res[label_id])])
+ print (' --- [get_ece()] ECE: ', np.sum(np.abs(res[label_id][res[label_id] != nan_value])))
+
+ # Prob bars
+ # plt.hist(y_predict[:,:,:,label_id].flatten(), bins=10)
+ # plt.title('Softmax Probs (label={})\nPatient:{}'.format(label_id, patient_id))
+ # plt.show()
+
+ # GT Prob bars
+ # plt.bar(np.arange(len(y_predict_bins_len))/10.0 + 0.1, y_predict_bins_len, width=0.05)
+ # plt.title('Softmax Probs (GT) (label={})\nPatient:{}'.format(label_id, patient_id))
+ # plt.xlabel('Probabilities')
+ # plt.show()
+
+ # GT Probs (sorted) in plt.plot (with equally-spaced bins)
+ # from collections import Counter
+ # tmp = np.sort(y_predict_label)
+ # plt.plot(range(len(tmp)), tmp, color='orange')
+ # tmp_bins = np.digitize(tmp, np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.01])) - 1
+ # tmp_bins_len = Counter(tmp_bins)
+ # boundary_start = 0
+ # plt.plot([0,0],[0.0,1.0], color='black', alpha=0.5, linestyle='dashed', label='Bins(equally-spaced)')
+ # for boundary in np.arange(0,len(tmp_bins_len)): plt.plot([boundary_start+tmp_bins_len[boundary], boundary_start+tmp_bins_len[boundary]], [0.0,1.0], color='black', alpha=0.5, linestyle='dashed'); boundary_start+=tmp_bins_len[boundary]
+ # plt.title('Sorted Softmax Probs (GT) (label={})\nPatient:{}'.format(label_id, patient_id))
+ # plt.legend()
+ # plt.show()
+
+ # GT Probs (sorted) in plt.plot (with equally-sized bins)
+ if label_id == 1:
+ Path('tmp').mkdir(parents=True, exist_ok=True)
+ tmp = np.sort(y_predict_label)
+ tmp_len = len(tmp)
+ plt.plot(range(len(tmp)), tmp, color='orange')
+ for boundary in np.arange(0,tmp_len, int(tmp_len//10)): plt.plot([boundary, boundary], [0.0,1.0], color='black', alpha=0.5, linestyle='dashed')
+ plt.plot([0,0],[0,0], color='black', alpha=0.5, linestyle='dashed', label='Bins(equally-sized)')
+ plt.title('Sorted Softmax Probs (GT) (label={})\nPatient:{}'.format(label_id, patient_id))
+ plt.legend()
+ # plt.show()
+ plt.savefig('./tmp/ECE_SortedProbs_label_{}_{}.png'.format(label_id, patient_id, ), bbox_inches='tight');plt.close()
+
+ # ECE plot
+ plt.plot(np.arange(11), np.arange(11)/10.0, linestyle='dashed', color='black', alpha=0.8)
+ plt.scatter(np.arange(len(y_predict_bins_mean)) + 0.5 , y_predict_bins_mean, alpha=0.5, color='g', marker='s', label='Mean Pred')
+ plt.scatter(np.arange(len(y_predict_bins_accuracy)) + 0.5 , y_predict_bins_accuracy, alpha=0.5, color='b', marker='x', label='Accuracy')
+ diff = np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean)
+ for bin_id in range(len(y_predict_bins_accuracy)): plt.plot([bin_id + 0.5, bin_id + 0.5],[y_predict_bins_accuracy[bin_id], y_predict_bins_mean[bin_id]], color='pink')
+ plt.plot([bin_id + 0.5, bin_id + 0.5],[y_predict_bins_accuracy[bin_id], y_predict_bins_mean[bin_id]], color='pink', label='CE')
+ plt.xticks(ticks=np.arange(11), labels=np.arange(11)/10.0)
+ plt.title('CE (label={})\nPatient:{}'.format(label_id, patient_id))
+ plt.xlabel('Probability')
+ plt.ylabel('Accuracy')
+ plt.legend()
+ # plt.show()
+ plt.savefig('./tmp/ECE_label_{}_{}.png'.format(label_id, patient_id, ), bbox_inches='tight');plt.close()
+
+ # pdb.set_trace()
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+ return res_global, res
+
+def eval_3D_summarize(res, ece_global_obj, model, eval_type, deepsup_eval, label_map, model_folder_epoch_patches, times_mcruns, ttotal, save=False, show=False, verbose=False):
+
+ try:
+
+ pid = os.getpid()
+
+ ###############################################################################
+ # Summarize #
+ ###############################################################################
+
+ # Step 1 - Summarize DICE + Surface Distances
+ if 1:
+ loss_labels_avg, loss_labels_std = [], []
+ hausdorff_labels_avg, hausdorff_labels_std = [], []
+ hausdorff95_labels_avg, hausdorff95_labels_std = [], []
+ msd_labels_avg, msd_labels_std = [], []
+
+ loss_labels_list = np.array([res[patient_id][config.KEY_DICE_LABELS] for patient_id in res])
+ hausdorff_labels_list = np.array([res[patient_id][config.KEY_HD_LABELS] for patient_id in res])
+ hausdorff95_labels_list = np.array([res[patient_id][config.KEY_HD95_LABELS] for patient_id in res])
+ msd_labels_list = np.array([res[patient_id][config.KEY_MSD_LABELS] for patient_id in res])
+
+ for label_id in range(loss_labels_list.shape[1]):
+ tmp_vals = loss_labels_list[:,label_id]
+ loss_labels_avg.append(np.mean(tmp_vals[tmp_vals > 0]))
+ loss_labels_std.append(np.std(tmp_vals[tmp_vals > 0]))
+
+ if label_id > 0:
+ tmp_vals = hausdorff_labels_list[:,label_id]
+ hausdorff_labels_avg.append(np.mean(tmp_vals[tmp_vals > 0])) # avoids -1 for "erroneous" HD, and 0 for "not to be calculated" HD
+ hausdorff_labels_std.append(np.std(tmp_vals[tmp_vals > 0]))
+
+ tmp_vals = hausdorff95_labels_list[:,label_id]
+ hausdorff95_labels_avg.append(np.mean(tmp_vals[tmp_vals > 0]))
+ hausdorff95_labels_std.append(np.std(tmp_vals[tmp_vals > 0]))
+
+ tmp_vals = msd_labels_list[:,label_id]
+ msd_labels_avg.append(np.mean(tmp_vals[tmp_vals > 0]))
+ msd_labels_std.append(np.std(tmp_vals[tmp_vals > 0]))
+
+ else:
+ hausdorff_labels_avg.append(0)
+ hausdorff_labels_std.append(0)
+ hausdorff95_labels_avg.append(0)
+ hausdorff95_labels_std.append(0)
+ msd_labels_avg.append(0)
+ msd_labels_std.append(0)
+
+ loss_avg = np.mean([res[patient_id][config.KEY_DICE_AVG] for patient_id in res])
+ print (' --------------------------- eval_type: ', eval_type)
+ print (' - dice_labels_3D : ', ['%.4f' % each for each in loss_labels_avg])
+ print (' - dice_labels_3D : ', ['%.4f' % each for each in loss_labels_std])
+ print (' - dice_3D : %.4f' % np.mean(loss_labels_avg))
+ print (' - dice_3D (w/o bgd): %.4f' % np.mean(loss_labels_avg[1:]))
+ print (' - dice_3D (w/o bgd, w/o chiasm): %.4f' % np.mean(loss_labels_avg[1:2] + loss_labels_avg[3:]))
+ print ('')
+ print (' - hausdorff_labels_3D : ', ['%.4f' % each for each in hausdorff_labels_avg])
+ print (' - hausdorff_labels_3D : ', ['%.4f' % each for each in hausdorff_labels_std])
+ print (' - hausdorff_3D (w/o bgd): %.4f' % np.mean(hausdorff_labels_avg[1:]))
+ print (' - hausdorff_3D (w/o bgd, w/o chiasm): %.4f' % np.mean(hausdorff_labels_avg[1:2] + hausdorff_labels_avg[3:]))
+ print ('')
+ print (' - hausdorff95_labels_3D : ', ['%.4f' % each for each in hausdorff95_labels_avg])
+ print (' - hausdorff95_labels_3D : ', ['%.4f' % each for each in hausdorff95_labels_std])
+ print (' - hausdorff95_3D (w/o bgd): %.4f' % np.mean(hausdorff95_labels_avg[1:]))
+ print (' - hausdorff95_3D (w/o bgd, w/o chiasm): %.4f' % np.mean(hausdorff95_labels_avg[1:2] + hausdorff95_labels_avg[3:]))
+ print ('')
+ print (' - msd_labels_3D : ', ['%.4f' % each for each in msd_labels_avg])
+ print (' - msd_labels_3D : ', ['%.4f' % each for each in msd_labels_std])
+ print (' - msd_3D (w/o bgd): %.4f' % np.mean(msd_labels_avg[1:]))
+ print (' - msd_3D (w/o bgd, w/o chiasm): %.4f' % np.mean(msd_labels_avg[1:2] + msd_labels_avg[3:]))
+
+ # Step 2 - Summarize AvU
+ if 1:
+ try:
+ print ('')
+ if config.KEY_AVU_PAC_ENT in res[list(res.keys())[0]]:
+ p_ac_list_ent = np.array([res[patient_id][config.KEY_AVU_PAC_ENT] for patient_id in res])
+ p_ui_list_ent = np.array([res[patient_id][config.KEY_AVU_PUI_ENT] for patient_id in res])
+ pavpu_list_ent = np.array([res[patient_id][config.KEY_AVU_ENT] for patient_id in res])
+ pavpu_unc_threshold_list_ent = np.array([res[patient_id][config.KEY_THRESH_ENT] for patient_id in res])
+ print (' - AvU values for entropy')
+ print (' - p(acc|cer) : %.4f +- %.4f' % ( np.mean([p_ac_list_ent[p_ac_list_ent > -1]]) , np.std([p_ac_list_ent[p_ac_list_ent > -1]]) ))
+ print (' - p(unc|inac) : %.4f +- %.4f' % ( np.mean([p_ui_list_ent[p_ui_list_ent > -1]]) , np.std([p_ui_list_ent[p_ui_list_ent > -1]]) ))
+ print (' - pavpu_3D : %.4f +- %.4f' % ( np.mean([pavpu_list_ent[pavpu_list_ent > -1]]) , np.std([pavpu_list_ent[pavpu_list_ent > -1]]) ))
+ print (' - unc_threshold : %.4f +- %.4f' % ( np.mean(pavpu_unc_threshold_list_ent[pavpu_unc_threshold_list_ent > -1]), np.std(pavpu_unc_threshold_list_ent[pavpu_unc_threshold_list_ent > -1]) ))
+
+ if config.KEY_AVU_PAC_MIF in res[list(res.keys())[0]]:
+ p_ac_list_mif = np.array([res[patient_id][config.KEY_AVU_PAC_MIF] for patient_id in res])
+ p_ui_list_mif = np.array([res[patient_id][config.KEY_AVU_PUI_MIF] for patient_id in res])
+ pavpu_list_mif = np.array([res[patient_id][config.KEY_AVU_MIF] for patient_id in res])
+ pavpu_unc_threshold_list_mif = np.array([res[patient_id][config.KEY_THRESH_MIF] for patient_id in res])
+ print (' - AvU values for mutual info')
+ print (' - p(acc|cer) : %.4f +- %.4f' % ( np.mean([p_ac_list_mif[p_ac_list_mif > -1]]) , np.std([p_ac_list_mif[p_ac_list_mif > -1]]) ))
+ print (' - p(unc|inac) : %.4f +- %.4f' % ( np.mean([p_ui_list_mif[p_ui_list_mif > -1]]) , np.std([p_ui_list_mif[p_ui_list_mif > -1]]) ))
+ print (' - pavpu_3D : %.4f +- %.4f' % ( np.mean([pavpu_list_mif[pavpu_list_mif > -1]]) , np.std([pavpu_list_mif[pavpu_list_mif > -1]]) ))
+ print (' - unc_threshold : %.4f +- %.4f' % ( np.mean(pavpu_unc_threshold_list_mif[pavpu_unc_threshold_list_mif > -1]), np.std(pavpu_unc_threshold_list_mif[pavpu_unc_threshold_list_mif > -1]) ))
+
+ if config.KEY_AVU_PAC_UNC in res[list(res.keys())[0]]:
+ p_ac_list_unc = np.array([res[patient_id][config.KEY_AVU_PAC_UNC] for patient_id in res])
+ p_ui_list_unc = np.array([res[patient_id][config.KEY_AVU_PUI_UNC] for patient_id in res])
+ pavpu_list_unc = np.array([res[patient_id][config.KEY_AVU_UNC] for patient_id in res])
+ pavpu_unc_threshold_list_unc = np.array([res[patient_id][config.KEY_THRESH_UNC] for patient_id in res])
+ print (' - AvU values for percentile subtracts')
+ print (' - p(acc|cer) : %.4f +- %.4f' % ( np.mean([p_ac_list_unc[p_ac_list_unc > -1]]) , np.std([p_ac_list_unc[p_ac_list_unc > -1]]) ))
+ print (' - p(unc|inac) : %.4f +- %.4f' % ( np.mean([p_ui_list_unc[p_ui_list_unc > -1]]) , np.std([p_ui_list_unc[p_ui_list_unc > -1]]) ))
+ print (' - pavpu_3D : %.4f +- %.4f' % ( np.mean([pavpu_list_unc[pavpu_list_unc > -1]]) , np.std([pavpu_list_unc[pavpu_list_unc > -1]]) ))
+ print (' - unc_threshold : %.4f +- %.4f' % ( np.mean(pavpu_unc_threshold_list_unc[pavpu_unc_threshold_list_unc > -1]), np.std(pavpu_unc_threshold_list_unc[pavpu_unc_threshold_list_unc > -1]) ))
+
+ print (' - pavpu params: PAVPU_UNC_THRESHOLD: ', config.PAVPU_UNC_THRESHOLD)
+ # print (' - pavpu params: PAVPU_GRID_SIZE : ', PAVPU_GRID_SIZE)
+ # print (' - pavpu params: PAVPU_RATIO_NEG : ', PAVPU_RATIO_NEG)
+
+ except:
+ traceback.print_exc()
+ if DEBUG: pdb.set_trace()
+
+ # Step 3 - Summarize ECE
+ if 1:
+ print ('')
+ gc.collect()
+
+ nan_value = config.VAL_ECE_NAN
+ ece_labels_obj = {}
+ ece_labels = []
+ label_count = len(ece_global_obj)
+ pbar_desc_prefix = '[ECE]'
+ ece_global_obj_keys = list(ece_global_obj.keys())
+ res[config.KEY_PATIENT_GLOBAL] = {}
+ with tqdm.tqdm(total=label_count, desc=pbar_desc_prefix, disable=True) as pbar_ece:
+ for label_id in ece_global_obj_keys:
+ o_true_label = np.array(ece_global_obj[label_id]['o_true_label'])
+ o_predict_label = np.array(ece_global_obj[label_id]['o_predict_label'])
+ y_predict_label = np.array(ece_global_obj[label_id]['y_predict_label'])
+ if label_id in ece_global_obj: del ece_global_obj[label_id]
+ gc.collect()
+
+ # Step 1.1 - Bin the probs and calculate their mean
+ y_predict_label_bin_ids = np.digitize(y_predict_label, np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.01]), right=False) - 1
+ y_predict_binned_vals = [y_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)]
+ y_predict_bins_mean = [np.mean(vals) if len(vals) else nan_value for vals in y_predict_binned_vals]
+
+ # Step 1.2 - Calculate the accuracy of each bin
+ o_predict_label_bins = [o_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)]
+ o_true_label_bins = [o_true_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)]
+ y_predict_bins_accuracy = [np.sum(o_predict_label_bins[bin_id] == o_true_label_bins[bin_id])/len(o_predict_label_bins[bin_id]) if len(o_predict_label_bins[bin_id]) else nan_value for bin_id in range(label_count)]
+ y_predict_bins_len = [len(o_predict_label_bins[bin_id]) for bin_id in range(label_count)]
+
+ # Step 1.3 - Wrapup
+ N = np.prod(y_predict_label.shape)
+ ce = np.array((np.array(y_predict_bins_len)/N)*(np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean)))
+ ce[ce == 0] = nan_value
+ ece_label = np.sum(np.abs(ce[ce != nan_value]))
+ ece_labels.append(ece_label)
+ ece_labels_obj[label_id] = {'y_predict_bins_mean':y_predict_bins_mean, 'y_predict_bins_accuracy':y_predict_bins_accuracy, 'ce':ce, 'ece':ece_label}
+
+ pbar_ece.update(1)
+ memory = pbar_desc_prefix + ' [' + str(utils.get_memory(pid)) + ']'
+ pbar_ece.set_description(desc=memory, refresh=True)
+
+ res[config.KEY_PATIENT_GLOBAL][label_id] = {'ce':ce, 'ece':ece_label}
+
+ print (' - ece_labels : ', ['%.4f' % each for each in ece_labels])
+ print (' - ece : %.4f' % np.mean(ece_labels))
+ print (' - ece (w/o bgd): %.4f' % np.mean(ece_labels[1:]))
+ print (' - ece (w/o bgd, w/o chiasm): %.4f' % np.mean(ece_labels[1:2] + ece_labels[3:]))
+ print ('')
+
+ del ece_global_obj
+ gc.collect()
+
+ # Step 4 - Plot
+ if 1:
+ if save and not deepsup_eval:
+ f, axarr = plt.subplots(3,1, figsize=(15,10))
+ boxplot_dice, boxplot_hausdorff, boxplot_hausdorff95, boxplot_msd = {}, {}, {}, {}
+ for label_id in range(len(loss_labels_list[0])):
+ label_name, _ = utils.get_info_from_label_id(label_id, label_map)
+ boxplot_dice[label_name] = loss_labels_list[:,label_id]
+ boxplot_hausdorff[label_name] = hausdorff_labels_list[:,label_id]
+ boxplot_hausdorff95[label_name] = hausdorff95_labels_list[:,label_id]
+ boxplot_msd[label_name] = msd_labels_list[:,label_id]
+
+ axarr[0].boxplot(boxplot_dice.values())
+ axarr[0].set_xticks(range(1, len(boxplot_dice)+1))
+ axarr[0].set_xticklabels(boxplot_dice.keys())
+ axarr[0].set_ylim([0.0,1.1])
+ axarr[0].set_title('DICE (Avg: {}) \n w/o chiasm:{}'.format(
+ '%.4f' % np.mean(loss_labels_avg[1:])
+ , '%.4f' % np.mean(loss_labels_avg[1:2] + loss_labels_avg[3:])
+ )
+ )
+
+ axarr[1].boxplot(boxplot_hausdorff.values())
+ axarr[1].set_xticks(range(1, len(boxplot_hausdorff)+1))
+ axarr[1].set_xticklabels(boxplot_hausdorff.keys())
+ axarr[1].set_ylim([0.0,10.0])
+ axarr[1].set_title('Hausdorff')
+
+ axarr[2].boxplot(boxplot_hausdorff95.values())
+ axarr[2].set_xticks(range(1, len(boxplot_hausdorff95)+1))
+ axarr[2].set_xticklabels(boxplot_hausdorff95.keys())
+ axarr[2].set_ylim([0.0,6.0])
+ axarr[2].set_title('95% Hausdorff')
+
+ path_results = str(Path(model_folder_epoch_patches).joinpath('results_all.png'))
+ plt.savefig(path_results, bbox_inches='tight')
+ plt.close()
+
+ f, axarr = plt.subplots(1, figsize=(15,10))
+ axarr.boxplot(boxplot_dice.values())
+ axarr.set_xticks(range(1, len(boxplot_dice)+1))
+ axarr.set_xticklabels(boxplot_dice.keys())
+ axarr.set_ylim([0.0,1.1])
+ axarr.set_yticks(np.arange(0,1.1,0.05))
+ axarr.set_title('DICE')
+ axarr.grid()
+ path_results = str(Path(model_folder_epoch_patches).joinpath('results_all_dice.png'))
+ plt.savefig(path_results, bbox_inches='tight')
+ plt.close()
+
+ f, axarr = plt.subplots(1, figsize=(15,10))
+ axarr.boxplot(boxplot_hausdorff95.values())
+ axarr.set_xticks(range(1, len(boxplot_hausdorff95)+1))
+ axarr.set_xticklabels(boxplot_hausdorff95.keys())
+ axarr.set_title('95% HD')
+ axarr.grid()
+ path_results = str(Path(model_folder_epoch_patches).joinpath('results_all_hd95.png'))
+ plt.savefig(path_results, bbox_inches='tight')
+ plt.close()
+
+ f, axarr = plt.subplots(1, figsize=(15,10))
+ axarr.boxplot(boxplot_hausdorff.values())
+ axarr.set_xticks(range(1, len(boxplot_hausdorff)+1))
+ axarr.set_xticklabels(boxplot_hausdorff.keys())
+ axarr.set_title('Hausdorff Distance')
+ axarr.grid()
+ path_results = str(Path(model_folder_epoch_patches).joinpath('results_all_hd.png'))
+ plt.savefig(path_results, bbox_inches='tight')
+ plt.close()
+
+ f, axarr = plt.subplots(1, figsize=(15,10))
+ axarr.boxplot(boxplot_msd.values())
+ axarr.set_xticks(range(1, len(boxplot_msd)+1))
+ axarr.set_xticklabels(boxplot_msd.keys())
+ axarr.set_title('MSD')
+ axarr.grid()
+ path_results = str(Path(model_folder_epoch_patches).joinpath('results_all_msd.png'))
+ plt.savefig(path_results, bbox_inches='tight')
+ plt.close()
+
+ # ECE
+ for label_id in ece_labels_obj:
+ y_predict_bins_mean = ece_labels_obj[label_id]['y_predict_bins_mean']
+ y_predict_bins_accuracy = ece_labels_obj[label_id]['y_predict_bins_accuracy']
+ ece = ece_labels_obj[label_id]['ece']
+
+ plt.plot(np.arange(11), np.arange(11)/10.0, linestyle='dashed', color='black', alpha=0.8)
+ plt.scatter(np.arange(len(y_predict_bins_mean)) + 0.5 , y_predict_bins_mean, alpha=0.5, color='g', marker='s', label='Mean Pred')
+ plt.scatter(np.arange(len(y_predict_bins_accuracy)) + 0.5 , y_predict_bins_accuracy, alpha=0.5, color='b', marker='x', label='Accuracy')
+ for bin_id in range(len(y_predict_bins_accuracy)): plt.plot([bin_id + 0.5, bin_id + 0.5],[y_predict_bins_accuracy[bin_id], y_predict_bins_mean[bin_id]], color='pink')
+ plt.plot([bin_id + 0.5, bin_id + 0.5],[y_predict_bins_accuracy[bin_id], y_predict_bins_mean[bin_id]], color='pink', label='CE')
+ plt.xticks(ticks=np.arange(11), labels=np.arange(11)/10.0)
+ plt.title('CE (label={})\nECE: {}'.format(label_id, '%.5f' % (ece)))
+ plt.xlabel('Probability')
+ plt.ylabel('Accuracy')
+ plt.ylim([-0.15, 1.05])
+ plt.legend()
+
+ # plt.show()
+ path_results = str(Path(model_folder_epoch_patches).joinpath('results_ece_label{}.png'.format(label_id)))
+ plt.savefig(str(path_results), bbox_inches='tight')
+ plt.close()
+
+ # Step 5 - Save data as .json
+ if 1:
+ try:
+
+ filename = str(Path(model_folder_epoch_patches).joinpath(config.FILENAME_EVAL3D_JSON))
+ utils.write_json(res, filename)
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+ model.trainable=True
+ print ('\n - [eval_3D] Avg of times_mcruns : {:f} +- {:f}'.format(np.mean(times_mcruns), np.std(times_mcruns)))
+ print (' - [eval_3D()] Total time passed (save={}) : {}s \n'.format(save, round(time.time() - ttotal, 2)))
+ if verbose: pdb.set_trace()
+
+ return loss_avg, {i:loss_labels_avg[i] for i in range(len(loss_labels_avg))}
+
+ except:
+ model.trainable=True
+ traceback.print_exc()
+ return -1, {}
+
+def eval_3D_process_outputs(res, ece_global_obj, patient_id_curr, meta1_batch, patient_img, patient_gt, patient_pred_vals, patient_pred_overlap, patient_pred_std, patient_pred_ent, patient_pred_mif, patient_pred_unc
+ , deepsup_eval, model_folder_epoch_imgs, model_folder_epoch_patches, label_map, label_colors, t99, show=False, save=False, verbose=False):
+
+ try:
+
+ # Step 3.1.1 - Get stitched patient grid
+ if verbose: t0 = time.time()
+ patient_pred_ent = patient_pred_ent/patient_pred_overlap # [H,W,D]/[H,W,D]
+ patient_pred_mif = patient_pred_mif/patient_pred_overlap
+ patient_pred_overlap = np.expand_dims(patient_pred_overlap, -1)
+ patient_pred = patient_pred_vals/patient_pred_overlap # [H,W,D,C]/[H,W,D,1]
+ patient_pred_std = patient_pred_std/patient_pred_overlap
+ patient_pred_unc = patient_pred_unc/patient_pred_overlap
+ del patient_pred_vals
+ del patient_pred_overlap
+ if 1:
+ patient_pred_unc = np.take_along_axis(patient_pred_unc, np.expand_dims(np.argmax(patient_pred, axis=-1),axis=-1), axis=-1)[:,:,:,0]
+
+
+ gc.collect() # returns number of unreachable objects collected by GC
+ patient_pred_postprocessed = losses.remove_smaller_components(patient_gt, patient_pred, meta=patient_id_curr, label_ids_small = [2,4,5])
+ if verbose: print (' - [eval_3D()] Post-Process time : ', time.time() - t0,'s')
+
+ # Step 3.1.2 - Loss Calculation
+ spacing = np.array([meta1_batch[4], meta1_batch[5], meta1_batch[6]])/100.0
+ try:
+ if verbose: t0 = time.time()
+ loss_avg_val, loss_labels_val = losses.dice_numpy(patient_gt, patient_pred_postprocessed)
+ hausdorff_avg_val, hausdorff_labels_val, hausdorff95_avg_val, hausdorff95_labels_val, msd_avg_val, msd_labels_vals = losses.get_surface_distances(patient_gt, patient_pred_postprocessed, spacing)
+ if verbose:
+ print (' - [eval_3D()] DICE : ', ['%.4f' % (each) for each in loss_labels_val])
+ print (' - [eval_3D()] HD95 : ', ['%.4f' % (each) for each in hausdorff95_labels_val])
+
+ if loss_avg_val != -1 and len(loss_labels_val):
+ res[patient_id_curr] = {
+ config.KEY_DICE_AVG : loss_avg_val
+ , config.KEY_DICE_LABELS : loss_labels_val
+ , config.KEY_HD_AVG : hausdorff95_avg_val
+ , config.KEY_HD_LABELS : hausdorff_labels_val
+ , config.KEY_HD95_AVG : hausdorff95_avg_val
+ , config.KEY_HD95_LABELS : hausdorff95_labels_val
+ , config.KEY_MSD_AVG : msd_avg_val
+ , config.KEY_MSD_LABELS : msd_labels_vals
+ }
+ else:
+ print (' - [ERROR][eval_3D()] patient_id: ', patient_id_curr)
+ if verbose: print (' - [eval_3D()] Loss calculation time: ', time.time() - t0,'s')
+ except:
+ traceback.print_exc()
+ if DEBUG: pdb.set_trace()
+
+ # Step 3.1.3 - ECE calculation
+ if verbose: t0 = time.time()
+ ece_global_obj, ece_patient_obj = get_ece(patient_gt, patient_pred, patient_id_curr, ece_global_obj)
+ res[patient_id_curr][config.KEY_ECE_LABELS] = ece_patient_obj
+ if verbose: print (' - [eval_3D()] ECE time : ', time.time() - t0,'s')
+
+ # Step 3.1.4 - Uncertainty Quantification PAvPU
+ if verbose: t0 = time.time()
+ if 0:
+ prob_acc_cer_ent, prob_unc_inacc_ent, pavpu_ent, thresh_ent, patient_pred_error = losses.get_pavpu_errorareas(patient_gt, patient_pred, patient_pred_ent, unc_threshold=config.PAVPU_ENT_THRESHOLD, unc_type='entropy')
+ res[patient_id_curr][config.KEY_AVU_PAC_ENT] = prob_acc_cer_ent
+ res[patient_id_curr][config.KEY_AVU_PUI_ENT] = prob_unc_inacc_ent
+ res[patient_id_curr][config.KEY_AVU_ENT] = pavpu_ent
+ res[patient_id_curr][config.KEY_THRESH_ENT] = thresh_ent
+
+ prob_acc_cer_mif, prob_unc_inacc_mif, pavpu_mif, thresh_mif, patient_pred_error = losses.get_pavpu_errorareas(patient_gt, patient_pred, patient_pred_mif, unc_threshold=config.PAVPU_MIF_THRESHOLD, unc_type='mutual info')
+ res[patient_id_curr][config.KEY_AVU_PAC_MIF] = prob_acc_cer_mif
+ res[patient_id_curr][config.KEY_AVU_PUI_MIF] = prob_unc_inacc_mif
+ res[patient_id_curr][config.KEY_AVU_MIF] = pavpu_mif
+ res[patient_id_curr][config.KEY_THRESH_MIF] = thresh_mif
+
+ # prob_acc_cer_unc, prob_unc_inacc_unc, pavpu_unc, thresh_unc, patient_pred_error = losses.get_pavpu_errorareas(patient_gt, patient_pred, patient_pred_unc, unc_threshold=config.PAVPU_UNC_THRESHOLD, unc_type='percentile subtracts')
+ # res[patient_id_curr][config.KEY_AVU_PAC_UNC] = prob_acc_cer_unc
+ # res[patient_id_curr][config.KEY_AVU_PUI_UNC] = prob_unc_inacc_unc
+ # res[patient_id_curr][config.KEY_AVU_UNC] = pavpu_unc
+ # res[patient_id_curr][config.KEY_THRESH_UNC] = thresh_unc
+ elif 1:
+ prob_acc_cer_ent, prob_unc_inacc_ent, pavpu_ent, patient_pred_error = losses.get_pavpu_gtai(patient_gt, patient_pred, patient_pred_ent, unc_threshold=config.PAVPU_ENT_THRESHOLD, unc_type='entropy')
+ res[patient_id_curr][config.KEY_AVU_PAC_ENT] = prob_acc_cer_ent
+ res[patient_id_curr][config.KEY_AVU_PUI_ENT] = prob_unc_inacc_ent
+ res[patient_id_curr][config.KEY_AVU_ENT] = pavpu_ent
+ res[patient_id_curr][config.KEY_THRESH_ENT] = config.PAVPU_ENT_THRESHOLD
+
+ prob_acc_cer_mif, prob_unc_inacc_mif, pavpu_mif, patient_pred_error = losses.get_pavpu_gtai(patient_gt, patient_pred, patient_pred_mif, unc_threshold=config.PAVPU_MIF_THRESHOLD, unc_type='mutual info')
+ res[patient_id_curr][config.KEY_AVU_PAC_MIF] = prob_acc_cer_mif
+ res[patient_id_curr][config.KEY_AVU_PUI_MIF] = prob_unc_inacc_mif
+ res[patient_id_curr][config.KEY_AVU_MIF] = pavpu_mif
+ res[patient_id_curr][config.KEY_THRESH_MIF] = config.PAVPU_MIF_THRESHOLD
+
+ if verbose: print (' - [eval_3D()] PAvPU time : ', time.time() - t0,'s')
+
+ # Step 3.1.5 - Save/Visualize
+ if not deepsup_eval:
+ if verbose: t0 = time.time()
+ eval_3D_finalize(exp_name, patient_img, patient_gt, patient_pred_postprocessed, patient_pred, patient_pred_std, patient_pred_ent, patient_pred_mif, patient_pred_unc, patient_pred_error
+ , patient_id_curr
+ , model_folder_epoch_imgs, model_folder_epoch_patches
+ , loss_labels_val, hausdorff_labels_val, hausdorff95_labels_val, msd_labels_vals
+ , spacing, label_map, label_colors
+ , show=show, save=save)
+ if verbose: print (' - [eval_3D()] Save as .nrrd time : ', time.time() - t0,'s')
+
+ if verbose: print (' - [eval_3D()] Total patient time : ', time.time() - t99,'s')
+
+ # Step 3.1.6
+ del patient_img
+ del patient_gt
+ del patient_pred
+ del patient_pred_std
+ del patient_pred_ent
+ del patient_pred_postprocessed
+ del patient_pred_mif
+ del patient_pred_unc
+ gc.collect()
+
+ return res, ece_global_obj
+
+ except:
+ traceback.print_exc()
+ if DEBUG: pdb.set_trace()
+ return res, ece_global_obj
+
+def eval_3D_get_outputs(model, X, Y, training_bool, MC_RUNS, deepsup, deepsup_eval):
+
+ # Step 0 - Init
+ # DO_KEYS = [config.KEY_PERC]
+ DO_KEYS = [config.KEY_MIF, config.KEY_ENT]
+
+ # Step 1 - Warm up model
+ _ = model(X, training=training_bool)
+
+ # Step 2 - Run Monte-Carlo predictions
+ try:
+ tic_mcruns = time.time()
+ if deepsup:
+ if deepsup_eval:
+ y_predict = tf.stack([model(X, training=training_bool)[0] for _ in range(MC_RUNS)]) # [MC,B,H,W,D,C] # [Note] with model.trainable=False and model.training=True we get dropout at inference time
+ X = X[:,::2,::2,::2,:]
+ Y = Y[:,::2,::2,::2,:]
+ else:
+ y_predict = tf.stack([model(X, training=training_bool)[1] for _ in range(MC_RUNS)]) # [MC,B,H,W,D,C] # [Note] with model.trainable=False and model.training=True we get dropout at inference time
+ else:
+ y_predict = tf.stack([model(X, training=training_bool) for _ in range(MC_RUNS)]) # [MC,B,H,W,D,C] # [Note] with model.trainable=False and model.training=True we get dropout at inference time
+ toc_mcruns = time.time()
+ except tf.errors.ResourceExhaustedError as e:
+ print (' - [eval_3D_get_outputs()] OOM error for MC_RUNS={}'.format(MC_RUNS))
+
+ try:
+ MC_RUNS = 5
+ tic_mcruns = time.time()
+ if deepsup:
+ if deepsup_eval:
+ y_predict = tf.stack([model(X, training=training_bool)[0] for _ in range(MC_RUNS)]) # [MC,B,H,W,D,C] # [Note] with model.trainable=False and model.training=True we get dropout at inference time
+ X = X[:,::2,::2,::2,:]
+ Y = Y[:,::2,::2,::2,:]
+ else:
+ y_predict = tf.stack([model(X, training=training_bool)[1] for _ in range(MC_RUNS)]) # [MC,B,H,W,D,C] # [Note] with model.trainable=False and model.training=True we get dropout at inference time
+ else:
+ y_predict = tf.stack([model(X, training=training_bool) for _ in range(MC_RUNS)]) # [MC,B,H,W,D,C] # [Note] with model.trainable=False and model.training=True we get dropout at inference time
+ toc_mcruns = time.time()
+ except tf.errors.ResourceExhaustedError as e:
+ print (' - [eval_3D_get_outputs()] OOM error for MC_RUNS=10')
+
+ # Step 3 - Calculate different metrics
+ if config.KEY_MIF in DO_KEYS:
+ y_predict_mif = y_predict * tf.math.log(y_predict + _EPSILON) # [MC,B,H,W,D,C]
+ y_predict_mif = tf.math.reduce_sum(y_predict_mif, axis=[0,-1])/MC_RUNS # [MC,B,H,W,D,C] -> [B,H,W,D]
+ else:
+ y_predict_mif = []
+
+ if config.KEY_STD in DO_KEYS:
+ y_predict_std = tf.math.reduce_std(y_predict, axis=0) # [MC,B,H,W,D,C] -> [B,H,W,D,C]
+ else:
+ y_predict_std = []
+
+ if config.KEY_PERC in DO_KEYS:
+ y_predict_perc = tfp.stats.percentile(y_predict, q=[30,70], axis=0, interpolation='nearest')
+ y_predict_unc = y_predict_perc[1] - y_predict_perc[0]
+ del y_predict_perc
+ gc.collect()
+ else:
+ y_predict_unc = []
+
+ y_predict = tf.math.reduce_mean(y_predict, axis=0)
+
+ if config.KEY_ENT in DO_KEYS:
+ y_predict_ent = -1*tf.math.reduce_sum(y_predict * tf.math.log(y_predict + _EPSILON), axis=-1) # [B,H,W,D,C] -> # [B,H,W,D] ent = -p.log(p)
+ y_predict_mif = y_predict_ent + y_predict_mif # [B,H,W,D] + [B,H,W,D] = [B,H,W,D]; MI = ent + expectation(ent)
+ else:
+ y_predict_ent = []
+ y_predict_mif = []
+
+ return Y, y_predict, y_predict_std, y_predict_ent, y_predict_mif, y_predict_unc, toc_mcruns-tic_mcruns
+
+def eval_3D(model, dataset_eval, dataset_eval_gen, params, show=False, save=False, verbose=False):
+
+ try:
+
+ # Step 0.0 - Variables under debugging
+ pass
+
+ # Step 0.1 - Extract params
+ PROJECT_DIR = params['PROJECT_DIR']
+ exp_name = params['exp_name']
+ pid = params['pid']
+ eval_type = params['eval_type']
+ batch_size = params['batch_size']
+ batch_size = 2
+ epoch = params['epoch']
+ deepsup = params['deepsup']
+ deepsup_eval = params['deepsup_eval']
+ label_map = dict(dataset_eval.get_label_map())
+ label_colors = dict(dataset_eval.get_label_colors())
+
+ if verbose: print (''); print (' --------------------- eval_3D({}) ---------------------'.format(eval_type))
+
+ # Step 0.2 - Init results array
+ res = {}
+ ece_global_obj = {}
+ patient_grid_count = {}
+
+ # Step 0.3 - Init temp variables
+ patient_id_curr = None
+ w_grid, h_grid, d_grid = None, None, None
+ meta1_batch = None
+ patient_gt = None
+ patient_img = None
+ patient_pred_overlap = None
+ patient_pred_vals = None
+ model_folder_epoch_patches = None
+ model_folder_epoch_imgs = None
+
+ mc_runs = params.get(config.KEY_MC_RUNS, None)
+ training_bool = params.get(config.KEY_TRAINING_BOOL, None)
+ model_folder_epoch_patches, model_folder_epoch_imgs = utils.get_eval_folders(PROJECT_DIR, exp_name, epoch, eval_type, mc_runs, training_bool, create=True)
+
+ # Step 0.4 - Debug vars
+ ttotal,t0, t99 = time.time(), None, None
+ times_mcruns = []
+
+ # Step 1 - Loop over dataset_eval (which provides patients & grids in an ordered manner)
+ print ('')
+ model.trainable = False
+ pbar_desc_prefix = 'Eval3D_{} [batch={}]'.format(eval_type, batch_size)
+ training_bool = params.get('training_bool',True) # [True, False]
+ with tqdm.tqdm(total=len(dataset_eval), desc=pbar_desc_prefix, leave=False) as pbar_eval:
+ for (X,Y,meta1,meta2) in dataset_eval_gen.repeat(1):
+
+ # Step 1.1 - Get MC results
+ MC_RUNS = params.get(config.KEY_MC_RUNS,10)
+ Y, y_predict, y_predict_std, y_predict_ent, y_predict_mif, y_predict_unc, mcruns_time = eval_3D_get_outputs(model, X, Y, training_bool, MC_RUNS, deepsup, deepsup_eval)
+ times_mcruns.append(mcruns_time)
+
+ for batch_id in range(X.shape[0]):
+
+ # Step 2 - Get grid info
+ patient_id_running = meta2[batch_id].numpy().decode('utf-8')
+ if patient_id_running in patient_grid_count: patient_grid_count[patient_id_running] += 1
+ else: patient_grid_count[patient_id_running] = 1
+
+ meta1_batch = meta1[batch_id].numpy()
+ w_start, h_start, d_start = meta1_batch[1], meta1_batch[2], meta1_batch[3]
+
+ # Step 3 - Check if its a new patient
+ if patient_id_running != patient_id_curr:
+
+ # Step 3.1 - Sort out old patient (patient_id_curr)
+ if patient_id_curr != None:
+
+ res, ece_global_obj = eval_3D_process_outputs(res, ece_global_obj, patient_id_curr, meta1_batch, patient_img, patient_gt, patient_pred_vals, patient_pred_overlap, patient_pred_std, patient_pred_ent, patient_pred_mif, patient_pred_unc
+ , deepsup_eval, model_folder_epoch_imgs, model_folder_epoch_patches, label_map, label_colors, t99, show=show, save=save, verbose=verbose)
+
+ # Step 3.2 - Create variables for new patient
+ if verbose: t99 = time.time()
+ patient_id_curr = patient_id_running
+ patient_scan_size = meta1_batch[7:10]
+ dataset_name = patient_id_curr.split('-')[0]
+ dataset_this = dataset_eval.get_subdataset(param_name=dataset_name)
+ w_grid, h_grid, d_grid = dataset_this.w_grid, dataset_this.h_grid, dataset_this.d_grid
+ if deepsup_eval:
+ patient_scan_size = patient_scan_size//2
+ w_grid, h_grid, d_grid = w_grid//2, h_grid//2, d_grid//2
+
+ patient_pred_size = list(patient_scan_size) + [len(dataset_this.LABEL_MAP)]
+ patient_pred_overlap = np.zeros(patient_scan_size, dtype=np.uint8)
+ patient_pred_ent = np.zeros(patient_scan_size, dtype=np.float32)
+ patient_pred_mif = np.zeros(patient_scan_size, dtype=np.float32)
+ patient_pred_vals = np.zeros(patient_pred_size, dtype=np.float32)
+ patient_pred_std = np.zeros(patient_pred_size, dtype=np.float32)
+ patient_pred_unc = np.zeros(patient_pred_size, dtype=np.float32)
+ patient_gt = np.zeros(patient_pred_size, dtype=np.float32)
+ if show or save:
+ patient_img = np.zeros(list(patient_scan_size) + [1], dtype=np.float32)
+ else:
+ patient_img = []
+
+ # Step 4 - If not new patient anymore, fill up data
+ if deepsup_eval:
+ w_start, h_start, d_start = w_start//2, h_start//2, d_start//2
+ patient_pred_vals[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += y_predict[batch_id]
+ if len(y_predict_std):
+ patient_pred_std[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += y_predict_std[batch_id]
+ if len(y_predict_ent):
+ patient_pred_ent[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += y_predict_ent[batch_id]
+ if len(y_predict_mif):
+ patient_pred_mif[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += y_predict_mif[batch_id]
+ if len(y_predict_unc):
+ patient_pred_unc[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += y_predict_unc[batch_id]
+
+ patient_pred_overlap[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += np.ones(y_predict[batch_id].shape[:-1], dtype=np.uint8)
+ patient_gt[w_start:w_start+w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] = Y[batch_id]
+ if show or save:
+ patient_img[w_start:w_start+w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] = X[batch_id]
+
+ pbar_eval.update(batch_size)
+ mem_used = utils.get_memory(pid)
+ memory = pbar_desc_prefix + ' [' + mem_used + ']'
+ pbar_eval.set_description(desc=memory, refresh=True)
+
+ # Step 3 - For last patient
+ res, ece_global_obj = eval_3D_process_outputs(res, ece_global_obj, patient_id_curr, meta1_batch, patient_img, patient_gt, patient_pred_vals, patient_pred_overlap, patient_pred_std, patient_pred_ent, patient_pred_mif, patient_pred_unc
+ , deepsup_eval, model_folder_epoch_imgs, model_folder_epoch_patches, label_map, label_colors, t99, show=show, save=save, verbose=verbose)
+ print ('\n - [eval_3D()] Time passed to accumulate grids & process patients: ', round(time.time() - ttotal, 2), 's')
+
+ return eval_3D_summarize(res, ece_global_obj, model, eval_type, deepsup_eval, label_map, model_folder_epoch_patches, times_mcruns, ttotal, save=save, show=show, verbose=verbose)
+
+ except:
+ traceback.print_exc()
+ model.trainable = True
+ return -1, {}
+
+############################################################
+# VAL #
+############################################################
+
+def val(model, dataset, params, show=False, save=False, verbose=False):
+ try:
+
+ # Step 1 - Load Model
+
+ load_model_params = {'PROJECT_DIR': params['PROJECT_DIR']
+ , 'exp_name': params['exp_name']
+ , 'load_epoch': params['epoch']
+ , 'optimizer': tf.keras.optimizers.Adam()
+ }
+ utils.load_model(model, load_type=config.MODE_VAL, params=load_model_params)
+ print ('')
+ print (' - [train.py][val()] Model({}) Loaded for {} at epoch-{} (validation purposes) !'.format(str(model), params['exp_name'], params['epoch']))
+ print ('')
+
+ # Step 3 - Calculate losses
+ dataset_gen = dataset.generator().batch(params['batch_size']).prefetch(params['prefetch_batch'])
+ loss_avg, loss_labels_avg = eval_3D(model, dataset, dataset_gen, params, show=show, save=save, verbose=verbose)
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+def val_dropout(exp_name, model, dataset, epoch, batch_size=32, show=False):
+ try:
+ utils.load_model(exp_name, model, epoch, {'optimizer':tf.keras.optimizers.Adam()}, load_type='test')
+ print (' - [train.py][val()] Model Loaded at epoch-{} !'.format(epoch))
+
+ MC_RUNS = 10
+
+ with tqdm.tqdm(total=len(dataset), desc='') as pbar:
+ for (X,Y,meta1, meta2) in dataset.generator().batch(batch_size):
+
+ if 1:
+ y_predict = tf.stack([model(X, training=True) for _ in range(MC_RUNS)]) # [Note] with model.trainable=False and model.training=True we get dropout at inference time
+ y_predict_mean = tf.math.reduce_mean(y_predict, axis=0)
+ y_predict_std = tf.math.reduce_std(y_predict, axis=0)
+ utils.viz_model_output_dropout(X,Y,y_predict_mean, y_predict_std, exp_name, epoch, dataset, meta2, mode='val_dropout')
+
+ if 1:
+ y_predict = model(X, training=False)
+ utils.viz_model_output(X,Y,y_predict, exp_name, epoch, dataset, meta2, mode='val')
+
+ pbar.update(batch_size)
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+############################################################
+# TRAINER #
+############################################################
+
+class Trainer:
+
+ def __init__(self, params):
+
+ # Init
+ self.params = params
+
+ # Print
+ self._train_preprint()
+
+ # Random Seeds
+ self._set_seed()
+
+ # Set the dataloaders
+ self._set_dataloaders()
+
+ # Set the model
+ self._set_model()
+
+ # Set Metrics
+ self._set_metrics()
+
+ # Other flags
+ self.write_model_done = False
+
+ def _train_preprint(self):
+ print ('')
+ print (' -------------- {} ({})'.format(self.params['exp_name'], str(datetime.datetime.now())))
+
+ print ('')
+ print (' DATALOADER ')
+ print (' ---------- ')
+ print (' - dir_type: ', self.params['dataloader']['dir_type'])
+
+ print (' -- resampled: ', self.params['dataloader']['resampled'])
+ print (' -- crop_init: ', self.params['dataloader']['crop_init'])
+ print (' -- grid: ', self.params['dataloader']['grid'])
+ print (' --- filter_grid : ', self.params['dataloader']['filter_grid'])
+ print (' --- random_grid : ', self.params['dataloader']['random_grid'])
+ print (' --- centred_prob : ', self.params['dataloader']['centred_prob'])
+
+ print (' -- batch_size: ', self.params['dataloader']['batch_size'])
+ print (' -- prefetch_batch : ', self.params['dataloader']['prefetch_batch'])
+ print (' -- parallel_calls : ', self.params['dataloader']['parallel_calls'])
+ print (' -- shuffle : ', self.params['dataloader']['shuffle'])
+
+ print (' -- single_sample: ', self.params['dataloader']['single_sample'])
+ if self.params['dataloader']['single_sample']:
+ print (' !!!!!!!!!!!!!!!!!!! SINGLE SAMPLE !!!!!!!!!!!!!!!!!!!')
+ print ('')
+
+ print ('')
+ print (' MODEL ')
+ print (' ----- ')
+ print (' - Model: ', str(self.params['model']['name']))
+ print (' -- KL Schedule : ', self.params['model']['kl_schedule'])
+ print (' -- KL Alpha Init: ', self.params['model']['kl_alpha_init'])
+ print (' -- KL Scaler : ', self.params['model']['kl_scale_factor'])
+ print (' -- Activation : ', self.params['model']['activation'])
+ print (' -- Kernel Reg : ', self.params['model']['kernel_reg'])
+ print (' -- Model TBoard : ', self.params['model']['model_tboard'])
+ print (' -- Profiler : ', self.params['model']['profiler']['profile'])
+ if self.params['model']['profiler']['profile']:
+ print (' ---- Profiler Epochs: ', self.params['model']['profiler']['epochs'])
+ print (' ---- Step Per Epochs: ', self.params['model']['profiler']['steps_per_epoch'])
+ print (' - Optimizer: ', str(self.params['model']['optimizer']))
+ print (' -- Init LR : ', self.params['model']['init_lr'])
+ print (' -- Fixed LR : ', self.params['model']['fixed_lr'])
+ print (' -- Grad Persistent: ', self.params['model']['grad_persistent'])
+ if self.params['model']['grad_persistent']:
+ print (' !!!!!!!!!!!!!!!!!!! GRAD PERSISTENT !!!!!!!!!!!!!!!!!!!')
+ print ('')
+ print (' - Epochs: ', self.params['model']['epochs'])
+ print (' -- Save : every {} epochs'.format(self.params['model']['epochs_save']))
+ print (' -- Eval3D : every {} epochs '.format(self.params['model']['epochs_eval']))
+ print (' -- Viz3D : every {} epochs '.format(self.params['model']['epochs_viz']))
+
+ print ('')
+ print (' METRICS ')
+ print (' ------- ')
+ print (' - Logging-TBoard: ', self.params['metrics']['logging_tboard'])
+ if not self.params['metrics']['logging_tboard']:
+ print (' !!!!!!!!!!!!!!!!!!! NO LOGGING-TBOARD !!!!!!!!!!!!!!!!!!!')
+ print ('')
+ print (' - Eval: ', self.params['metrics']['metrics_eval'])
+ print (' - Loss: ', self.params['metrics']['metrics_loss'])
+ print (' -- Type of Loss : ', self.params['metrics']['loss_type'])
+ print (' -- Weighted Loss : ', self.params['metrics']['loss_weighted'])
+ print (' -- Masked Loss : ', self.params['metrics']['loss_mask'])
+ print (' -- Combo : ', self.params['metrics']['loss_combo'])
+ print (' -- Loss Epoch : ', self.params['metrics']['loss_epoch'])
+ print (' -- Loss Rate : ', self.params['metrics']['loss_rate'])
+
+ print ('')
+ print (' DEVOPS ')
+ print (' ------ ')
+ self.pid = os.getpid()
+ print (' - OS-PID: ', self.pid)
+ print (' - Seed: ', self.params['random_seed'])
+
+ print ('')
+
+ def _set_seed(self):
+ np.random.seed(self.params['random_seed'])
+ tf.random.set_seed(self.params['random_seed'])
+
+ def _set_dataloaders(self):
+
+ # Params - Directories
+ data_dir = self.params['dataloader']['data_dir']
+ dir_type = self.params['dataloader']['dir_type']
+ dir_type_eval = ['_'.join(dir_type)]
+
+ # Params - Single volume
+ resampled = self.params['dataloader']['resampled']
+ crop_init = self.params['dataloader']['crop_init']
+ grid = self.params['dataloader']['grid']
+ filter_grid = self.params['dataloader']['filter_grid']
+ random_grid = self.params['dataloader']['random_grid']
+ centred_prob = self.params['dataloader']['centred_prob']
+
+ # Params - Dataloader
+ batch_size = self.params['dataloader']['batch_size']
+ prefetch_batch = self.params['dataloader']['prefetch_batch']
+ parallel_calls = self.params['dataloader']['parallel_calls']
+ shuffle_size = self.params['dataloader']['shuffle']
+
+ # Params - Debug
+ single_sample = self.params['dataloader']['single_sample']
+
+
+ # Datasets
+ self.dataset_train = get_dataloader_3D_train(data_dir, dir_type=dir_type
+ , grid=grid, crop_init=crop_init, filter_grid=filter_grid
+ , random_grid=random_grid
+ , resampled=resampled, single_sample=single_sample
+ , parallel_calls=parallel_calls
+ , centred_dataloader_prob=centred_prob
+ )
+ self.dataset_train_eval = get_dataloader_3D_train_eval(data_dir, dir_type=dir_type_eval
+ , grid=grid, crop_init=crop_init
+ , resampled=resampled, single_sample=single_sample
+ )
+ self.dataset_test_eval = get_dataloader_3D_test_eval(data_dir
+ , grid=grid, crop_init=crop_init
+ , resampled=resampled, single_sample=single_sample
+ )
+
+ # Get labels Ids
+ self.label_map = dict(self.dataset_train.get_label_map())
+ self.label_ids = self.label_map.values()
+ self.params['internal'] = {}
+ self.params['internal']['label_map'] = self.label_map # for use in Metrics
+ self.params['internal']['label_ids'] = self.label_ids # for use in Metrics
+ self.label_weights = list(self.dataset_train.get_label_weights())
+
+ # Generators
+ self.dataset_train_gen = self.dataset_train.generator().repeat().shuffle(shuffle_size).batch(batch_size).apply(tf.data.experimental.prefetch_to_device(device='/GPU:0', buffer_size=prefetch_batch))
+ self.dataset_train_eval_gen = self.dataset_train_eval.generator().batch(2).prefetch(prefetch_batch)
+ self.dataset_test_eval_gen = self.dataset_test_eval.generator().batch(2).prefetch(prefetch_batch)
+
+ def _set_model(self):
+
+ # Step 1 - Get class ids
+ class_count = len(self.label_ids)
+ deepsup = self.params['model']['deepsup']
+ activation = self.params['model']['activation']
+
+ # Step 2 - Get model arch
+ self.kl_schedule = self.params['model']['kl_schedule']
+ if self.kl_schedule == config.KL_DIV_FIXED:
+ self.kl_alpha_init = self.params['model']['kl_alpha_init']
+ elif self.kl_schedule == config.KL_DIV_ANNEALING:
+ if 0:
+ self.kl_alpha_init = 0.100 # [0.100, 0.050, 0.020, 0.010]
+ self.kl_alpha_increase_per_epoch = 0.001
+ self.kl_alpha_max = 0.030
+ self.initial_epoch = 250
+ self.kl_epochs_change = 10
+ elif 1:
+ self.kl_alpha_init = 0.05
+ self.kl_alpha_increase_per_epoch = 0.0001
+ self.initial_epoch = 250
+ elif 0:
+ self.kl_alpha_init = 0.001
+ self.kl_alpha_increase_per_epoch = 0.0001
+ self.initial_epoch = 250
+
+ if self.params['model']['name'] == config.MODEL_FOCUSNET_FLIPOUT:
+ print (' - [Trainer][_models()] ModelFocusNetFlipOut')
+ self.model = models.ModelFocusNetFlipOut(class_count=class_count, trainable=True, activation=activation, deepsup=deepsup)
+
+ elif self.params['model']['name'] == config.MODEL_FOCUSNET_FLIPOUT_V2:
+ print (' - [Trainer][_models()] ModelFocusNetFlipOutV2')
+ self.model = models.ModelFocusNetFlipOutV2(class_count=class_count, trainable=True, activation=activation, deepsup=deepsup)
+
+ elif self.params['model']['name'] == config.MODEL_FOCUSNET_FLIPOUT_POOL2:
+ print (' - [Trainer][_models()] ModelFocusNetFlipOutPool2')
+ self.model = models.ModelFocusNetFlipOutPool2(class_count=class_count, trainable=True, activation=activation, deepsup=deepsup)
+
+ # print (' - [Trainer][_set_model()] initial_epoch: ', self.initial_epoch)
+ print (' - [Trainer][_set_model()] kl_alpha_init: ', self.kl_alpha_init)
+ # print (' - [Trainer][_set_model()] kl_alpha_increase_per_epoch: ', self.kl_alpha_increase_per_epoch)
+
+ # Step 3 - Get optimizer
+ if self.params['model']['optimizer'] == config.OPTIMIZER_ADAM:
+ self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.params['model']['init_lr'])
+
+ # Step 4 - Load model if needed
+ epochs = self.params['model']['epochs']
+ if not self.params['model']['load_model']['load']:
+ # Step 4.1 - Set epoch range under non-loading situations
+ self.epoch_range = range(1,epochs+1)
+ else:
+
+ # Step 4.2.1 - Some model-loading params
+ load_epoch = self.params['model']['load_model']['load_epoch']
+ load_exp_name = self.params['model']['load_model']['load_exp_name']
+ load_optimizer_lr = self.params['model']['load_model']['load_optimizer_lr']
+ load_model_params = {'PROJECT_DIR': self.params['PROJECT_DIR'], 'load_epoch': load_epoch, 'optimizer':self.optimizer}
+
+ print ('')
+ print (' - [Trainer][_set_model()] Loading pretrained model')
+ print (' - [Trainer][_set_model()] Model: ', self.model)
+
+ # Step 4.2.2.1 - If loading is done from the same exp_name
+ if load_exp_name is None:
+ load_model_params['exp_name'] = exp_name
+ self.epoch_range = range(load_epoch+1, epochs+1)
+ print (' - [Trainer][_set_model()] Training from epoch:{} to {}'.format(load_epoch, epochs))
+ # Step 4.2.2.1 - If loading is done from another exp_name
+ else:
+ self.epoch_range = range(1, epochs+1)
+ load_model_params['exp_name'] = load_exp_name
+ print (' - [Trainer][_set_model()] Training from epoch:{} to {}'.format(1, epochs))
+
+ print (' - [Trainer][_set_model()] exp_name: ', load_model_params['exp_name'])
+
+ # Step 4.3 - Finally, load model from the checkpoint
+ utils.load_model(self.model, load_type=config.MODE_TRAIN, params=load_model_params)
+ print (' - [Trainer][_set_model()] Model Loaded at epoch-{} !'.format(load_epoch))
+ print (' -- [Trainer][_set_model()] Optimizer.lr : ', self.optimizer.lr.numpy())
+ if load_optimizer_lr is not None:
+ self.optimizer.lr.assign(load_optimizer_lr)
+ print (' -- [Trainer][_set_model()] Optimizer.lr : ', self.optimizer.lr.numpy())
+
+ # Step 5 - Creae model weights
+ # init_size = ((1,240,240,40,1))
+ init_size = ((1,140,140,40,1))
+ print ('\n -- [Trainer][_set_model()] Mode weight creation with ', init_size, '\n')
+ X_tmp = tf.random.normal(init_size) # if the final dataloader does not have the same input size, the weight initialization gets screwed up.
+ _ = self.model(X_tmp)
+ self.layers_kl_std = self.get_layers_kl_std(std=True)
+ print (' -- [Trainer][_set_model()] Created model weights ')
+ try:
+ print (' --------------------------------------- ')
+ print (self.model.summary(line_length=150))
+ print (' --------------------------------------- ')
+ count = 0
+ for var in self.model.trainable_variables:
+ print (' - var: ', var.name)
+ count += 1
+ if count > 20:
+ print (' ... ')
+ break
+ except:
+ print (' - [Trainer][_set_model()] model.summary() failed')
+ pass
+
+ def _set_metrics(self):
+
+ self.metrics = {}
+ self.metrics[config.MODE_TRAIN] = ModelMetrics(metric_type=config.MODE_TRAIN, params=self.params)
+ self.metrics[config.MODE_TEST] = ModelMetrics(metric_type=config.MODE_TEST, params=self.params)
+
+ deepsup = self.params['model']['deepsup']
+ if deepsup:
+ self.metrics[config.MODE_TRAIN_DEEPSUP] = ModelMetrics(metric_type=config.MODE_TRAIN_DEEPSUP, params=self.params)
+ self.metrics[config.MODE_TEST_DEEPSUP] = ModelMetrics(metric_type=config.MODE_TEST_DEEPSUP, params=self.params)
+
+ def _set_profiler(self, epoch, epoch_step):
+ exp_name = self.params['exp_name']
+
+ if self.params['model']['profiler']['profile']:
+ if epoch in self.params['model']['profiler']['epochs']:
+ if epoch_step == self.params['model']['profiler']['starting_step']:
+ self.logdir = Path(config.MODEL_CHKPOINT_MAINFOLDER).joinpath(exp_name, config.MODEL_LOGS_FOLDERNAME, 'profiler', str(epoch))
+ tf.profiler.experimental.start(str(self.logdir))
+ print (' - tf.profiler.experimental.start(logdir)')
+ print ('')
+ elif epoch_step == self.params['model']['profiler']['starting_step'] + self.params['model']['profiler']['steps_per_epoch']:
+ print (' - tf.profiler.experimental.stop()')
+ tf.profiler.experimental.stop()
+ print ('')
+
+ @tf.function
+ def get_layers_kl_std(self, std=False):
+
+ res = {}
+ for layer in self.model.layers:
+ for loss_id, loss in enumerate(layer.losses):
+ layer_name = layer.name + '_' + str(loss_id)
+ res[layer_name] = {'kl': loss}
+
+ if std:
+ if hasattr(layer, 'conv_layer'):
+ if loss_id == 0:
+ mean_vals = layer.conv_layer.submodules[1].kernel_posterior.distribution.loc
+ std_vals = layer.conv_layer.submodules[1].kernel_posterior.distribution.scale
+ res[layer_name]['std'] = std_vals
+ res[layer_name]['mean'] = mean_vals
+ elif loss_id == 1:
+ mean_vals = layer.conv_layer.submodules[3].kernel_posterior.distribution.loc
+ std_vals = layer.conv_layer.submodules[3].kernel_posterior.distribution.scale
+ res[layer_name]['std'] = std_vals
+ res[layer_name]['mean'] = mean_vals
+
+ elif hasattr(layer, 'convblock_res'):
+ if loss_id == 0:
+ mean_vals = layer.convblock_res.conv_layer.submodules[1].kernel_posterior.distribution.loc
+ std_vals = layer.convblock_res.conv_layer.submodules[1].kernel_posterior.distribution.scale
+ res[layer_name]['std'] = std_vals
+ res[layer_name]['mean'] = mean_vals
+ elif loss_id == 1:
+ mean_vals = layer.convblock_res.conv_layer.submodules[3].kernel_posterior.distribution.loc
+ std_vals = layer.convblock_res.conv_layer.submodules[3].kernel_posterior.distribution.scale
+ res[layer_name]['std'] = std_vals
+ res[layer_name]['mean'] = mean_vals
+
+ elif type(layer) == tfp.layers.Convolution3DFlipout:
+ mean_vals = layer.kernel_posterior.distribution.loc
+ std_vals = layer.kernel_posterior.distribution.scale
+ res[layer_name]['std'] = std_vals
+ res[layer_name]['mean'] = mean_vals
+
+ return res
+
+ @tf.function
+ def _train_loss(self, Y, y_predict, meta1, epoch, mode):
+
+ trainMetrics = self.metrics[mode]
+ metrics_loss = self.params['metrics']['metrics_loss']
+ loss_weighted = self.params['metrics']['loss_weighted']
+ loss_mask = self.params['metrics']['loss_mask']
+ loss_type = self.params['metrics']['loss_type']
+ loss_combo = self.params['metrics']['loss_combo']
+ loss_epoch = self.params['metrics']['loss_epoch']
+ loss_rate = self.params['metrics']['loss_rate']
+
+ label_ids = self.label_ids
+ label_weights = tf.cast(self.label_weights, dtype=tf.float32)
+
+ loss_vals = tf.constant(0.0, dtype=tf.float32)
+ mask = losses.get_mask(meta1[:,-len(label_ids):], Y)
+
+ inf_flag = False
+ nan_flag = False
+
+ for metric_str in metrics_loss:
+
+ weights = []
+ if loss_weighted[metric_str]:
+ weights = label_weights
+
+ if not loss_mask[metric_str]:
+ mask = tf.cast(tf.cast(mask + 1, dtype=tf.bool), dtype=tf.float32)
+
+ loss_epoch_metric = tf.constant(loss_epoch[metric_str], dtype=tf.float32)
+
+ if metrics_loss[metric_str] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_FOCAL, config.LOSS_PAVPU] and epoch >= loss_epoch_metric:
+ loss_val_train, loss_labellist_train, metric_val_report, metric_labellist_report = trainMetrics.losses_obj[metric_str](Y, y_predict, label_mask=mask, weights=weights)
+ if metrics_loss[metric_str] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_FOCAL, config.LOSS_PAVPU]:
+ nan_list = tf.math.is_nan(loss_labellist_train)
+ nan_val = tf.math.is_nan(loss_val_train)
+ inf_list = tf.math.is_inf(loss_labellist_train)
+ inf_val = tf.math.is_inf(loss_val_train)
+ if nan_val or tf.math.reduce_any(nan_list):
+ nan_flag = True
+ tf.print (' - [ERROR][Trainer][_train_loss()] Loss NaN spotted: ', metric_str, ' || nan_list: ', nan_list, ' || nan_val: ', nan_val)
+ tf.print (' - [ERROR][Trainer][_train_loss()] Loss NaN spotted: ', metric_str, ' || mask: ', mask)
+ tf.print (' - [ERROR][Trainer][_train_loss()] Loss NaN spotted: ', metric_str, ' || loss_vals: ', loss_vals)
+ elif inf_val or tf.math.reduce_any(inf_list):
+ inf_flag = True
+ tf.print (' - [ERROR][Trainer][_train_loss()] Loss Inf spotted: ', metric_str, ' || loss_val_train: ', loss_val_train)
+ tf.print (' - [ERROR][Trainer][_train_loss()] Loss Inf spotted: ', metric_str, ' || inf_list: ', inf_list, ' || inf_val: ', inf_val)
+ tf.print (' - [ERROR][Trainer][_train_loss()] Loss Inf spotted: ', metric_str, ' || mask: ', mask)
+ tf.print (' - [ERROR][Trainer][_train_loss()] Loss NaN spotted: ', metric_str, ' || loss_vals: ', loss_vals)
+ else:
+ if len(metric_labellist_report):
+ trainMetrics.update_metric_loss_labels(metric_str, metric_labellist_report) # in sub-3D settings, this value is only indicative of performance
+ trainMetrics.update_metric_loss(metric_str, loss_val_train)
+
+ if 0: # [DEBUG]
+ tf.print(' - metric_str: ', metric_str, ' || loss_val_train: ', loss_val_train)
+
+ if loss_type[metric_str] == config.LOSS_SCALAR:
+ if metric_str in loss_combo:
+ print (' - [Trainer][_train_loss()] Scalar loss')
+ if loss_epoch_metric == 0.0:
+ loss_val_train = loss_val_train*loss_combo[metric_str]
+ else:
+ loss_rate_metric = tf.constant(loss_rate[metric_str], dtype=tf.float32)
+ loss_factor = (epoch - loss_epoch_metric)/loss_rate_metric
+ if loss_factor > 1:loss_factor = tf.constant(1.0, dtype=tf.float32)
+ loss_val_train = loss_val_train*loss_combo[metric_str]*loss_factor
+ loss_vals = tf.math.add(loss_vals, loss_val_train) # Averaged loss
+
+ elif loss_type[metric_str] == config.LOSS_VECTOR:
+ if metric_str in loss_combo:
+ print (' - [Trainer][_train_loss()] Vector loss')
+ loss_labellist_train = loss_labellist_train*loss_combo[metric_str]
+ loss_vals = tf.math.add(loss_vals, loss_labellist_train) # Averaged loss by label
+
+ if nan_flag or inf_flag:
+ print (' - [INFO][Trainer][_train_loss()] loss_vals:', loss_vals)
+ loss_vals = 0.0 # no backprop when something was wrong
+
+ return loss_vals
+
+ @tf.function
+ def _train_step(self, X, Y, meta1, meta2, kl_alpha, epoch):
+
+ try:
+
+ if 1:
+ model = self.model
+ deepsup = self.params['model']['deepsup']
+ optimizer = self.optimizer
+ grad_persistent = self.params['model']['grad_persistent']
+ trainMetrics = self.metrics[config.MODE_TRAIN]
+ kl_scale_fac = self.params['model']['kl_scale_factor']
+ MC_RUNS = 5
+
+ y_predict = None
+ loss_vals = None
+ gradients = None
+
+ # Step 1 - Calculate loss
+ with tf.GradientTape(persistent=grad_persistent) as tape:
+
+
+ t2 = tf.timestamp()
+ if deepsup: (y_predict_deepsup, y_predict) = model(X, training=True)
+ else : y_predict = model(X, training=True)
+ t2_ = tf.timestamp()
+
+ loss_vals = self._train_loss(Y, y_predict, meta1, epoch, mode=config.MODE_TRAIN)
+
+ if loss_vals > 0:
+ if deepsup:
+ print (' - [Trainer][_train_step()] deepsup training')
+ Y_deepsup = Y[:,::2,::2,::2,:]
+ loss_vals_deepsup = self._train_loss(Y_deepsup, y_predict_deepsup, meta1, epoch, mode=config.MODE_TRAIN_DEEPSUP)
+ loss_vals += loss_vals_deepsup
+
+ print (' - [Trainer][_train_step()] Model FlipOut')
+ kl = tf.math.add_n(model.losses)
+ kl_loss = kl*kl_alpha/kl_scale_fac
+
+ kl_layers = {}
+ for layer in model.layers:
+ for loss_id, loss in enumerate(layer.losses):
+ layer_name = layer.name + '_' + str(loss_id)
+ kl_layers[layer_name] = {'kl': loss}
+ trainMetrics.update_metrics_kl(kl_alpha, kl, kl_layers)
+ trainMetrics.update_metrics_scalarloss(loss_vals, kl_loss)
+
+ if 0:
+ with tf.GradientTape(watch_accessed_variables=False) as meh:
+ print (' - [Trainer][_train_step()] Loopy Loss Expectation')
+ loss_loops = 2
+ for _ in tf.range(loss_loops-1):
+ y_predict = model(X, training=False)
+ loss_vals = tf.math.add(loss_vals, self._train_loss(Y, y_predict, meta1))
+ loss_vals = tf.math.divide(loss_vals, loss_loops)
+
+ loss_vals = loss_vals + kl_loss
+
+ # Step 2 - Calculate gradients
+ t3 = tf.timestamp()
+ if not tf.math.reduce_any(tf.math.is_nan(loss_vals)) and loss_vals > 0:
+ all_vars = model.trainable_variables
+
+ gradients = tape.gradient(loss_vals, all_vars) # dL/dW
+
+ # Step 3 - Apply gradients
+ optimizer.apply_gradients(zip(gradients, all_vars))
+ else:
+ tf.print('\n ====================== [NaN Error] ====================== ')
+ tf.print(' - [ERROR][Trainer][_train_step()] Loss NaN spotted || loss_vals: ', loss_vals)
+ tf.print(' - [ERROR][Trainer][_train_step()] meta2: ', meta2, ' || meta1: ', meta1)
+
+ t3_ = tf.timestamp()
+ return t2_-t2, t3-t2_, t3_-t3
+
+ except tf.errors.ResourceExhaustedError as e:
+ print (' - [ERROR][Trainer][_train_step()] OOM error')
+ return None, None, None
+
+ except:
+ tf.print('\n ====================== [Some Error] ====================== ')
+ tf.print(' - [ERROR][Trainer][_train_step()] meta2: ', meta2, ' || meta1: ', meta1)
+ traceback.print_exc()
+ return None, None, None
+
+ def train(self):
+
+ # Global params
+ exp_name = self.params['exp_name']
+
+ # Dataloader params
+ batch_size = self.params['dataloader']['batch_size']
+
+ # Model/Training params
+ fixed_lr = self.params['model']['fixed_lr']
+ init_lr = self.params['model']['init_lr']
+ max_epoch = self.params['model']['epochs']
+ epoch_range = iter(self.epoch_range)
+ epoch_length = len(self.dataset_train)
+ deepsup = self.params['model']['deepsup']
+ params_save_model = {'PROJECT_DIR': self.params['PROJECT_DIR'], 'exp_name': exp_name, 'optimizer':self.optimizer}
+
+ # Metrics params
+ metrics_eval = self.params['metrics']['metrics_eval']
+ trainMetrics = self.metrics[config.MODE_TRAIN]
+ trainMetrics.init_metrics_layers_kl_std(self.params, self.layers_kl_std)
+ trainMetricsDeepSup = None
+ if deepsup: trainMetricsDeepSup = self.metrics[config.MODE_TRAIN_DEEPSUP]
+
+ # Eval Params
+ params_eval = {'PROJECT_DIR': self.params['PROJECT_DIR'], 'exp_name': exp_name, 'pid': self.pid
+ , 'eval_type': config.MODE_TRAIN, 'batch_size': batch_size}
+
+ # Viz params
+ epochs_save = self.params['model']['epochs_save']
+ epochs_viz = self.params['model']['epochs_viz']
+ epochs_eval = self.params['model']['epochs_eval']
+
+ # KL Divergence Params
+ kl_alpha = self.kl_alpha_init # [0.0, self.kl_alpha_init]
+
+ # Random vars
+ t_start_time = time.time()
+
+ epoch = None
+ try:
+
+ epoch_step = 0
+ pbar = None
+ t1 = time.time()
+ for (X,Y,meta1,meta2) in self.dataset_train_gen:
+ t1_ = time.time()
+
+ try:
+ # Epoch starter code
+ if epoch_step == 0:
+
+ # Get Epoch
+ epoch = next(epoch_range)
+
+ # Metrics
+ trainMetrics.reset_metrics(self.params)
+
+ # LR
+ if not fixed_lr:
+ set_lr(epoch, self.optimizer, init_lr)
+ self.model.trainable = True
+
+ # Calculate kl_alpha (commented if alpha is fixed)
+ if self.kl_schedule == config.KL_DIV_ANNEALING:
+ if epoch > self.initial_epoch:
+ if epoch % self.kl_epochs_change == 0:
+ kl_alpha = tf.math.minimum(self.kl_alpha_max, self.kl_alpha_init + (epoch - self.initial_epoch)/float(self.kl_epochs_change) * self.kl_alpha_increase_per_epoch)
+
+ # Pretty print
+ print ('')
+ print (' ===== [{}] EPOCH:{}/{} (LR={:3f}, kl_alpha={:3f}) =================='.format(exp_name, epoch, max_epoch,self.optimizer.lr.numpy(), kl_alpha))
+
+ # Start a fresh pbar
+ pbar = tqdm.tqdm(total=epoch_length, desc='')
+
+ # Model Writing to tensorboard
+ if self.params['model']['model_tboard'] and self.write_model_done is False :
+ self.write_model_done = True
+ utils.write_model_tboard(self.model, X, self.params)
+
+ # Start/Stop Profiling (after dataloader is kicked off)
+ self._set_profiler(epoch, epoch_step)
+
+ # Calculate loss and gradients from them
+ time_predict, time_loss, time_backprop = self._train_step(X, Y, meta1, meta2, tf.constant(kl_alpha, dtype=tf.float32), tf.constant(epoch, dtype=tf.float32))
+
+ # Update metrics (time + eval + plots)
+ time_dataloader = t1_ - t1
+ trainMetrics.update_metrics_time(time_dataloader, time_predict, time_loss, time_backprop)
+
+ # Update looping stuff
+ epoch_step += batch_size
+ pbar.update(batch_size)
+ trainMetrics.update_pbar(pbar)
+
+ except:
+ utils.print_exp_name(exp_name + '-' + config.MODE_TRAIN, epoch)
+ params_save_model['epoch'] = epoch
+ utils.save_model(self.model, params_save_model)
+ traceback.print_exc()
+
+ if epoch_step >= epoch_length:
+
+ # Reset epoch-loop params
+ pbar.close()
+ epoch_step = 0
+
+ try:
+ # Model save
+ if epoch % epochs_save == 0:
+ params_save_model['epoch'] = epoch
+ utils.save_model(self.model, params_save_model)
+
+ # Tensorboard for std
+ if epoch % epochs_save == 0:
+ layers_kl_std = self.get_layers_kl_std(std=True)
+ trainMetrics.write_epoch_summary_std(layers_kl_std, epoch=epoch)
+
+ # Eval on full 3D
+ if epoch % epochs_eval == 0:
+ self.params['epoch'] = epoch
+ save=False
+ if epoch > 0 and epoch % epochs_viz == 0:
+ save=True
+
+ self.model.trainable = False
+ for metric_str in metrics_eval:
+ if metrics_eval[metric_str] in [config.LOSS_DICE]:
+ params_eval['epoch'] = epoch
+ params_eval['deepsup'] = deepsup
+ params_eval['deepsup_eval'] = False
+ params_eval['eval_type'] = config.MODE_TRAIN
+ eval_avg, eval_labels_avg = eval_3D(self.model, self.dataset_train_eval, self.dataset_train_eval_gen, params_eval, save=save)
+ trainMetrics.update_metric_eval_labels(metric_str, eval_labels_avg, do_average=True)
+
+ if deepsup:
+ params_eval['deepsup'] = deepsup
+ params_eval['deepsup_eval'] = True
+ params_eval['eval_type'] = config.MODE_TRAIN_DEEPSUP
+ eval_avg, eval_labels_avg = eval_3D(self.model, self.dataset_train_eval, self.dataset_train_eval_gen, params_eval, save=save)
+ trainMetricsDeepSup.update_metric_eval_labels(metric_str, eval_labels_avg, do_average=True)
+
+ # Test
+ if epoch % epochs_eval == 0:
+ self._test()
+ self.model.trainable = True
+
+ # Epochs summary/wrapup
+ eval_condition = epoch % epochs_eval == 0
+ trainMetrics.write_epoch_summary(epoch, self.label_map, {'optimizer':self.optimizer}, eval_condition)
+ if deepsup:
+ trainMetricsDeepSup.write_epoch_summary(epoch, self.label_map, {'optimizer':self.optimizer}, eval_condition)
+
+ if epoch > 0 and epoch % self.params['others']['epochs_timer'] == 0:
+ elapsed_seconds = time.time() - t_start_time
+ print (' - Total time elapsed : {}'.format( str(datetime.timedelta(seconds=elapsed_seconds)) ))
+ if epoch % self.params['others']['epochs_memory'] == 0:
+ mem_before = utils.get_memory(self.pid)
+ gc_n = gc.collect()
+ mem_after = utils.get_memory(self.pid)
+ print(' - Unreachable objects collected by GC: {} || ({}) -> ({})'.format(gc_n, mem_before, mem_after))
+
+ # Break out of loop at end of all epochs
+ if epoch == max_epoch:
+ print ('\n\n - [Trainer][train()] All epochs finished')
+ break
+
+ except:
+ utils.print_exp_name(exp_name + '-' + config.MODE_TRAIN, epoch)
+ params_save_model['epoch'] = epoch
+ utils.save_model(self.model, params_save_model)
+ traceback.print_exc()
+ pdb.set_trace()
+
+ t1 = time.time() # reset dataloader time calculator
+
+ except:
+ utils.print_exp_name(exp_name + '-' + config.MODE_TRAIN, epoch)
+ traceback.print_exc()
+
+ def _test(self):
+
+ exp_name = None
+ epoch = None
+ try:
+
+ # Step 1.1 - Params
+ exp_name = self.params['exp_name']
+ epoch = self.params['epoch']
+ deepsup = self.params['model']['deepsup']
+
+ metrics_eval = self.params['metrics']['metrics_eval']
+ epochs_viz = self.params['model']['epochs_viz']
+ batch_size = self.params['dataloader']['batch_size']
+
+ # vars
+ testMetrics = self.metrics[config.MODE_TEST]
+ testMetrics.reset_metrics(self.params)
+ testMetricsDeepSup = None
+ if deepsup:
+ testMetricsDeepSup = self.metrics[config.MODE_TEST_DEEPSUP]
+ testMetricsDeepSup.reset_metrics(self.params)
+ params_eval = {'PROJECT_DIR': self.params['PROJECT_DIR'], 'exp_name': exp_name, 'pid': self.pid
+ , 'eval_type': config.MODE_TEST, 'batch_size': batch_size
+ , 'epoch':epoch}
+
+ # Step 2 - Eval on full 3D
+ save=False
+ if epoch > 0 and epoch % epochs_viz == 0:
+ save=True
+ for metric_str in metrics_eval:
+ if metrics_eval[metric_str] in [config.LOSS_DICE]:
+ params_eval['deepsup'] = deepsup
+ params_eval['deepsup_eval'] = False
+ params_eval['eval_type'] = config.MODE_TEST
+ eval_avg, eval_labels_avg = eval_3D(self.model, self.dataset_test_eval, self.dataset_test_eval_gen, params_eval, save=save)
+ testMetrics.update_metric_eval_labels(metric_str, eval_labels_avg, do_average=True)
+
+ if deepsup:
+ params_eval['deepsup'] = deepsup
+ params_eval['deepsup_eval'] = True
+ params_eval['eval_type'] = config.MODE_TEST_DEEPSUP
+ eval_avg, eval_labels_avg = eval_3D(self.model, self.dataset_test_eval, self.dataset_test_eval_gen, params_eval, save=save)
+ testMetricsDeepSup.update_metric_eval_labels(metric_str, eval_labels_avg, do_average=True)
+
+ testMetrics.write_epoch_summary(epoch, self.label_map, {}, True)
+ if deepsup:
+ testMetricsDeepSup.write_epoch_summary(epoch, self.label_map, {}, True)
+
+ except:
+ utils.print_exp_name(exp_name + '-' + config.MODE_TEST, epoch)
+ traceback.print_exc()
+ pdb.set_trace()
+
+if __name__ == "__main__":
+
+ ## Step 0 - Exp Name
+
+ # exp_name = 'Grid_FocusNetFlipV21632GNorm_FixedKL010_33samB214014040PreGridNorm_WithNegCEScalar10_seed42' # <---- This
+
+ # exp_name = 'Fin__FocusNetFlipV2-GNorm-1632-FixedKl010__33sam-B2-PreWindowNorm-24024040__CEScalar10-WithNeg_seed42'
+ # exp_name = 'Fin__FocusNetFlipV2-GNorm-1632-FixedKl010__33sam-B2-PreWindowNorm-24024040__10PA01Thresh05vPU-CEScalar10-WithNeg_seed42'
+
+ # exp_name = 'Grid__FocusNetFlipV2-GNorm-1224-FixedKl010__33sam-B2-PreWindowNorm-14014040__CEScalar10-WithNeg_seed42'
+ # exp_name = 'Grid__FocusNetFlipV2-GNorm-1224-FixedKl010__33sam-B2-PreWindowNorm-14014040__CEScalar10-WithNeg_seed42_v2'
+ # exp_name = 'Grid__FocusNetFlipV2-GNorm-1224-FixedKl005__33sam-B2-PreWindowNorm-14014040__CEScalar10-WithNeg_seed42'
+ # exp_name = 'Grid__FocusNetFlipV2-GNorm-1224-FixedKl001__33sam-B2-PreWindowNorm-14014040__CEScalar10-WithNeg_seed42'
+
+ # exp_name = 'Grid__FocusNetFlipV2-GNorm-1632-FixedKl001__33sam-B2-PreWindowNorm-14014040__05PAvPU-Th05-nAC005-1CEScalar10-WithoutNeg_seed42_v2'
+ # exp_name = 'Grid__FocusNetFlipV2-GNorm-1632-FixedKl001__33sam-B2-PreWindowNorm-14014040__05PAvPU-Th05-nAC005-Annealed-Ep300-Fac10-1CEScalar10-WithNeg_seed42'
+ # exp_name = 'Grid__FocusNetFlipV2-GNorm-1632-FixedKl001__33sam-B2-PreWindowNorm-14014040__05PAvPU-Th04-nAC005-Annealed-Ep300-Fac10-1CEScalar10-WithNeg_seed42'
+ # exp_name = 'Grid__FocusNetFlipV2-GNorm-1632-FixedKl001__33sam-B2-PreWindowNorm-14014040__10PAvPU-Th04-nAC005-Annealed-Ep300-Fac10-1CEScalar10-WithNeg_seed42'
+
+ # exp_name = 'Grid__FocusNetFlipV2-GNorm-1632-FixedKl001__33sam-B2-PreWindowNorm-14014040__05PAvPU-HErr-Th05-Annealed-Ep300-Fac10-1CEScalar10-WithNeg_seed42'
+ # exp_name = 'tmp_Grid__FocusNetFlipV2-GNorm-1632-FixedKl001__33sam-B2-PreWindowNorm-14014040__05PAvPU-HErr-Th05-Annealed-Ep300-Fac10-1CEScalar10-WithNeg_seed42'
+
+ exp_name = 'Grid__FocusNetFlipV2-GNorm-1632-FixedKl001__38sam-B2-PreWindowNorm-14014040__CEScalar10-WithNeg_seed42'
+
+ if 0:
+
+ ## Step 1 - Params
+ params = {
+ 'PROJECT_DIR': config.PROJECT_DIR
+ , 'random_seed':42
+ , 'exp_name': exp_name
+ , 'dataloader':{
+ 'data_dir': Path(config.MAIN_DIR).joinpath('medical_dataloader', '_data')
+ , 'dir_type': [config.DATALOADER_MICCAI2015_TRAIN, config.DATALOADER_MICCAI2015_TRAIN_ADD, config.DATALOADER_MICCAI2015_TESTONSITE]
+ # , 'dir_type' : [config.DATALOADER_MICCAI2015_TRAIN, config.DATALOADER_MICCAI2015_TRAIN_ADD]
+ , 'resampled' : True
+ , 'crop_init' : True
+ , 'grid' : True
+ , 'random_grid' : True
+ , 'filter_grid' : False
+ , 'centred_prob' : 0.3
+ , 'batch_size' : 2 # [2,4,8]
+ , 'shuffle' : 5 # [5,5,12]
+ , 'prefetch_batch': 4 # [4,3,2]
+ , 'parallel_calls': 4 # [4,4,4]
+ , 'single_sample' : False # [ !!!!!!!! WATCH OUT !!!!!!!!! ]
+ }
+ , 'model': {
+ 'name': config.MODEL_FOCUSNET_FLIPOUT_V2 # [MODEL_FOCUSNET_FLIPOUT, MODEL_FOCUSNET_FLIPOUT_V2, MODEL_FOCUSNET_FLIPOUT_POOL2]
+ , 'kl_alpha_init' : 0.01
+ , 'kl_schedule' : config.KL_DIV_FIXED # [config.KL_DIV_FIXED, config.KL_DIV_ANNEALING]
+ , 'kl_scale_factor' : 154 # distribute the kl_div loss 'kl_scale_factor' times across an epoch i.e. dataset_len/batch_size=>
+ # |(240,240,80)= 78.0/2.0|, [240,240,40]=66/2, |(140,140,40)=308/2.0=154, 264/2=132|, |(96,96,40)=684/2.0=342|
+ , 'deepsup' : False
+ , 'activation': tf.keras.activations.softmax # ['softmax', 'tf.keras.activations.sigmoid']
+ , 'kernel_reg': False
+ , 'optimizer' : config.OPTIMIZER_ADAM
+ , 'grad_persistent': False
+ , 'init_lr' : 0.001 # 0.01 # 0.005
+ , 'fixed_lr' : False
+ , 'epochs' : 1500 # 1000
+ , 'epochs_save': 50 # 20
+ , 'epochs_eval': 50 # 40
+ , 'epochs_viz' : 500 # 100
+ , 'load_model':{
+ 'load':False, 'load_exp_name': None, 'load_epoch':-1, 'load_optimizer_lr':None # DEFAULT OPTION
+ # 'load_exp_name':'Grid_FocusNetFlip1632_E250KL0001_WSum1My_14014040_NoAug_FocalScalar09_seed42', 'load':True, 'load_epoch':200
+ # 'load_exp_name':None, 'load':True, 'load_epoch':100
+ }
+ , 'profiler': {
+ 'profile': False
+ , 'epochs': [252,253]
+ , 'steps_per_epoch': 30
+ , 'starting_step': 2
+ }
+ , 'model_tboard': False
+ }
+ , 'metrics' : {
+ 'logging_tboard': True
+ # for full 3D volume
+ , 'metrics_eval': {'Dice': config.LOSS_DICE}
+ ## for smaller grid
+ # , 'metrics_loss':{'Dice': config.LOSS_DICE} # [config.LOSS_CE, config.LOSS_DICE, config.LOSS_FOCAL]
+ # , 'loss_weighted': {'Dice': True}
+ # , 'loss_mask': {'Dice':True}
+ # , 'loss_type': {'Dice': config.LOSS_VECTOR} # [config.LOSS_SCALAR, config.LOSS_VECTOR]
+ # , 'loss_combo': {'Dice':1.0}
+ , 'metrics_loss' : {'CE': config.LOSS_CE}
+ , 'loss_type' : {'CE': config.LOSS_SCALAR}
+ , 'loss_weighted': {'CE': True}
+ , 'loss_mask' : {'CE': True}
+ , 'loss_combo' : {'CE': 1.0}
+ , 'loss_epoch' : {'CE': 0}
+ , 'loss_rate' : {'CE': 1}
+ # , 'metrics_loss' : {'Focal': config.LOSS_FOCAL}
+ # , 'loss_type' : {'Focal': config.LOSS_SCALAR}
+ # , 'loss_weighted': {'Focal': True}
+ # , 'loss_mask' : {'Focal': True}
+ # , 'loss_combo' : {'Focal': 1.0}
+ # , 'metrics_loss':{'Focal': config.LOSS_FOCAL, 'Dice': config.LOSS_DICE} # [config.LOSS_CE, config.LOSS_DICE, config.LOSS_FOCAL]
+ # , 'loss_type': {'Focal': config.LOSS_SCALAR, 'Dice': config.LOSS_SCALAR}
+ # , 'loss_weighted': {'Focal': True, 'Dice': False}
+ # , 'loss_mask': {'Focal':True, 'Dice': True}
+ # , 'loss_combo': {'Focal': 1.0, 'Dice': 1.0} #[, ['Dice', 'CE']]
+ # , 'metrics_loss' : {'CE': config.LOSS_CE, 'pavpu': config.LOSS_PAVPU} # [config.LOSS_CE, config.LOSS_DICE, config.LOSS_FOCAL]
+ # , 'loss_weighted' : {'CE': True, 'pavpu': False}
+ # , 'loss_mask' : {'CE': True, 'pavpu': False}
+ # , 'loss_type' : {'CE': config.LOSS_SCALAR, 'pavpu': config.LOSS_SCALAR}
+ # , 'loss_combo' : {'CE': 1.0, 'pavpu': 5.0}
+ # , 'loss_epoch' : {'CE': 0 , 'pavpu': 300}
+ # , 'loss_rate' : {'CE': 1 , 'pavpu':10}
+ }
+ , 'others': {
+ 'epochs_timer': 20
+ , 'epochs_memory':5
+ }
+ }
+
+ # Call the trainer
+ trainer = Trainer(params)
+ trainer.train()
+
+ elif 1:
+
+ # Step 1 - Dataloaders
+ resampled = True
+ single_sample = False
+ grid = True
+ crop_init = True
+ data_dir = Path(config.MAIN_DIR).joinpath('medical_dataloader', '_data')
+ if 0:
+ # dir_type = [config.DATALOADER_MICCAI2015_TRAIN, config.DATALOADER_MICCAI2015_TRAIN_ADD, config.DATALOADER_MICCAI2015_TESTONSITE]; eval_type = config.MODE_TRAIN_VAL; dir_type = ['_'.join(dir_type)]
+ dir_type = [config.DATALOADER_MICCAI2015_TEST]; eval_type = config.MODE_TEST
+ # dir_type = [config.DATALOADER_MICCAI2015_TESTONSITE]; eval_type = config.MODE_TEST
+
+ dataset_test_eval = get_dataloader_3D_test_eval(data_dir, dir_type=dir_type
+ , grid=grid, crop_init=crop_init
+ , resampled=resampled, single_sample=single_sample
+ )
+ else:
+ dir_type = [medconfig.DATALOADER_DEEPMINDTCIA_TEST]; annotator_type = [medconfig.DATALOADER_DEEPMINDTCIA_ONC]; eval_type = config.MODE_DEEPMINDTCIA_TEST_ONC
+ # dir_type = [medconfig.DATALOADER_DEEPMINDTCIA_TEST]; annotator_type = [medconfig.DATALOADER_DEEPMINDTCIA_RAD]; eval_type = config.MODE_DEEPMINDTCIA_TEST_RAD
+ dataset_test_eval = get_dataloader_deepmindtcia(data_dir, dir_type=dir_type, annotator_type=annotator_type
+ , grid=grid, crop_init=crop_init, resampled=resampled
+ , single_sample=single_sample
+ )
+
+ # Step 2 - Model Loading
+ class_count = len(dataset_test_eval.datasets[0].LABEL_MAP.values())
+ activation = 'softmax' # tf.keras.activations.softmax
+ deepsup = False
+ # model = models.ModelFocusNetFlipOutPool2(class_count=class_count, trainable=False, activation=activation)
+ # model = models.ModelFocusNetFlipOut(class_count=class_count, trainable=False, activation=activation)
+ model = models.ModelFocusNetFlipOutV2(class_count=class_count, trainable=False, activation=activation)
+
+ # Step 3 - Final function call
+ params = {
+ 'PROJECT_DIR' : config.PROJECT_DIR
+ , 'exp_name' : exp_name
+ , 'pid' : os.getpid()
+ , 'batch_size' : 2
+ , 'prefetch_batch': 1
+ , 'epoch' : 500 # [1000, 1500, 2500]
+ , 'eval_type' : eval_type
+ , 'MC_RUNS' : 30
+ , 'deepsup' : deepsup
+ , 'deepsup_eval' : False
+ , 'training_bool' : True # [True=dropout-at-test-time, False=no-dropout-at-test-time]
+ }
+ print ('\n - eval_type: ', params['eval_type'])
+ print (' - epoch : ', params['epoch'])
+ print (' - MC_RUNS : ', params['MC_RUNS'])
+ val(model, dataset_test_eval, params, show=False, save=True, verbose=False)
+
+
diff --git a/src/model/utils.py b/src/model/utils.py
new file mode 100644
index 0000000..3c47491
--- /dev/null
+++ b/src/model/utils.py
@@ -0,0 +1,585 @@
+# Import internal libraries
+import src.config as config
+
+# Import external libraries
+import os
+import pdb
+import copy
+import tqdm
+import json
+import psutil
+import humanize
+import traceback
+import numpy as np
+import matplotlib
+from pathlib import Path
+import tensorflow as tf
+
+from src.config import PROJECT_DIR
+
+############################################################
+# MODEL RELATED #
+############################################################
+def save_model(model, params):
+ """
+ The phrase "Saving a TensorFlow model" typically means one of two things:
+ - using the Checkpoints format (this code does checkpointing)
+ - using the SavedModel format.
+ - Ref: https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint
+ - ckpt_obj = tf.train.Checkpoint(optimizer=optimizer, model=model); ckpt_obj.save(file_prefix='my_model_ckpt')
+ - tf.keras.Model.save_weights('my_model_save_weights')
+ - tf.keras.Model.save('my_model_save')
+ - Questions
+ - What is the SavedModel format
+ - SavedModel saves the execution graph.
+ """
+ try:
+ PROJECT_DIR = params['PROJECT_DIR']
+ exp_name = params['exp_name']
+ epoch = params['epoch']
+
+ folder_name = config.MODEL_CHKPOINT_NAME_FMT.format(epoch)
+ model_folder = Path(PROJECT_DIR).joinpath(config.MODEL_CHKPOINT_MAINFOLDER, exp_name, folder_name)
+ model_folder.mkdir(parents=True, exist_ok=True)
+ model_path = Path(model_folder).joinpath(folder_name)
+
+ optimizer = params['optimizer']
+ ckpt_obj = tf.train.Checkpoint(optimizer=optimizer, model=model)
+ ckpt_obj.save(file_prefix=model_path)
+
+ if 0: # CHECK
+ model.save(str(Path(model_path).joinpath('model_save'))) # SavedModel format: creates a folder "model_save" with assets/, variables/ (contains weights) and a saved_model.pb (model architecture)
+ model.save_weights(str(Path(model_path).joinpath('model_save_weights')))
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+def load_model(model, load_type, params):
+ """
+ - Ref: https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint
+ """
+ try:
+
+ PROJECT_DIR = params['PROJECT_DIR']
+ exp_name = params['exp_name']
+ load_epoch = params['load_epoch']
+
+ folder_name = config.MODEL_CHKPOINT_NAME_FMT.format(load_epoch)
+ model_folder = Path(PROJECT_DIR).joinpath(config.MODEL_CHKPOINT_MAINFOLDER, exp_name, folder_name)
+
+ if load_type == config.MODE_TRAIN:
+ if 'optimizer' in params:
+ ckpt_obj = tf.train.Checkpoint(optimizer=params['optimizer'], model=model)
+ # ckpt_obj.restore(save_path=tf.train.latest_checkpoint(str(model_folder))).assert_existing_objects_matched() # shows errors
+ # ckpt_obj.restore(save_path=tf.train.latest_checkpoint(str(model_folder))).assert_consumed()
+ ckpt_obj.restore(save_path=tf.train.latest_checkpoint(str(model_folder))).expect_partial()
+ else:
+ print (' - [ERROR][utils.load_model] Optimizer not passed !')
+ pdb.set_trace()
+
+ elif load_type in [config.MODE_VAL, config.MODE_TEST]:
+ ckpt_obj = tf.train.Checkpoint(model=model)
+ ckpt_obj.restore(save_path=tf.train.latest_checkpoint(str(model_folder))).expect_partial()
+
+ elif load_type == config.MODE_VAL_NEW:
+ model.load_weights(str(Path(model_folder).joinpath('model.h5')), by_name=True, skip_mismatch=False)
+
+ else:
+ print (' - [ERROR][utils.load_model] It should not be here!')
+ pdb.set_trace()
+ # tf.keras.Model.load_weights
+ # tf.train.list_variables(tf.train.latest_checkpoint(str(model_folder)))
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+def get_tensorboard_writer(exp_name, suffix):
+ try:
+ import tensorflow as tf
+
+ logdir = Path(config.MODEL_CHKPOINT_MAINFOLDER).joinpath(exp_name, config.MODEL_LOGS_FOLDERNAME, suffix)
+ writer = tf.summary.create_file_writer(str(logdir))
+ return writer
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+def make_summary(fieldname, epoch, writer1=None, value1=None, writer2=None, value2=None):
+ try:
+ import tensorflow as tf
+
+ if writer1 is not None and value1 is not None:
+ with writer1.as_default():
+ tf.summary.scalar(fieldname, value1, epoch)
+ writer1.flush()
+ if writer2 is not None and value2 is not None:
+ with writer2.as_default():
+ tf.summary.scalar(fieldname, value2, epoch)
+ writer2.flush()
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+def make_summary_hist(fieldname, epoch, writer1=None, value1=None, writer2=None, value2=None):
+ try:
+ import tensorflow as tf
+
+ if writer1 is not None and value1 is not None:
+ with writer1.as_default():
+ tf.summary.histogram(fieldname, value1, epoch)
+ writer1.flush()
+ if writer2 is not None and value2 is not None:
+ with writer2.as_default():
+ tf.summary.histogram(fieldname, value2, epoch)
+ writer2.flush()
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+def write_model(model, X, params, suffix='model'):
+ """
+ - Ref:
+ - https://www.tensorflow.org/api_docs/python/tf/summary/trace_on
+ - https://www.tensorflow.org/api_docs/python/tf/summary/trace_export
+ - https://www.tensorflow.org/api_docs/python/tf/keras/utils/plot_model
+ - https://stackoverflow.com/questions/56690089/how-to-graph-tf-keras-model-in-tensorflow-2-0
+ """
+
+ # Step 1 - Start trace
+ tf.summary.trace_on(graph=True, profiler=False)
+
+ # Step 2 - Perform operation
+ _ = write_model_trace(model, X)
+
+ # Step 3 - Export trace
+ writer = get_tensorboard_writer(params['exp_name'], suffix)
+ with writer.as_default():
+ tf.summary.trace_export(name=model.name, step=0, profiler_outdir=None)
+ writer.flush()
+
+ # Step 4 - Save as .png
+ # tf.keras.utils.plot_model(model, to_file='model.png', show_shapes=True, expand_nested=True) # only works for the functional API. :(
+
+@tf.function
+def write_model_trace(model, X):
+ return model(X)
+
+def set_lr(epoch, optimizer):
+ # if epoch == 200:
+ # optimizer.lr.assign(0.0001)
+ pass
+
+############################################################
+# WRITE RELATED #
+############################################################
+
+class NpEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, np.integer):
+ return int(obj)
+ elif isinstance(obj, np.floating):
+ return float(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ else:
+ return super(NpEncoder, self).default(obj)
+
+def write_json(json_data, json_filepath):
+
+ Path(json_filepath).parent.absolute().mkdir(parents=True, exist_ok=True)
+
+ with open(str(json_filepath), 'w') as fp:
+ json.dump(json_data, fp, indent=4, cls=NpEncoder)
+
+############################################################
+# DEBUG RELATED #
+############################################################
+def print_break_msg():
+ print ('')
+ print (' ========================= break operator applied here =========================')
+ print ('')
+
+def get_memory_usage(filename):
+ mem = os.popen("ps aux | grep %s | awk '{sum=sum+$6}; END {print sum/1024 \" MB\"}'"% (filename)).read()
+ return mem.rstrip()
+
+def print_exp_name(exp_name, epoch):
+ print ('')
+ print (' [ERROR] ========================= {} (epoch={}) ========================='.format(exp_name, epoch))
+ print ('')
+
+def get_memory(pid):
+ try:
+ process = psutil.Process(pid)
+ return humanize.naturalsize(process.memory_info().rss)
+ except:
+ return '-1'
+
+def get_tf_gpu_memory():
+ # https://www.tensorflow.org/api_docs/python/tf/config/experimental/get_memory_usage
+ gpu_devices = tf.config.list_physical_devices('GPU')
+ if len(gpu_devices):
+ memory_bytes = tf.config.experimental.get_memory_usage(gpu_devices[0].name.split('/physical_device:')[-1])
+ return 'GPU: {:2f}GB'.format(memory_bytes/1024.0/1024.0/1024.0)
+ else:
+ return '-1'
+
+############################################################
+# DATALOADER RELATED #
+############################################################
+def get_info_from_label_id(label_id, label_map, label_colors=None):
+ """
+ The label_id param has to be greater than 0
+ """
+
+ label_name = [label for label in label_map if label_map[label] == label_id]
+ if len(label_name):
+ label_name = label_name[0]
+ else:
+ label_name = None
+
+ if label_colors is not None:
+ label_color = np.array(label_colors[label_id])
+ if np.any(label_color > 0):
+ label_color = label_color/255.0
+ else:
+ label_color = None
+
+ return label_name, label_color
+
+def cmap_for_dataset(label_colors):
+ cmap_me = matplotlib.colors.ListedColormap(np.array([*label_colors.values()])/255.0)
+ norm = matplotlib.colors.BoundaryNorm(boundaries=range(0,cmap_me.N+1), ncolors=cmap_me.N)
+
+ return cmap_me, norm
+
+############################################################
+# EVAL RELATED #
+############################################################
+def get_eval_folders(PROJECT_DIR, exp_name, epoch, mode, mc_runs=None, training_bool=None, create=False):
+ folder_name = config.MODEL_CHKPOINT_NAME_FMT.format(epoch)
+
+ if mc_runs is not None and training_bool is not None: # During manual eval
+ if mc_runs == 1 and training_bool == False:
+ model_folder_epoch_save = Path(PROJECT_DIR).joinpath(config.MODEL_CHKPOINT_MAINFOLDER, exp_name, folder_name, config.MODEL_IMGS_FOLDERNAME, mode + config.SUFFIX_DET)
+ else:
+ model_folder_epoch_save = Path(PROJECT_DIR).joinpath(config.MODEL_CHKPOINT_MAINFOLDER, exp_name, folder_name, config.MODEL_IMGS_FOLDERNAME, mode + config.SUFFIX_MC.format(mc_runs))
+ else: # During automated training + eval
+ model_folder_epoch_save = Path(PROJECT_DIR).joinpath(config.MODEL_CHKPOINT_MAINFOLDER, exp_name, folder_name, config.MODEL_IMGS_FOLDERNAME, mode)
+
+ model_folder_epoch_patches = Path(model_folder_epoch_save).joinpath('patches')
+ model_folder_epoch_imgs = Path(model_folder_epoch_save).joinpath('imgs')
+
+ if create:
+ Path(model_folder_epoch_patches).mkdir(parents=True, exist_ok=True)
+ Path(model_folder_epoch_imgs).mkdir(parents=True, exist_ok=True)
+
+ return model_folder_epoch_patches, model_folder_epoch_imgs
+
+############################################################
+# 3D #
+############################################################
+
+def viz_model_output_3d_old(X, y_true, y_predict, y_predict_std, patient_id, path_save, label_map, label_colors, VMAX_STD=0.3):
+ """
+ X : [H,W,D,1]
+ y_true: [H,W,D,Labels]
+
+ Takes only a single batch of data
+ """
+ try:
+ import matplotlib
+ import matplotlib.pyplot as plt
+ import skimage
+ import skimage.measure
+
+ Path(path_save).mkdir(exist_ok=True, parents=True)
+
+ labels_ids = sorted(list(label_map.values()))
+ cmap_me, norm_me = cmap_for_dataset(label_colors)
+ cmap_img = 'rainbow'
+ VMIN = 0
+ VMAX_STD = VMAX_STD
+ VMAX_MEAN = 1.0
+
+ slice_ids = list(range(X.shape[2]))
+ # with tqdm.tqdm(total=len(slice_ids), leave=False, desc='ID:' + patient_id) as pbar_slices:
+ with tqdm.tqdm(total=len(slice_ids), leave=False) as pbar_slices:
+ for slice_id in range(X.shape[2]):
+
+ # Data
+ X_slice = X[:,:,slice_id,0]
+ y_true_slice = y_true[:,:,slice_id,:]
+ y_predict_slice = y_predict[:,:,slice_id,:]
+ y_predict_std_slice = y_predict_std[:,:,slice_id,:]
+
+ # Matplotlib figure
+ filename = '\n'.join(patient_id.split('-'))
+ suptitle_str = 'Slice: {}'.format(filename)
+ fig = plt.figure(figsize=(15,15), dpi=200)
+ fig_std = plt.figure(figsize=(15,15), dpi=200)
+ spec = fig.add_gridspec(nrows=2 + np.ceil(len(labels_ids)/5).astype(int), ncols=5)
+ spec_std = fig_std.add_gridspec(nrows=2 + np.ceil(len(labels_ids)/5).astype(int), ncols=5)
+ fig.suptitle(suptitle_str + '\n Predictive Mean')
+ fig_std.suptitle(suptitle_str + '\n Predictive Std')
+
+ # Top two images
+ ax3 = fig.add_subplot(spec[0, 4])
+ ax3.imshow(X_slice, cmap='gray')
+ ax3.axis('off')
+ ax3.set_title('Raw data')
+ ax3_std = fig_std.add_subplot(spec_std[0, 4])
+ ax3_std.imshow(X_slice, cmap='gray')
+ ax3_std.axis('off')
+ ax3_std.set_title('Raw data')
+ img_slice_mask_plot = np.zeros(y_true_slice[:,:,0].shape)
+
+ # Other images
+ i,j = 2,0
+ for label_id in labels_ids:
+ if label_id not in config.IGNORE_LABELS:
+
+ # Get ground-truth and prediction slices
+ img_slice_mask_predict = copy.deepcopy(y_predict_slice[:,:,label_id])
+ img_slice_mask_predict_std = copy.deepcopy(y_predict_std_slice[:,:,label_id])
+ img_slice_mask_gt = copy.deepcopy(y_true_slice[:,:,label_id])
+
+ # Plot prediction heatmap
+ if j >= 5:
+ i = i + 1
+ j = 0
+ ax = fig.add_subplot(spec[i,j])
+ ax_std = fig_std.add_subplot(spec_std[i,j])
+ j += 1
+
+ # img_slice_mask_predict[img_slice_mask_predict > config.PREDICT_THRESHOLD_MASK] = label_id
+ ax.imshow(img_slice_mask_predict, interpolation='none', cmap=cmap_img, vmin=0, vmax=VMAX_MEAN)
+ ax_std.imshow(img_slice_mask_predict_std, interpolation='none', cmap=cmap_img, vmin=0, vmax=VMAX_STD)
+
+ # Plot Gt contours
+ label, color = get_info_from_label_id(label_id, label_map, label_colors)
+ if label_id == 0:
+ ax4 = fig.add_subplot(spec[1, 4])
+ fig.colorbar(matplotlib.cm.ScalarMappable(cmap=cmap_img), ax=ax4)
+ ax4.axis('off')
+ ax4.set_title('Colorbar')
+
+ ax4_std = fig_std.add_subplot(spec_std[1, 4])
+ fig_std.colorbar(matplotlib.cm.ScalarMappable(cmap=cmap_img, norm=matplotlib.colors.Normalize(vmin=0, vmax=VMAX_STD)), ax=ax4_std)
+ ax4_std.axis('off')
+ ax4_std.set_title('Colorbar')
+
+ else:
+ contours_mask = skimage.measure.find_contours(img_slice_mask_gt, level=0.99)
+ for _, contour_mask in enumerate(contours_mask):
+ ax3.plot(contour_mask[:, 1], contour_mask[:, 0], linewidth=2, color=color)
+ ax3_std.plot(contour_mask[:, 1], contour_mask[:, 0], linewidth=2, color=color)
+
+ if label is not None:
+ ax.set_title(label + '(' + str(label_id) + ')', color=color)
+ ax_std.set_title(label + '(' + str(label_id) + ')', color=color)
+ ax.axis('off')
+ ax_std.axis('off')
+
+ # Gather GT mask
+ idxs_gt = np.argwhere(img_slice_mask_gt > 0)
+ img_slice_mask_plot[idxs_gt[:,0], idxs_gt[:,1]] = label_id
+
+ # GT mask
+ ax1 = fig.add_subplot(spec[0:2, 0:2])
+ ax1.imshow(img_slice_mask_plot, cmap=cmap_me, norm=norm_me, interpolation='none')
+ ax1.set_title('GT Mask')
+ ax1.tick_params(labelsize=6)
+ ax1_std = fig_std.add_subplot(spec_std[0:2, 0:2])
+ ax1_std.imshow(img_slice_mask_plot, cmap=cmap_me, norm=norm_me, interpolation='none')
+ ax1_std.set_title('GT Mask')
+ ax1_std.tick_params(labelsize=6)
+
+ # Predicted Mask
+ ax2 = fig.add_subplot(spec[0:2, 2:4])
+ ax2.imshow(np.argmax(y_predict_slice, axis=2), cmap=cmap_me, norm=norm_me, interpolation='none')
+ ax2.set_title('Predicted Mask (mean)')
+ ax2.tick_params(labelsize=6)
+ ax2_std = fig_std.add_subplot(spec_std[0:2, 2:4])
+ ax2_std.imshow(np.argmax(y_predict_slice, axis=2), cmap=cmap_me, norm=norm_me, interpolation='none')
+ ax2_std.set_title('Predicted Mask (mean)')
+ ax2_std.tick_params(labelsize=6)
+
+ # Show and save
+ # path_savefig = Path(model_folder_epoch_images).joinpath(filename_meta.replace('.npy','.png'))
+ path_savefig = Path(path_save).joinpath(patient_id + '_' + '%.3d' % (slice_id) + '_mean.png')
+ fig.savefig(str(path_savefig), bbox_inches='tight')
+ path_savefig_std = Path(path_save).joinpath(patient_id + '_' + '%.3d' % (slice_id) + '_std.png')
+ fig_std.savefig(str(path_savefig_std), bbox_inches='tight')
+ plt.close(fig=fig)
+ plt.close(fig=fig_std)
+ pbar_slices.update(1)
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
+def viz_model_output_3d(exp_name, X, y_true, y_predict, y_predict_unc, patient_id, path_save, label_map, label_colors, vmax_unc=0.3, unc_title='', unc_savesufix=''):
+ """
+ X : [H,W,D,1]
+ y_true: [H,W,D,Labels]
+
+ Takes only a single batch of data
+ """
+ try:
+ import matplotlib
+ import matplotlib.pyplot as plt
+ import skimage
+ import skimage.measure
+
+ Path(path_save).mkdir(exist_ok=True, parents=True)
+
+ labels_ids = sorted(list(label_map.values()))
+ cmap_me, norm_me = cmap_for_dataset(label_colors)
+ cmap_img = 'rainbow'
+ VMIN = 0
+ VMAX_UNC = vmax_unc
+ VMAX_MEAN = 1.0
+
+ slice_ids = list(range(X.shape[2]))
+ # with tqdm.tqdm(total=len(slice_ids), leave=False, desc='ID:' + patient_id) as pbar_slices:
+ with tqdm.tqdm(total=len(slice_ids), leave=False) as pbar_slices:
+ for slice_id in range(X.shape[2]):
+
+ if slice_id < 20:
+ pbar_slices.update(1)
+ continue
+
+ # Data
+ X_slice = X[:,:,slice_id,0]
+ y_true_slice = y_true[:,:,slice_id,:]
+ y_predict_slice = y_predict[:,:,slice_id,:]
+ y_predict_unc_slice = y_predict_unc[:,:,slice_id,:]
+
+ # Matplotlib figure
+ filename = '\n'.join(patient_id.split('-'))
+ suptitle_str = 'Exp: {}\nPatient: {}\nSlice: {}'.format(exp_name, filename, slice_id)
+ fig = plt.figure(figsize=(15,15), dpi=200)
+ fig_unc = plt.figure(figsize=(15,15), dpi=200)
+ spec = fig.add_gridspec(nrows=2 + np.ceil(len(labels_ids)/5).astype(int), ncols=5)
+ spec_unc = fig_unc.add_gridspec(nrows=2 + np.ceil(len(labels_ids)/5).astype(int), ncols=5)
+ fig.suptitle(suptitle_str + '\n Predictive Mean')
+ fig_unc.suptitle(suptitle_str + '\n {}'.format(unc_title))
+
+ # Top two images
+ ax3 = fig.add_subplot(spec[0, 4])
+ ax3.imshow(X_slice, cmap='gray')
+ ax3.axis('off')
+ ax3.set_title('Raw data')
+ ax3_unc = fig_unc.add_subplot(spec_unc[0, 4])
+ ax3_unc.imshow(X_slice, cmap='gray')
+ ax3_unc.axis('off')
+ ax3_unc.set_title('Raw data')
+ img_slice_mask_plot = np.zeros(y_true_slice[:,:,0].shape)
+
+ # Other images
+ i,j = 2,0
+ for label_id in labels_ids:
+ if label_id not in config.IGNORE_LABELS:
+
+ # Get ground-truth and prediction slices
+ img_slice_mask_predict = copy.deepcopy(y_predict_slice[:,:,label_id])
+ img_slice_mask_predict_unc = copy.deepcopy(y_predict_unc_slice[:,:,label_id])
+ img_slice_mask_gt = copy.deepcopy(y_true_slice[:,:,label_id])
+
+ # Plot prediction heatmap
+ if j >= 5:
+ i = i + 1
+ j = 0
+ ax = fig.add_subplot(spec[i,j])
+ # ax_unc = fig_unc.add_subplot(spec_unc[i,j])
+ j += 1
+
+ # img_slice_mask_predict[img_slice_mask_predict > config.PREDICT_THRESHOLD_MASK] = label_id
+ ax.imshow(img_slice_mask_predict, interpolation='none', cmap=cmap_img, vmin=0, vmax=VMAX_MEAN)
+ # ax_unc.imshow(img_slice_mask_predict_unc, interpolation='none', cmap=cmap_img, vmin=0, vmax=VMAX_UNC)
+
+ # Plot Gt contours
+ label, color = get_info_from_label_id(label_id, label_map, label_colors)
+ if label_id == 0:
+ ax4 = fig.add_subplot(spec[1, 4])
+ fig.colorbar(matplotlib.cm.ScalarMappable(cmap=cmap_img), ax=ax4)
+ ax4.axis('off')
+ ax4.set_title('Colorbar')
+
+ ax4_unc = fig_unc.add_subplot(spec_unc[1, 4])
+ fig_unc.colorbar(matplotlib.cm.ScalarMappable(cmap=cmap_img, norm=matplotlib.colors.Normalize(vmin=0, vmax=VMAX_UNC)), ax=ax4_unc)
+ ax4_unc.axis('off')
+ ax4_unc.set_title('Colorbar')
+
+ else:
+ contours_mask = skimage.measure.find_contours(img_slice_mask_gt, level=0.99)
+ for _, contour_mask in enumerate(contours_mask):
+ ax3.plot(contour_mask[:, 1], contour_mask[:, 0], linewidth=2, color=color)
+ ax3_unc.plot(contour_mask[:, 1], contour_mask[:, 0], linewidth=2, color=color)
+
+ if label is not None:
+ ax.set_title(label + '(' + str(label_id) + ')', color=color)
+ # ax_unc.set_title(label + '(' + str(label_id) + ')', color=color)
+ ax.axis('off')
+ # ax_unc.axis('off')
+
+ # Gather GT mask
+ idxs_gt = np.argwhere(img_slice_mask_gt > 0)
+ img_slice_mask_plot[idxs_gt[:,0], idxs_gt[:,1]] = label_id
+
+ # if 'ent' in unc_savesufix:
+ # ax_unc.cla()
+
+ # GT mask
+ ax1 = fig.add_subplot(spec[0:2, 0:2])
+ ax1.imshow(img_slice_mask_plot, cmap=cmap_me, norm=norm_me, interpolation='none')
+ ax1.set_title('GT Mask')
+ ax1.tick_params(labelsize=6)
+ ax1_unc = fig_unc.add_subplot(spec_unc[0:2, 0:2])
+ ax1_unc.imshow(img_slice_mask_plot, cmap=cmap_me, norm=norm_me, interpolation='none')
+ ax1_unc.set_title('GT Mask')
+ ax1_unc.tick_params(labelsize=6)
+
+ # Predicted Mask
+ ax2 = fig.add_subplot(spec[0:2, 2:4])
+ ax2.imshow(np.argmax(y_predict_slice, axis=2), cmap=cmap_me, norm=norm_me, interpolation='none')
+ ax2.set_title('Predicted Mask (mean)')
+ ax2.tick_params(labelsize=6)
+ ax2_unc = fig_unc.add_subplot(spec_unc[0:2, 2:4])
+ ax2_unc.imshow(np.argmax(y_predict_slice, axis=2), cmap=cmap_me, norm=norm_me, interpolation='none')
+ ax2_unc.set_title('Predicted Mask (mean)')
+ ax2_unc.tick_params(labelsize=6)
+
+ # Specifically for entropy
+ if unc_savesufix in ['stdmax', 'ent', 'mif']:
+
+ ax_unc_gt = fig_unc.add_subplot(spec[2:4, 0:2])
+ ax_unc_gt.imshow(img_slice_mask_predict_unc, interpolation='none', cmap=cmap_img, vmin=0, vmax=VMAX_UNC)
+ slice_binary_gt = img_slice_mask_plot
+ slice_binary_gt[slice_binary_gt > 0] = 1
+ ax_unc_gt.imshow(slice_binary_gt, cmap='gray', interpolation='none', alpha=0.3)
+
+ ax_unc_pred = fig_unc.add_subplot(spec[2:4, 2:4])
+ ax_unc_pred.imshow(img_slice_mask_predict_unc, interpolation='none', cmap=cmap_img, vmin=0, vmax=VMAX_UNC)
+ slice_binary_pred = np.argmax(y_predict_slice, axis=2)
+ slice_binary_pred[slice_binary_pred > 0] = 1
+ ax_unc_pred.imshow(slice_binary_pred, cmap='gray', interpolation='none', alpha=0.3)
+
+ # Show and save
+ # path_savefig = Path(model_folder_epoch_images).joinpath(filename_meta.replace('.npy','.png'))
+ path_savefig = Path(path_save).joinpath(patient_id + '_' + '%.3d' % (slice_id) + '_mean.png')
+ fig.savefig(str(path_savefig), bbox_inches='tight')
+ path_savefig_unc = Path(path_save).joinpath(patient_id + '_' + '%.3d' % (slice_id) + '_{}.png'.format(unc_savesufix))
+ fig_unc.savefig(str(path_savefig_unc), bbox_inches='tight')
+ plt.close(fig=fig)
+ plt.close(fig=fig_unc)
+ pbar_slices.update(1)
+
+ except:
+ traceback.print_exc()
+ pdb.set_trace()
+
\ No newline at end of file