From 25797ad5fa3be4a399a51b81115d8c2ac2c40c80 Mon Sep 17 00:00:00 2001 From: Mody-SHARK Date: Tue, 30 Nov 2021 09:19:13 +0100 Subject: [PATCH] First commit --- .gitignore | 9 + README.md | 49 + __init__.py | 0 demo/demo_dropout_ce.py | 105 + demo/demo_dropout_cebasic.py | 105 + demo/demo_dropout_dice.py | 105 + demo/demo_flipout_ce.py | 109 + src/__init__.py | 0 src/config.py | 401 +++ src/dataloader/__init__.py | 0 src/dataloader/augmentations.py | 592 ++++ src/dataloader/dataset.py | 53 + src/dataloader/extractors/__init__.py | 0 src/dataloader/extractors/han_deepmindtcia.py | 339 +++ src/dataloader/extractors/han_miccai2015.py | 339 +++ src/dataloader/han_deepmindtcia.py | 726 +++++ src/dataloader/han_miccai2015.py | 765 ++++++ src/dataloader/utils.py | 1273 +++++++++ src/dataloader/utils_viz.py | 437 +++ src/model/__init__.py | 0 src/model/losses.py | 488 ++++ src/model/models.py | 572 ++++ src/model/trainer.py | 1718 ++++++++++++ src/model/trainer_flipout.py | 2435 +++++++++++++++++ src/model/utils.py | 585 ++++ 25 files changed, 11205 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 __init__.py create mode 100644 demo/demo_dropout_ce.py create mode 100644 demo/demo_dropout_cebasic.py create mode 100644 demo/demo_dropout_dice.py create mode 100644 demo/demo_flipout_ce.py create mode 100644 src/__init__.py create mode 100644 src/config.py create mode 100644 src/dataloader/__init__.py create mode 100644 src/dataloader/augmentations.py create mode 100644 src/dataloader/dataset.py create mode 100644 src/dataloader/extractors/__init__.py create mode 100644 src/dataloader/extractors/han_deepmindtcia.py create mode 100644 src/dataloader/extractors/han_miccai2015.py create mode 100644 src/dataloader/han_deepmindtcia.py create mode 100644 src/dataloader/han_miccai2015.py create mode 100644 src/dataloader/utils.py create mode 100644 src/dataloader/utils_viz.py create mode 100644 src/model/__init__.py create mode 100644 src/model/losses.py create mode 100644 src/model/models.py create mode 100644 src/model/trainer.py create mode 100644 src/model/trainer_flipout.py create mode 100644 src/model/utils.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ad9170e --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +# Python +__pycache__ +.pyc + + +# Modeling +_data +_models +_logs \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..496df7e --- /dev/null +++ b/README.md @@ -0,0 +1,49 @@ +# Bayesian Uncertainty for Quality Assessment of Deep Learning Contours +This repository contains Tensorflow2.4 code for the paper(s) + - Comparing Bayesian Models for Organ Contouring in Headand Neck Radiotherapy + + +## Installation +1. Install [Anaconda](https://docs.anaconda.com/anaconda/install/) with python3.7 +2. Install [git](https://git-scm.com/downloads) +3. Open a terminal and follow the commands + - Clone this repository + - `git clone git@github.com:prerakmody/hansegmentation-uncertainty-qa.git` + - Create conda env + - (Specifically For Windows): `conda init powershell` (and restart the terminal) + - (For all plaforms) + ``` + cd hansegmentation-uncertainty-qa + conda deactivate + conda create --name hansegmentation-uncertainty-qa python=3.8 + conda activate hansegmentation-uncertainty-qa + conda develop . # check for conda.pth file in $ANACONDA_HOME/envs/hansegmentation-uncertainty-qa/lib/python3.8/site-packages + ``` + - Install packages + - Tensorflow (check [here]((https://www.tensorflow.org/install/source#tested_build_configurations)) for CUDA/cuDNN requirements) + - (stick to the exact commands) + - For tensorflow2.4 + ``` + conda install -c nvidia cudnn=8.0.0=cuda11.0_0 + pip install tensorflow==2.4 + ``` + - Check tensorflow installation + ``` + python -c "import tensorflow as tf;print('\n\n\n====================== \n GPU Devices: ',tf.config.list_physical_devices('GPU'), '\n======================')" + python -c "import tensorflow as tf;print('\n\n\n====================== \n', tf.reduce_sum(tf.random.normal([1000, 1000])), '\n======================' )" + ``` + - [unix] upon running either of the above commands, you will see tensorflow searching for library files like libcudart.so, libcublas.so, libcublasLt.so, libcufft.so, libcurand.so, libcusolver.so, libcusparse.so, libcudnn.so in the location `$ANACONDA_HOME/envs/hansegmentation-uncertainty-qa/lib/` + - [windows] upon running either of the above commands, you will see tensorflow searching for library files like cudart64_110.dll ... and so on in the location `$ANACONDA_HOME\envs\hansegmentation-uncertainty-qa\Library\bin` + + - Other tensorflow pacakges + ``` + pip install tensorflow-probability==0.12.1 tensorflow-addons==0.12.1 + ``` + - Other packages + ``` + pip install scipy seaborn tqdm psutil humanize pynrrd pydicom SimpleITK itk scikit-image + pip install psutil humanize pynvml + ``` + +# Notes + - All the `src/train{}.py` files are the ones used to train the models as shown in the `demo/` folder \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/demo/demo_dropout_ce.py b/demo/demo_dropout_ce.py new file mode 100644 index 0000000..7ba6001 --- /dev/null +++ b/demo/demo_dropout_ce.py @@ -0,0 +1,105 @@ +# Import private libraries +import src.config as config +from src.model.trainer import Trainer,Validator + +# Import public libraries +import os +import pdb +import traceback +import tensorflow as tf +from pathlib import Path + + + +if __name__ == "__main__": + + exp_name = 'HansegmentationUncertaintyQA-Dropout-CE' + + data_dir = Path(config.MAIN_DIR).joinpath('medical_dataloader', '_data') + resampled = True + crop_init = True + grid = True + batch_size = 2 + + model = config.MODEL_FOCUSNET_DROPOUT + + # To train + params = { + 'exp_name': exp_name + , 'random_seed':42 + , 'dataloader':{ + 'data_dir': data_dir + , 'dir_type' : [config.DATALOADER_MICCAI2015_TRAIN, config.DATALOADER_MICCAI2015_TRAIN_ADD] + , 'resampled' : resampled + , 'crop_init' : crop_init + , 'grid' : grid + , 'random_grid': True + , 'filter_grid': False + , 'centred_prob' : 0.3 + , 'batch_size' : batch_size + , 'shuffle' : 5 + , 'prefetch_batch': 4 + , 'parallel_calls': 3 + } + , 'model': { + 'name': model + , 'optimizer' : config.OPTIMIZER_ADAM + , 'init_lr' : 0.001 + , 'fixed_lr' : True + , 'epochs' : 1500 + , 'epochs_save': 50 + , 'epochs_eval': 50 + , 'epochs_viz' : 500 + , 'load_model':{ + 'load':False, 'load_exp_name': None, 'load_epoch':-1, 'load_optimizer_lr':None + } + , 'profiler': { + 'profile': False + , 'epochs': [2,3] + , 'steps_per_epoch': 60 + , 'starting_step': 4 + } + , 'model_tboard': False + } + , 'metrics' : { + 'logging_tboard': True + # for full 3D volume + , 'metrics_eval': {'Dice': config.LOSS_DICE} + ## for smaller grid/patch + , 'metrics_loss' : {'CE': config.LOSS_CE} # [config.LOSS_CE, config.LOSS_DICE] + , 'loss_weighted' : {'CE': True} + , 'loss_mask' : {'CE': True} + , 'loss_combo' : {'CE': 1.0} + } + , 'others': { + 'epochs_timer': 20 + , 'epochs_memory':5 + } + } + + # Call the trainer + trainer = Trainer(params) + trainer.train() + + # To evaluate on MICCAI2015 + params = { + 'exp_name': exp_name + , 'pid' : os.getpid() + , 'dataloader': { + 'data_dir' : data_dir + , 'resampled' : resampled + , 'grid' : grid + , 'crop_init' : crop_init + , 'batch_size' : batch_size + , 'prefetch_batch': 1 + , 'dir_type' : [config.DATALOADER_MICCAI2015_TEST] # [config.DATALOADER_MICCAI2015_TESTONSITE] + , 'eval_type' : config.MODE_TEST + } + , 'model': { + 'name': model + , 'load_epoch' : 1000 + , 'MC_RUNS' : 30 + , 'training_bool' : True # [True=dropout-at-test-time, False=no-dropout-at-test-time] + } + , 'save': True + } \ No newline at end of file diff --git a/demo/demo_dropout_cebasic.py b/demo/demo_dropout_cebasic.py new file mode 100644 index 0000000..d3ff9ad --- /dev/null +++ b/demo/demo_dropout_cebasic.py @@ -0,0 +1,105 @@ +# Import private libraries +import src.config as config +from src.model.trainer import Trainer,Validator + +# Import public libraries +import os +import pdb +import traceback +import tensorflow as tf +from pathlib import Path + + + +if __name__ == "__main__": + + exp_name = 'HansegmentationUncertaintyQA-Dropout-CEBasic' + + data_dir = Path(config.MAIN_DIR).joinpath('medical_dataloader', '_data') + resampled = True + crop_init = True + grid = True + batch_size = 2 + + model = config.MODEL_FOCUSNET_DROPOUT + + # To train + params = { + 'exp_name': exp_name + , 'random_seed':42 + , 'dataloader':{ + 'data_dir': data_dir + , 'dir_type' : [config.DATALOADER_MICCAI2015_TRAIN, config.DATALOADER_MICCAI2015_TRAIN_ADD] + , 'resampled' : resampled + , 'crop_init' : crop_init + , 'grid' : grid + , 'random_grid': True + , 'filter_grid': False + , 'centred_prob' : 0.3 + , 'batch_size' : batch_size + , 'shuffle' : 5 + , 'prefetch_batch': 4 + , 'parallel_calls': 3 + } + , 'model': { + 'name': model + , 'optimizer' : config.OPTIMIZER_ADAM + , 'init_lr' : 0.001 + , 'fixed_lr' : True + , 'epochs' : 1500 + , 'epochs_save': 50 + , 'epochs_eval': 50 + , 'epochs_viz' : 500 + , 'load_model':{ + 'load':False, 'load_exp_name': None, 'load_epoch':-1, 'load_optimizer_lr':None + } + , 'profiler': { + 'profile': False + , 'epochs': [2,3] + , 'steps_per_epoch': 60 + , 'starting_step': 4 + } + , 'model_tboard': False + } + , 'metrics' : { + 'logging_tboard': True + # for full 3D volume + , 'metrics_eval': {'Dice': config.LOSS_DICE} + ## for smaller grid/patch + , 'metrics_loss' : {'CE-Basic': config.LOSS_CE_BASIC} + , 'loss_weighted' : {'CE-Basic': True} + , 'loss_mask' : {'CE-Basic': True} + , 'loss_combo' : {'CE-Basic': 1.0} + } + , 'others': { + 'epochs_timer': 20 + , 'epochs_memory':5 + } + } + + # Call the trainer + trainer = Trainer(params) + trainer.train() + + # To evaluate on MICCAI2015 + params = { + 'exp_name': exp_name + , 'pid' : os.getpid() + , 'dataloader': { + 'data_dir' : data_dir + , 'resampled' : resampled + , 'grid' : grid + , 'crop_init' : crop_init + , 'batch_size' : batch_size + , 'prefetch_batch': 1 + , 'dir_type' : [config.DATALOADER_MICCAI2015_TEST] # [config.DATALOADER_MICCAI2015_TESTONSITE] + , 'eval_type' : config.MODE_TEST + } + , 'model': { + 'name': model + , 'load_epoch' : 1000 + , 'MC_RUNS' : 30 + , 'training_bool' : True # [True=dropout-at-test-time, False=no-dropout-at-test-time] + } + , 'save': True + } \ No newline at end of file diff --git a/demo/demo_dropout_dice.py b/demo/demo_dropout_dice.py new file mode 100644 index 0000000..3257769 --- /dev/null +++ b/demo/demo_dropout_dice.py @@ -0,0 +1,105 @@ +# Import private libraries +import src.config as config +from src.model.trainer import Trainer,Validator + +# Import public libraries +import os +import pdb +import traceback +import tensorflow as tf +from pathlib import Path + + + +if __name__ == "__main__": + + exp_name = 'HansegmentationUncertaintyQA-Dropout-DICE' + + data_dir = Path(config.MAIN_DIR).joinpath('medical_dataloader', '_data') + resampled = True + crop_init = True + grid = True + batch_size = 2 + + model = config.MODEL_FOCUSNET_DROPOUT + + # To train + params = { + 'exp_name': exp_name + , 'random_seed':42 + , 'dataloader':{ + 'data_dir': data_dir + , 'dir_type' : [config.DATALOADER_MICCAI2015_TRAIN, config.DATALOADER_MICCAI2015_TRAIN_ADD] + , 'resampled' : resampled + , 'crop_init' : crop_init + , 'grid' : grid + , 'random_grid': True + , 'filter_grid': False + , 'centred_prob' : 0.3 + , 'batch_size' : batch_size + , 'shuffle' : 5 + , 'prefetch_batch': 4 + , 'parallel_calls': 3 + } + , 'model': { + 'name': model + , 'optimizer' : config.OPTIMIZER_ADAM + , 'init_lr' : 0.001 + , 'fixed_lr' : True + , 'epochs' : 1500 + , 'epochs_save': 50 + , 'epochs_eval': 50 + , 'epochs_viz' : 500 + , 'load_model':{ + 'load':False, 'load_exp_name': None, 'load_epoch':-1, 'load_optimizer_lr':None + } + , 'profiler': { + 'profile': False + , 'epochs': [2,3] + , 'steps_per_epoch': 60 + , 'starting_step': 4 + } + , 'model_tboard': False + } + , 'metrics' : { + 'logging_tboard': True + # for full 3D volume + , 'metrics_eval': {'Dice': config.LOSS_DICE} + ## for smaller grid/patch + , 'metrics_loss' : {'Dice': config.LOSS_DICE} + , 'loss_weighted' : {'Dice': True} + , 'loss_mask' : {'Dice': True} + , 'loss_combo' : {'Dice': 1.0} + } + , 'others': { + 'epochs_timer': 20 + , 'epochs_memory':5 + } + } + + # Call the trainer + trainer = Trainer(params) + trainer.train() + + # To evaluate on MICCAI2015 + params = { + 'exp_name': exp_name + , 'pid' : os.getpid() + , 'dataloader': { + 'data_dir' : data_dir + , 'resampled' : resampled + , 'grid' : grid + , 'crop_init' : crop_init + , 'batch_size' : batch_size + , 'prefetch_batch': 1 + , 'dir_type' : [config.DATALOADER_MICCAI2015_TEST] # [config.DATALOADER_MICCAI2015_TESTONSITE] + , 'eval_type' : config.MODE_TEST + } + , 'model': { + 'name': model + , 'load_epoch' : 1000 + , 'MC_RUNS' : 30 + , 'training_bool' : True # [True=dropout-at-test-time, False=no-dropout-at-test-time] + } + , 'save': True + } \ No newline at end of file diff --git a/demo/demo_flipout_ce.py b/demo/demo_flipout_ce.py new file mode 100644 index 0000000..faad657 --- /dev/null +++ b/demo/demo_flipout_ce.py @@ -0,0 +1,109 @@ +# Import private libraries +import src.config as config +from src.model.trainer import Trainer,Validator + +# Import public libraries +import os +import pdb +import traceback +import tensorflow as tf +from pathlib import Path + + + +if __name__ == "__main__": + + exp_name = 'HansegmentationUncertaintyQA-Flipout-CE' + + data_dir = Path(config.MAIN_DIR).joinpath('medical_dataloader', '_data') + resampled = True + crop_init = True + grid = True + batch_size = 2 + + model = config.MODEL_FOCUSNET_FLIPOUT + + # To train + params = { + 'exp_name': exp_name + , 'random_seed':42 + , 'dataloader':{ + 'data_dir': data_dir + , 'dir_type' : [config.DATALOADER_MICCAI2015_TRAIN, config.DATALOADER_MICCAI2015_TRAIN_ADD] + , 'resampled' : resampled + , 'crop_init' : crop_init + , 'grid' : grid + , 'random_grid': True + , 'filter_grid': False + , 'centred_prob' : 0.3 + , 'batch_size' : batch_size + , 'shuffle' : 5 + , 'prefetch_batch': 4 + , 'parallel_calls': 3 + } + , 'model': { + 'name': model + , 'kl_alpha_init' : 0.01 + , 'kl_schedule' : config.KL_DIV_FIXED # [config.KL_DIV_FIXED, config.KL_DIV_ANNEALING] + , 'kl_scale_factor' : 154 # distribute the kl_div loss 'kl_scale_factor' times across an epoch i.e. dataset_len/batch_size=> + # (140,140,40)=308/2.0=154 + , 'optimizer' : config.OPTIMIZER_ADAM + , 'init_lr' : 0.001 + , 'fixed_lr' : True + , 'epochs' : 1500 + , 'epochs_save': 50 + , 'epochs_eval': 50 + , 'epochs_viz' : 500 + , 'load_model':{ + 'load':False, 'load_exp_name': None, 'load_epoch':-1, 'load_optimizer_lr':None + } + , 'profiler': { + 'profile': False + , 'epochs': [2,3] + , 'steps_per_epoch': 60 + , 'starting_step': 4 + } + , 'model_tboard': False + } + , 'metrics' : { + 'logging_tboard': True + # for full 3D volume + , 'metrics_eval': {'Dice': config.LOSS_DICE} + ## for smaller grid/patch + , 'metrics_loss' : {'CE': config.LOSS_CE} # [config.LOSS_CE, config.LOSS_DICE] + , 'loss_weighted' : {'CE': True} + , 'loss_mask' : {'CE': True} + , 'loss_combo' : {'CE': 1.0} + } + , 'others': { + 'epochs_timer': 20 + , 'epochs_memory':5 + } + } + + # Call the trainer + trainer = Trainer(params) + trainer.train() + + # To evaluate on MICCAI2015 + params = { + 'exp_name': exp_name + , 'pid' : os.getpid() + , 'dataloader': { + 'data_dir' : data_dir + , 'resampled' : resampled + , 'grid' : grid + , 'crop_init' : crop_init + , 'batch_size' : batch_size + , 'prefetch_batch': 1 + , 'dir_type' : [config.DATALOADER_MICCAI2015_TEST] # [config.DATALOADER_MICCAI2015_TESTONSITE] + , 'eval_type' : config.MODE_TEST + } + , 'model': { + 'name': model + , 'load_epoch' : 1000 + , 'MC_RUNS' : 30 + , 'training_bool' : True # [True=dropout-at-test-time, False=no-dropout-at-test-time] + } + , 'save': True + } \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..8bee63c --- /dev/null +++ b/src/config.py @@ -0,0 +1,401 @@ +############################################################ +# INIT # +############################################################ + +from pathlib import Path +PROJECT_DIR = Path(__file__).parent.absolute().parent.absolute() +MAIN_DIR = Path(PROJECT_DIR).parent.absolute() + +import os +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +os.environ["TF_GPU_THREAD_MODE"] = "gpu_private" # to avoid large "Kernel Launch Time" + +import tensorflow as tf +try: + if len(tf.config.list_physical_devices('GPU')): + tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True) + sys_details = tf.sysconfig.get_build_info() + print (' - [TFlow Build Info] ver: ', tf.__version__, 'CUDA(major.minor):', sys_details["cuda_version"], ' || cuDNN(major): ', sys_details["cudnn_version"]) + + else: + print (' - No GPU present!! Exiting ...') + import sys; sys.exit(1) +except: + pass + +############################################################ +# MODEL RELATED # +############################################################ +MODEL_CHKPOINT_MAINFOLDER = '_models' +MODEL_CHKPOINT_NAME_FMT = 'ckpt_epoch{:03d}' +MODEL_LOGS_FOLDERNAME = 'logs' +MODEL_IMGS_FOLDERNAME = 'images' +MODEL_PATCHES_FOLDERNAME = 'patches' + +EXT_NRRD = '.nrrd' + +MODE_TRAIN = 'Train' +MODE_TRAIN_VAL = 'Train_val' +MODE_VAL = 'Val' +MODE_VAL_NEW = 'Val_New' +MODE_TEST = 'Test' +MODE_DEEPMINDTCIA_TEST_ONC = 'DeepMindTCIATestOnc' +MODE_DEEPMINDTCIA_TEST_RAD = 'DeepMindTCIATestRad' +MODE_DEEPMINDTCIA_VAL_ONC = 'DeepMindTCIAValOnc' +MODE_DEEPMINDTCIA_VAL_RAD = 'DeepMindTCIAValRad' + +ACT_SIGMOID = 'sigmoid' +ACT_SOFTMAX = 'softmax' + +MODEL_FOCUSNET_DROPOUT = 'ModelFocusNetDropOut' +MODEL_FOCUSNET_FLIPOUT = 'ModelFocusNetFlipOut' + +OPTIMIZER_ADAM = 'Adam' + +THRESHOLD_SIGMA_IGNORE = 0.3 +MIN_SIZE_COMPONENT = 10 + +KL_DIV_FIXED = 'fixed' +KL_DIV_ANNEALING = 'annealing' + +############################################################ +# EVAL RELATED # +############################################################ +KEY_DICE_AVG = 'dice_avg' +KEY_DICE_LABELS = 'dice_labels' +KEY_HD_AVG = 'hd_avg' +KEY_HD_LABELS = 'hd_labels' +KEY_HD95_AVG = 'hd95_avg' +KEY_HD95_LABELS = 'hd95_labels' +KEY_MSD_AVG = 'msd_avg' +KEY_MSD_LABELS = 'msd_labels' +KEY_ECE_AVG = 'ece_avg' +KEY_ECE_LABELS = 'ece_labels' +KEY_AVU_ENT = 'avu_ent' +KEY_AVU_PAC_ENT = 'avu_pac_ent' +KEY_AVU_PUI_ENT = 'avu_pui_ent' +KEY_THRESH_ENT = 'avu_thresh_ent' +KEY_AVU_MIF = 'avu_mif' +KEY_AVU_PAC_MIF = 'avu_pac_mif' +KEY_AVU_PUI_MIF = 'avu_pui_mif' +KEY_THRESH_MIF = 'avu_thresh_mif' +KEY_AVU_UNC = 'avu_unc' +KEY_AVU_PAC_UNC = 'avu_pac_unc' +KEY_AVU_PUI_UNC = 'avu_pui_unc' +KEY_THRESH_UNC = 'avu_thresh_unc' + +PAVPU_UNC_THRESHOLD = 'adaptive-median' # [0.3, 'adaptive', 'adaptive-median'] +PAVPU_ENT_THRESHOLD = 0.5 +PAVPU_MIF_THRESHOLD = 0.1 +PAVPU_GRID_SIZE = (4,4,2) +PAVPU_RATIO_NEG = 0.9 + +KEY_ENT = 'ent' +KEY_MIF = 'mif' +KEY_STD = 'std' +KEY_PERC = 'perc' + +KEY_SUM = 'sum' +KEY_AVG = 'avg' + +CMAP_MAGMA = 'magma' +CMAP_GRAY = 'gray' + +FILENAME_EVAL3D_JSON = 'res.json' + +FOLDERNAME_TMP = '_tmp' +FOLDERNAME_TMP_BOKEH = 'bokeh-plots' +FOLDERNAME_TMP_ENTMIF = 'entmif' + +VAL_ECE_NAN = -0.1 +VAL_DICE_NAN = -1.0 +VAL_MC_RUNS_DEFAULT = 20 + +KEY_PATIENT_GLOBAL = 'global' + +SUFFIX_DET = '-Det' +SUFFIX_MC = '-MC{}' + +KEY_MC_RUNS = 'MC_RUNS' +KEY_TRAINING_BOOL = 'training_bool' + +############################################################ +# LOSSES RELATED # +############################################################ +LOSS_DICE = 'Dice' +LOSS_CE = 'CE' +LOSS_CE_BASIC = 'CE-Basic' + +############################################################ +# DATALOADER RELATED # +############################################################ + +HEAD_AND_NECK = 'HaN' +PROSTATE = 'Prostrate' +THORACIC = 'Thoracic' + +DIRNAME_PROCESSED = 'processed' +DIRNAME_PROCESSED_SPACING = 'processed_{}' +DIRNAME_RAW = 'raw' +DIRNAME_SAVE_3D = 'data_3D' + +FILENAME_JSON_IMG = 'img.json' +FILENAME_JSON_MASK = 'mask.json' +FILENAME_JSON_IMG_RESAMPLED = 'img_resampled.json' +FILENAME_JSON_MASK_RESAMPLED = 'mask_resampled.json' +FILENAME_CSV_IMG = 'img.csv' +FILENAME_CSV_MASK = 'mask.csv' +FILENAME_CSV_IMG_RESAMPLED = 'img_resampled.csv' +FILENAME_CSV_MASK_RESAMPLED = 'mask_resampled.csv' + +FILENAME_VOXEL_INFO = 'voxelinfo.json' + +KEYNAME_LABEL_OARS = 'labels_oars' +KEYNAME_LABEL_EXTERNAL = 'labels_external' +KEYNAME_LABEL_TUMORS = 'labels_tumors' +KEYNAME_LABEL_MISSING = 'labels_missing' + +DATAEXTRACTOR_WORKERS = 8 + +import itk +import numpy as np +import tensorflow as tf +import SimpleITK as sitk + +DATATYPE_VOXEL_IMG = np.int16 +DATATYPE_VOXEL_MASK = np.uint8 + +DATATYPE_SITK_VOXEL_IMG = sitk.sitkInt16 +DATATYPE_SITK_VOXEL_MASK = sitk.sitkUInt8 + +DATATYPE_ITK_VOXEL_MASK = itk.UC + +DATATYPE_NP_INT32 = np.int32 + +DATATYPE_TF_STRING = tf.string +DATATYPE_TF_UINT8 = tf.uint8 +DATATYPE_TF_INT16 = tf.int16 +DATATYPE_TF_INT32 = tf.int32 +DATATYPE_TF_FLOAT32 = tf.float32 + +DUMMY_LABEL = 255 + +# Keys - Volume params +KEYNAME_PIXEL_SPACING = 'pixel_spacing' +KEYNAME_ORIGIN = 'origin' +KEYNAME_SHAPE = 'shape' +KEYNAME_INTERCEPT = 'intercept' +KEYNAME_SLOPE = 'slope' +KEYNAME_ZVALS = 'z_vals' +KEYNAME_MEAN_MIDPOINT = 'mean_midpoint' +KEYNAME_OTHERS = 'others' +KEYNAME_INTERPOLATOR = 'interpolator' +KEYNAME_INTERPOLATOR_IMG = 'interpolator_img' +KEYNAME_INTERPOLATOR_MASK = 'interpolator_mask' +KEYNAME_SHAPE_ORIG = 'shape_orig' +KEYNAME_SHAPE_RESAMPLED = 'shape_resampled' + +# Keys - Dataset params +KEY_VOXELRESO = 'VOXEL_RESO' +KEY_LABEL_MAP = 'LABEL_MAP' +KEY_LABEL_MAP_FULL = 'LABEL_MAP_FULL' +KEY_LABEL_MAP_EXTERNAL = 'LABEL_MAP_EXTERNAL' +KEY_LABEL_COLORS = 'LABEL_COLORS' +KEY_LABEL_WEIGHTS = 'LABEL_WEIGHTS' +KEY_IGNORE_LABELS = 'IGNORE_LABELS' +KEY_LABELID_BACKGROUND = 'LABELID_BACKGROUND' +KEY_LABELID_MIDPOINT = 'LABELID_MIDPOINT' +KEY_HU_MIN = 'HU_MIN' +KEY_HU_MAX = 'HU_MAX' +KEY_PREPROCESS = 'PREPROCESS' +KEY_CROP = 'CROP' +KEY_GRID_3D = 'GRID_3D' + +# Common Labels +LABELNAME_BACKGROUND = 'Background' + +# Keys - Cropping +KEY_MIDPOINT_EXTENSION_W_LEFT = 'MIDPOINT_EXTENSION_W_LEFT' +KEY_MIDPOINT_EXTENSION_W_RIGHT = 'MIDPOINT_EXTENSION_W_RIGHT' +KEY_MIDPOINT_EXTENSION_H_BACK = 'MIDPOINT_EXTENSION_H_BACK' +KEY_MIDPOINT_EXTENSION_H_FRONT = 'MIDPOINT_EXTENSION_H_FRONT' +KEY_MIDPOINT_EXTENSION_D_TOP = 'MIDPOINT_EXTENSION_D_TOP' +KEY_MIDPOINT_EXTENSION_D_BOTTOM = 'MIDPOINT_EXTENSION_D_BOTTOM' + +# Keys - Gridding +KEY_GRID_SIZE = 'GRID_SIZE' +KEY_GRID_OVERLAP = 'GRID_OVERLAP' +KEY_GRID_SAMPLER_PERC = 'GRID_SAMPLER_PERC' +KEY_GRID_RANDOM_SHIFT_MAX = 'GRID_RANDOM_SHIFT_MAX' +KEY_GRID_RANDOM_SHIFT_PERC = 'GRID_RANDOM_SHIFT_PERC' + +# Keys - .nrrd file keys +KEY_NRRD_PIXEL_SPACING = 'space directions' +KEY_NRRD_ORIGIN = 'space origin' +KEY_NRRD_SHAPE = 'sizes' + +TYPE_VOXEL_ORIGSHAPE = 'orig' +TYPE_VOXEL_RESAMPLED = 'resampled' + +MASK_TYPE_ONEHOT = 'one_hot' +MASK_TYPE_COMBINED = 'combined' + +PREFETCH_BUFFER = 5 + +################################### HaN - MICCAI 2015 ################################### + +DATASET_MICCAI2015 = 'miccai2015' +DATALOADER_MICCAI2015_TRAIN = 'train' +DATALOADER_MICCAI2015_TRAIN_ADD = 'train_additional' +DATALOADER_MICCAI2015_TEST = 'test_offsite' +DATALOADER_MICCAI2015_TESTONSITE = 'test_onsite' + +HaN_MICCAI2015 = { + KEY_LABEL_MAP : { + 'Background':0 + , 'BrainStem':1 , 'Chiasm':2, 'Mandible':3 + , 'OpticNerve_L':4, 'OpticNerve_R':5 + , 'Parotid_L':6,'Parotid_R':7 + ,'Submandibular_L':8, 'Submandibular_R':9 + } + , KEY_LABEL_COLORS : { + 0: [255,255,255,10] + , 1:[0,110,254,255], 2: [225,128,128,255], 3:[254,0,128,255] + , 4:[191,50,191,255], 5:[254,128,254,255] + , 6: [182, 74, 74,255], 7:[128,128,0,255] + , 8:[50,105,161,255], 9:[46,194,194,255] + } + , KEY_LABEL_WEIGHTS : [1/4231347, 1/16453, 1/372, 1/32244, 1/444, 1/397, 1/16873, 1/17510, 1/4419, 1/4410] # avg voxel count + , KEY_IGNORE_LABELS : [0] + , KEY_LABELID_BACKGROUND : 0 + , KEY_LABELID_MIDPOINT : 1 + , KEY_HU_MIN : -125 # window_levl=50, window_width=350, 50 - (350/2) for soft tissue + , KEY_HU_MAX : 225 # window_levl=50, window_width=350, 50 + (350/2) for soft tissue + , KEY_VOXELRESO : (0.8, 0.8, 2.5) # [(0.8, 0.8, 2.5), (1.0, 1.0, 2.0) , (1,1,1), (1,1,2)] + , KEY_GRID_3D : { + TYPE_VOXEL_RESAMPLED:{ + '(0.8, 0.8, 2.5)':{ + KEY_GRID_SIZE : [140,140,40] + , KEY_GRID_OVERLAP : [20,20,0] + , KEY_GRID_SAMPLER_PERC : 0.90 + , KEY_GRID_RANDOM_SHIFT_MAX : 40 + , KEY_GRID_RANDOM_SHIFT_PERC : 0.5 + } + } + } + , KEY_PREPROCESS:{ + TYPE_VOXEL_RESAMPLED:{ + KEY_CROP: { + '(0.8, 0.8, 2.5)':{ + KEY_MIDPOINT_EXTENSION_W_LEFT : 120 + ,KEY_MIDPOINT_EXTENSION_W_RIGHT : 120 # 240 + ,KEY_MIDPOINT_EXTENSION_H_BACK : 66 + ,KEY_MIDPOINT_EXTENSION_H_FRONT : 174 # 240 + ,KEY_MIDPOINT_EXTENSION_D_TOP : 20 + ,KEY_MIDPOINT_EXTENSION_D_BOTTOM : 60 # 80 [96(20-76) ] + } + } + } + } +} + + +PATIENT_MICCAI2015_TESTOFFSITE = 'HaN_MICCAI2015-test_offsite-{}_resample_True' +FILENAME_SAVE_CT_MICCAI2015 = 'nrrd_HaN_MICCAI2015-test_offsite-{}_resample_True_img.nrrd' +FILENAME_SAVE_GT_MICCAI2015 = 'nrrd_HaN_MICCAI2015-test_offsite-{}_resample_True_mask.nrrd' +FILENAME_SAVE_PRED_MICCAI2015 = 'nrrd_HaN_MICCAI2015-test_offsite-{}_resample_True_maskpred.nrrd' +FILENAME_SAVE_MIF_MICCAI2015 = 'nrrd_HaN_MICCAI2015-test_offsite-{}_resample_True_maskpredmif.nrrd' +FILENAME_SAVE_ENT_MICCAI2015 = 'nrrd_HaN_MICCAI2015-test_offsite-{}_resample_True_maskpredent.nrrd' +FILENAME_SAVE_STD_MICCAI2015 = 'nrrd_HaN_MICCAI2015-test_offsite-{}_resample_True_maskpredstd.nrrd' + +PATIENTIDS_MICCAI2015_TEST = ['0522c0555', '0522c0576', '0522c0598', '0522c0659', '0522c0661', '0522c0667', '0522c0669', '0522c0708', '0522c0727', '0522c0746'] + +################################### HaN - DeepMindTCIA ################################### + +DATALOADER_DEEPMINDTCIA_TEST = 'test' +DATALOADER_DEEPMINDTCIA_VAL = 'validation' +DATALOADER_DEEPMINDTCIA_ONC = 'oncologist' +DATALOADER_DEEPMINDTCIA_RAD = 'radiographer' +DATASET_DEEPMINDTCIA = 'deepmindtcia' + + +KEY_LABELMAP_MICCAI_DEEPMINDTCIA = KEY_LABEL_MAP + 'MICCAI_TCIADEEPMIND' +HaN_DeepMindTCIA = { + KEY_LABELMAP_MICCAI_DEEPMINDTCIA: { + 'Background': 'Background' + , 'Brainstem': 'BrainStem' + , 'Mandible':'Mandible' + , 'Optic-Nerve-Lt': 'OpticNerve_L', 'Optic-Nerve-Rt': 'OpticNerve_R' + , 'Optic_Nerve_Lt': 'OpticNerve_L', 'Optic_Nerve_Rt': 'OpticNerve_R' + , 'Parotid-Lt':'Parotid_L', 'Parotid-Rt':'Parotid_R' + , 'Parotid_Lt':'Parotid_L', 'Parotid_Rt':'Parotid_R' + , 'Submandibular-Lt': 'Submandibular_L', 'Submandibular-Rt':'Submandibular_R' + , 'Submandibular_Lt': 'Submandibular_L', 'Submandibular_Rt':'Submandibular_R' + } + , KEY_LABEL_MAP : { + 'Background':0 + , 'BrainStem':1 , 'Chiasm':2, 'Mandible':3 + , 'OpticNerve_L':4, 'OpticNerve_R':5 + , 'Parotid_L':6,'Parotid_R':7 + ,'Submandibular_L':8, 'Submandibular_R':9 + } + , KEY_LABEL_COLORS : { + 0: [255,255,255,10] + , 1:[0,110,254,255], 2: [225,128,128,255], 3:[254,0,128,255] + , 4:[191,50,191,255], 5:[254,128,254,255] + , 6: [182, 74, 74,255], 7:[128,128,0,255] + , 8:[50,105,161,255], 9:[46,194,194,255] + } + , KEY_LABEL_WEIGHTS : [] + , KEY_IGNORE_LABELS : [0] + , KEY_LABELID_BACKGROUND : 0 + , KEY_LABELID_MIDPOINT : 1 + , KEY_HU_MIN : -125 # window_levl=50, window_width=350, 50 - (350/2) for soft tissue + , KEY_HU_MAX : 225 # window_levl=50, window_width=350, 50 + (350/2) for soft tissue + , KEY_VOXELRESO : (0.8, 0.8, 2.5) # [(0.8, 0.8, 2.5), (1.0, 1.0, 2.0) , (1,1,1), (1,1,2)] + , KEY_GRID_3D : { + TYPE_VOXEL_RESAMPLED:{ + '(0.8, 0.8, 2.5)':{ + KEY_GRID_SIZE : [140,140,40] + , KEY_GRID_OVERLAP : [20,20,0] + , KEY_GRID_SAMPLER_PERC : 0.90 + , KEY_GRID_RANDOM_SHIFT_MAX : 40 + , KEY_GRID_RANDOM_SHIFT_PERC : 0.5 + } + } + } + , KEY_PREPROCESS:{ + TYPE_VOXEL_RESAMPLED:{ + KEY_CROP: { + '(0.8, 0.8, 2.5)':{ + KEY_MIDPOINT_EXTENSION_W_LEFT : 120 + ,KEY_MIDPOINT_EXTENSION_W_RIGHT : 120 # 240 + ,KEY_MIDPOINT_EXTENSION_H_BACK : 66 + ,KEY_MIDPOINT_EXTENSION_H_FRONT : 174 # 240 + ,KEY_MIDPOINT_EXTENSION_D_TOP : 20 + ,KEY_MIDPOINT_EXTENSION_D_BOTTOM : 60 # 80 [96(20-76) ] + } + } + } + } +} + +PATIENTIDS_DEEPMINDTCIA_TEST = ['0522c0331', '0522c0416', '0522c0419', '0522c0629', '0522c0768', '0522c0770', '0522c0773', '0522c0845', 'TCGA-CV-7236', 'TCGA-CV-7243', 'TCGA-CV-7245', 'TCGA-CV-A6JO', 'TCGA-CV-A6JY', 'TCGA-CV-A6K0', 'TCGA-CV-A6K1'] + +PATIENT_DEEPMINDTCIA_TEST_ONC = 'HaN_DeepMindTCIA-test-oncologist-{}_resample_True' +FILENAME_SAVE_CT_DEEPMINDTCIA_TEST_ONC = 'nrrd_HaN_DeepMindTCIA-test-oncologist-{}_resample_True_img.nrrd' +FILENAME_SAVE_GT_DEEPMINDTCIA_TEST_ONC = 'nrrd_HaN_DeepMindTCIA-test-oncologist-{}_resample_True_mask.nrrd' +FILENAME_SAVE_PRED_DEEPMINDTCIA_TEST_ONC = 'nrrd_HaN_DeepMindTCIA-test-oncologist-{}_resample_True_maskpred.nrrd' +FILENAME_SAVE_MIF_DEEPMINDTCIA_TEST_ONC = 'nrrd_HaN_DeepMindTCIA-test-oncologist-{}_resample_True_maskpredmif.nrrd' +FILENAME_SAVE_ENT_DEEPMINDTCIA_TEST_ONC = 'nrrd_HaN_DeepMindTCIA-test-oncologist-{}_resample_True_maskpredent.nrrd' +FILENAME_SAVE_STD_DEEPMINDTCIA_TEST_ONC = 'nrrd_HaN_DeepMindTCIA-test-oncologist-{}_resample_True_maskpredstd.nrrd' + +############################################################ +# VISUALIZATION # +############################################################ +FIGSIZE=(15,15) +IGNORE_LABELS = [] +PREDICT_THRESHOLD_MASK = 0.6 + +ENT_MIN, ENT_MAX = 0.0, 0.5 +MIF_MIN, MIF_MAX = 0.0, 0.1 \ No newline at end of file diff --git a/src/dataloader/__init__.py b/src/dataloader/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/dataloader/augmentations.py b/src/dataloader/augmentations.py new file mode 100644 index 0000000..f941252 --- /dev/null +++ b/src/dataloader/augmentations.py @@ -0,0 +1,592 @@ +# Import private libraries +import src.config as config +import src.dataloader.utils as utils + + +# Import public libraries +import pdb +import math +import traceback +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path +import SimpleITK as sitk +import tensorflow as tf +import tensorflow_addons as tfa + +class Rotate3D: + + def __init__(self): + self.name = 'Rotate3D' + + @tf.function + def execute(self, x, y, meta1, meta2): + """ + Rotates a 3D image along the z-axis by some random angle + - Ref: https://www.tensorflow.org/api_docs/python/tf/image/rot90 + + Parameters + ---------- + x: tf.Tensor + This is the 3D image of dtype=tf.int16 and shape=(H,W,C,1) + y: tf.Tensor + This is the 3D mask of dtype=tf.uint8 and shape=(H,W,C,Labels) + meta1 = tf.Tensor + This contains some indexing and meta info. Irrelevant to this function + meta2 = tf.Tensor + This contains some string information on patient identification. Irrelevant to this function + """ + + try: + if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= 0.2: + rotate_count = tf.random.uniform([], minval=1, maxval=4, dtype=tf.dtypes.int32) + + # k=3 (rot=270) (anti-clockwise) + if rotate_count == 3: + xtmp = tf.transpose(tf.reverse(x, [0]), [1,0,2,3]) + ytmp = tf.transpose(tf.reverse(y, [0]), [1,0,2,3]) + + # k = 1 (rot=90) (anti-clockwise) + elif rotate_count == 1: + xtmp = tf.reverse(tf.transpose(x, [1,0,2,3]), [0]) + ytmp = tf.reverse(tf.transpose(y, [1,0,2,3]), [0]) + + # k=2 (rot=180) (clock-wise) + elif rotate_count == 2: + xtmp = tf.reverse(x, [0,1]) + ytmp = tf.reverse(y, [0,1]) + + else: + xtmp = x + ytmp = y + + return xtmp, ytmp, meta1, meta2 + # return xtmp.read_value(), ytmp.read_value(), meta1, meta2 + + else: + return x, y, meta1, meta2 + except: + traceback.print_exc() + return x, y, meta1, meta2 + + @tf.function + def execute2(self, x_moving, x_fixed, y_moving, y_fixed, meta1, meta2): + + try: + if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= 0.5: + rotate_count = tf.random.uniform([], minval=1, maxval=4, dtype=tf.dtypes.int32) + + # k=3 (rot=270) (anti-clockwise) + if rotate_count == 3: + x_moving_tmp = tf.transpose(tf.reverse(x_moving, [0]), [1,0,2,3]) + x_fixed_tmp = tf.transpose(tf.reverse(x_fixed, [0]), [1,0,2,3]) + y_moving_tmp = tf.transpose(tf.reverse(y_moving, [0]), [1,0,2,3]) + y_fixed_tmp = tf.transpose(tf.reverse(y_fixed, [0]), [1,0,2,3]) + + # k = 1 (rot=90) (anti-clockwise) + elif rotate_count == 1: + x_moving_tmp = tf.reverse(tf.transpose(x_moving, [1,0,2,3]), [0]) + x_fixed_tmp = tf.reverse(tf.transpose(x_fixed, [1,0,2,3]), [0]) + y_moving_tmp = tf.reverse(tf.transpose(y_moving, [1,0,2,3]), [0]) + y_fixed_tmp = tf.reverse(tf.transpose(y_fixed, [1,0,2,3]), [0]) + + # k=2 (rot=180) (clock-wise) + elif rotate_count == 2: + x_moving_tmp = tf.reverse(x_moving, [0,1]) + x_fixed_tmp = tf.reverse(x_fixed, [0,1]) + y_moving_tmp = tf.reverse(y_moving, [0,1]) + y_fixed_tmp = tf.reverse(y_fixed, [0,1]) + + else: + x_moving_tmp = x_moving + x_fixed_tmp = x_fixed + y_moving_tmp = y_moving + y_fixed_tmp = y_fixed + + return x_moving_tmp, x_fixed_tmp, y_moving_tmp, y_fixed_tmp, meta1, meta2 + + else: + return x_moving, x_fixed, y_moving, y_fixed, meta1, meta2 + + except: + tf.print(' - [ERROR][Rotate3D][execute2]') + return x_moving, x_fixed, y_moving, y_fixed, meta1, meta2 + +class Rotate3DSmall: + + def __init__(self, label_map, mask_type, prob=0.2, angle_degrees=15, interpolation='bilinear'): + + self.name = 'Rotate3DSmall' + + self.label_ids = label_map.values() + self.class_count = len(label_map) + self.mask_type = mask_type + + self.prob = prob + self.angle_degrees = angle_degrees + self.interpolation = interpolation + + @tf.function + def execute(self, x, y, meta1, meta2): + """ + - Ref: https://www.tensorflow.org/addons/api_docs/python/tfa/image/rotate + + """ + + x = tf.cast(x, dtype=tf.float32) + y = tf.cast(y, dtype=tf.float32) + + if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.prob: + + angle_radians = tf.random.uniform([], minval=math.radians(-self.angle_degrees), maxval=math.radians(self.angle_degrees) + , dtype=tf.dtypes.float32) + + + if self.mask_type == config.MASK_TYPE_ONEHOT: + """ + - x: [H,W,D,1] + - y: [H,W,D,class] + """ + + x = tf.expand_dims(x[:,:,:,0], axis=0) # [1,H,W,D] + x = tf.transpose(x) # [D,W,H,1] + x = tfa.image.rotate(x, angle_radians, interpolation='bilinear') # [D,W,H,1] + x = tf.transpose(x) # [1,H,W,D] + x = tf.expand_dims(x[0], axis=-1) # [H,W,D,1] + + y = tf.concat([ + tf.expand_dims( + tf.transpose( # [1,H,W,D] + tfa.image.rotate( # [D,W,H,1] + tf.transpose( # [D,W,H,1] + tf.expand_dims(y[:,:,:,class_id], axis=0) # [1,H,W,D] + ) + , angle_radians, interpolation='bilinear' + ) + )[0] # [H,W,D] + , axis=-1 # [H,W,D,1] + ) for class_id in range(self.class_count) + ], axis=-1) # [H,W,D,10] + y = tf.where(tf.math.greater_equal(y,0.5), 1.0, y) + y = tf.where(tf.math.less(y,0.5), 0.0, y) + + elif self.mask_type == config.MASK_TYPE_COMBINED: + """ + - x: [H,W,D] + - y: [H,W,D] + """ + + x = tf.expand_dims(x,axis=0) # [1,H,W,D] + x = tf.transpose(x) # [D,W,H,1] + x = tfa.image.rotate(x, angle_radians, interpolation=self.interpolation) # [D,W,H,1] + x = tf.transpose(x) # [1,H,W,D] + x = x[0] # [H,W,D] + + y = tf.concat([tf.expand_dims(tf.math.equal(y, label), axis=-1) for label in self.label_ids], axis=-1) # [H,W,D,L] + y = tf.cast(y, dtype=tf.float32) + y = tf.concat([ + tf.expand_dims( + tf.transpose( # [1,H,W,D] + tfa.image.rotate( # [D,W,H,1] + tf.transpose( # [D,W,H,1] + tf.expand_dims(y[:,:,:,class_id], axis=0) # [1,H,W,D] + ) + , angle_radians, interpolation='bilinear' + ) + )[0] # [H,W,D] + , axis=-1 # [H,W,D,1] + ) for class_id in range(self.class_count) + ], axis=-1) # [H,W,D,L] + y = tf.where(tf.math.greater_equal(y,0.5), 1.0, y) + y = tf.where(tf.math.less(y,0.5), 0.0, y) + y = tf.math.argmax(y, axis=-1) # [H,W,D] + + x = tf.cast(x, dtype=tf.float32) + y = tf.cast(y, dtype=tf.float32) + + else: + x = tf.cast(x, dtype=tf.float32) + y = tf.cast(y, dtype=tf.float32) + + return x, y, meta1, meta2 + +class Rotate3DSmallZ: + + def __init__(self, label_map, mask_type, prob=0.2, angle_degrees=5, interpolation='bilinear'): + + self.name = 'Rotate3DSmallZ' + + self.label_ids = label_map.values() + self.class_count = len(label_map) + self.mask_type = mask_type + + self.prob = prob + self.angle_degrees = angle_degrees + self.interpolation = interpolation + + @tf.function + def execute(self, x, y, meta1, meta2): + """ + Params + ------ + x: [H,W,D,1] + y: [H,W,D,C] + - Ref: https://www.tensorflow.org/addons/api_docs/python/tfa/image/rotate + + """ + + x = tf.cast(x, dtype=tf.float32) + y = tf.cast(y, dtype=tf.float32) + + if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.prob: + + angle_radians = tf.random.uniform([], minval=math.radians(-self.angle_degrees), maxval=math.radians(self.angle_degrees) + , dtype=tf.dtypes.float32) + + if self.mask_type == config.MASK_TYPE_ONEHOT: + """ + - x: [H,W,D,1] + - y: [H,W,D,class] + """ + + x = tfa.image.rotate(x, angle_radians, interpolation='bilinear') # [H,W,D,1] + + y = tf.concat([ + tfa.image.rotate( + tf.expand_dims(y[:,:,:,class_id], axis=-1) # [H,W,D,1] + , angle_radians, interpolation='bilinear' + ) + for class_id in range(self.class_count) + ], axis=-1) # [H,W,D,10] + y = tf.where(tf.math.greater_equal(y,0.5), 1.0, y) + y = tf.where(tf.math.less(y,0.5), 0.0, y) + + elif self.mask_type == config.MASK_TYPE_COMBINED: + pass + + x = tf.cast(x, dtype=tf.float32) + y = tf.cast(y, dtype=tf.float32) + + else: + x = tf.cast(x, dtype=tf.float32) + y = tf.cast(y, dtype=tf.float32) + + return x, y, meta1, meta2 + +class Translate: + + def __init__(self, label_map, translations=[40,40], prob=0.2): + + self.translations = translations + self.prob = prob + self.label_ids = label_map.values() + self.class_count = len(label_map) + self.name = 'Translate' + + + + @tf.function + def execute(self,x,y,meta1,meta2): + """ + Params + ------ + x: [H,W,D,1] + y: [H,W,D,class] + + Ref + --- + - tfa.image.translate(image): image= (num_images, num_rows, num_columns, num_channels) + """ + + if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.prob: + + translate_x = tf.random.uniform([], minval=-self.translations[0], maxval=self.translations[0], dtype=tf.dtypes.int32) + translate_y = tf.random.uniform([], minval=-self.translations[1], maxval=self.translations[1], dtype=tf.dtypes.int32) + + x = tf.expand_dims(x[:,:,:,0], axis=0) # [1,H,W,D] + x = tf.transpose(x) # [D,W,H,1] + x = tfa.image.translate(x, [translate_x, translate_y], interpolation='bilinear') # [D,W,H,1]; x=(num_images, num_rows, num_columns, num_channels) + x = tf.transpose(x) # [1,H,W,D] + x = tf.expand_dims(x[0], axis=-1) # [H,W,D,1] + + y = tf.concat([ + tf.expand_dims( + tf.transpose( # [1,H,W,D] + tfa.image.translate( # [D,W,H,1] + tf.transpose( # [D,W,H,1] + tf.expand_dims(y[:,:,:,class_id], axis=0) # [1,H,W,D] + ) + , [translate_x, translate_y], interpolation='bilinear' + ) + )[0] # [H,W,D] + , axis=-1 # [H,W,D,1] + ) for class_id in range(self.class_count) + ], axis=-1) # [H,W,D,10] + + return x,y,meta1,meta2 + +class Noise: + + def __init__(self, x_shape, mean=0.0, std=0.1, prob=0.2): + + self.mean = mean + self.std = std + self.prob = prob + self.x_shape = x_shape + self.name = 'Noise' + + @tf.function + def execute(self,x,y,meta1,meta2): + """ + Params + ------ + x: [H,W,D,1] + y: [H,W,D,class] + """ + + if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.prob: + + x = x + tf.random.normal(self.x_shape, self.mean, self.std) + + return x,y,meta1,meta2 + +class Deform2Punt5D: + + def __init__(self, img_shape, label_map, grid_points=50, stddev=2.0, div_factor=2, prob=0.2, debug=False): + """ + img_shape = [H,W] or [H,W,D] + """ + + self.name = 'Deform2Punt5D' + + if debug: + import os + import psutil + import pynvml + pynvml.nvmlInit() + self.device_id = pynvml.nvmlDeviceGetHandleByIndex(0) + self.process = psutil.Process(os.getpid()) + + self.img_shape = img_shape[:2] # no deformation in z-dimension if img=3D + self.depth = img_shape[-1] + self.grid_points = grid_points + self.stddev = stddev + self.div_factor = int(div_factor) + self.debug = debug + self.label_ids = label_map.values() + self.prob = prob + + self.flattened_grid_locations = [] + + self._get_grid_control_points() + self._get_grid_locations(*self.img_shape) + + def _get_grid_control_points(self): + + # Step 1 - Define grid shape using control grid spacing & final image shape + grid_shape = np.zeros(len(self.img_shape), dtype=int) + + for idx in range(len(self.img_shape)): + num_elem = float(self.img_shape[idx]) + if num_elem % 2 == 0: + grid_shape[idx] = np.ceil( (num_elem - 1) / (2*self.grid_points) + 0.5) * 2 + 2 + else: + grid_shape[idx] = np.ceil((num_elem - 1) / (2*self.grid_points)) * 2 + 3 + + coords = [] + for i, size in enumerate(grid_shape): + coords.append(tf.linspace(-(size - 1) / 2*self.grid_points, (size - 1) / 2*self.grid_points, size)) + permutation = np.roll(np.arange(len(coords) + 1), -1) + self.grid_control_points_orig = tf.cast(tf.transpose(tf.meshgrid(*coords, indexing="ij"), permutation), dtype=tf.float32) + self.grid_control_points_orig += tf.expand_dims(tf.expand_dims(tf.constant(self.img_shape, dtype=tf.float32)/2.0,0),0) + self.grid_control_points = tf.reshape(self.grid_control_points_orig, [-1,2]) + + def _get_grid_locations(self, image_height, image_width): + + image_height = image_height//self.div_factor + image_width = image_width//self.div_factor + + y_range = np.linspace(0, image_height - 1, image_height) + x_range = np.linspace(0, image_width - 1, image_width) + y_grid, x_grid = np.meshgrid(y_range, x_range, indexing="ij") # grid_x, grid_y = tf.meshgrid(tf.range(width), tf.range(height)) + grid_locations = np.stack((y_grid, x_grid), -1) + + flattened_grid_locations = np.reshape(grid_locations, [image_height * image_width, 2]) + self.flattened_grid_locations = tf.cast(tf.expand_dims(flattened_grid_locations, 0), dtype=tf.float32) + + return flattened_grid_locations + + @tf.function + def _get_dense_flow(self, grid_control_points, grid_control_points_new): + """ + params + ----- + grid_control_points: [B, points, 2] + grid_control_points_new: [B, points, 2] + + return + dense_flows: [1,H,W,2] + """ + + # Step 1 - Get flows + source_control_point_locations = tf.cast(grid_control_points/self.div_factor, dtype=tf.float32) + dest_control_point_locations = tf.cast(grid_control_points_new/self.div_factor, dtype=tf.float32) + control_point_flows = dest_control_point_locations - source_control_point_locations + + # Step 2 - Get dense flow via bspline interp + if self.debug: + import pynvml + tf.print (' - [_get_dense_flow()] GPU mem: ', '%.4f' % (pynvml.nvmlDeviceGetMemoryInfo(self.device_id).used/1024/1024/1000),'GB), ') + tf.print (' - [_get_dense_flow()] img size: ', self.img_shape) + tf.print (' - [_get_dense_flow()] img size for bspline interp: ', self.img_shape[0]//self.div_factor, self.img_shape[1]//self.div_factor) + flattened_flows = tfa.image.interpolate_spline( + dest_control_point_locations, + control_point_flows, + self.flattened_grid_locations, + order=2 + ) + if self.debug: + import pynvml + tf.print (' - [_get_dense_flow()] GPU mem: ', '%.4f' % (pynvml.nvmlDeviceGetMemoryInfo(self.device_id).used/1024/1024/1000),'GB), ') + + # Step 3 - Reshape dense flow to original image size + dense_flows = tf.reshape(flattened_flows, [1, self.img_shape[0]//self.div_factor, self.img_shape[1]//self.div_factor, 2]) + dense_flows = tf.image.resize(dense_flows, (self.img_shape[0], self.img_shape[1])) + if self.debug: + print (' - [_get_dense_flow()] dense flow: ', dense_flows.shape) + + return dense_flows + + def execute2D(self, x, y, show=True): + """ + x = [H,W] + y = [H,W,L] + """ + + # Step 1 - Get new control points and dense flow for each slice + grid_control_points_new = self.grid_control_points + tf.random.normal(self.grid_control_points.shape, 0, self.stddev) + x_flow = self._get_dense_flow(tf.expand_dims(self.grid_control_points,0), tf.expand_dims(grid_control_points_new,0)) # [1,H,W,2] + + # Step 2 - Transform + x_tf = tf.expand_dims(tf.expand_dims(x, -1), 0) # [1,H,W,1] + y_tf = tf.expand_dims(y, 0) # [1,H,W,L] + if 0: + x_tf, x_flow = tfa.image.sparse_image_warp(x_tf, tf.expand_dims(self.grid_control_points,0), tf.expand_dims(grid_control_points_new,0)) + else: + x_tf = tfa.image.dense_image_warp(x_tf, x_flow) # [1,H,W,1] + y_tf = tf.concat([ + tfa.image.dense_image_warp( + tf.expand_dims(y_tf[:,:,:,class_id],-1), x_flow + ) for class_id in self.label_ids] + , axis=-1 + ) # [B,H,W,L] + y_tf = tf.where(tf.math.greater_equal(y_tf, 0.5), 1.0, y_tf) + y_tf = tf.where(tf.math.less(y_tf, 0.5), 0.0, y_tf) + + + # Step 99 - Show + if show: + f,axarr = plt.subplots(2,2, figsize=(15,10)) + axarr[0][0].imshow(x, cmap='gray') + y_plot = tf.argmax(y, axis=-1) # [H,W] + axarr[0][0].imshow(y_plot, alpha=0.5) + grid_og = self.grid_control_points_orig + axarr[0][0].plot(grid_og[:,:,1], grid_og[:,:,0], 'y-.', alpha=0.2) + axarr[0][0].plot(tf.transpose(grid_og[:,:,1]), tf.transpose(grid_og[:,:,0]), 'y-.', alpha=0.2) + axarr[0][0].set_xlim([0, self.img_shape[1]]) + axarr[0][0].set_ylim([0, self.img_shape[0]]) + + axarr[0][1].imshow(x_tf[0,:,:,0], cmap='gray') + y_tf_plot = tf.argmax(y_tf, axis=-1)[0,:,:] # [H,W] + axarr[0][1].imshow(y_tf_plot, alpha=0.5) + grid_new = tf.reshape(grid_control_points_new, self.grid_control_points_orig.shape) + axarr[0][1].plot(grid_new[:,:,1], grid_new[:,:,0], 'y-.', alpha=0.2) + axarr[0][1].plot(tf.transpose(grid_new[:,:,1]), tf.transpose(grid_new[:,:,0]), 'y-.', alpha=0.2) + axarr[0][1].set_xlim([0, self.img_shape[1]]) + axarr[0][1].set_ylim([0, self.img_shape[0]]) + + axarr_dy = axarr[1][0].imshow(x_flow[0,:,:,1], vmin=-7.5, vmax=7.5, interpolation='none', cmap='rainbow') + f.colorbar(axarr_dy, ax=axarr[1][0], extend='both') + axarr_dx = axarr[1][1].imshow(x_flow[0,:,:,0], vmin=-7.5, vmax=7.5, interpolation='none', cmap='rainbow') + f.colorbar(axarr_dx, ax=axarr[1][1], extend='both') + + plt.suptitle('Image Size: {} \n Control Points: {}\n StdDev: {}\nDiv Factor={}'.format(self.img_shape, self.grid_points, self.stddev, self.div_factor)) + plt.show() + pdb.set_trace() + + return x_tf, y_tf + + @tf.function + def execute(self, x, y, meta1, meta2, show=False): + """ + x = [H,W,D,1] + y = [H,W,D,L] + No deformation in z-axis + """ + + if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.prob: + if self.debug: + tf.print (' - [Deform2Punt5D()][execute()] img_shape: ', self.img_shape, ' || depth: ', self.depth) + + # Step 1 - Get new control points and dense flow for each slice + grid_control_points_new = self.grid_control_points + tf.random.normal(self.grid_control_points.shape, 0, self.stddev) + x_flow = self._get_dense_flow(tf.expand_dims(self.grid_control_points,0), tf.expand_dims(grid_control_points_new,0)) # [1,H,W,2] + + # Step 2 - Transform + if show: + idx = 70 + import matplotlib.pyplot as plt + f,axarr = plt.subplots(2,2, figsize=(15,10)) + + x_slice = x[:,:,idx,0] # [H,W] + axarr[0][0].imshow(x_slice, cmap='gray') + y_slice = tf.argmax(y[:,:,idx,:], axis=-1) # [H,W] + axarr[0][0].imshow(y_slice, cmap='gray', alpha=0.5) + grid_og = self.grid_control_points_orig + axarr[0][0].plot(grid_og[:,:,1], grid_og[:,:,0], 'y-.', alpha=0.2) + axarr[0][0].plot(tf.transpose(grid_og[:,:,1]), tf.transpose(grid_og[:,:,0]), 'y-.', alpha=0.5) + axarr[0][0].set_xlim([0, self.img_shape[1]]) + axarr[0][0].set_ylim([0, self.img_shape[0]]) + + x = tf.transpose(x, [2,0,1,3]) # [H,W,D,1] -> [D,H,W,1] + y = tf.transpose(y, [2,0,1,3]) # [H,W,D,L] -> [D,H,W,L] + + if self.debug: + import pynvml + tf.print (' - [Deform2Punt5D][execute()] GPU mem: ', '%.4f' % (pynvml.nvmlDeviceGetMemoryInfo(self.device_id).used/1024/1024/1000),'GB, ') + x_flow_repeat = tf.repeat(x_flow, self.depth, axis=0) # [B,H,W,2] or [D,H,W,2] + if self.debug: + tf.print (' - [execute()] x_tf: ', x.shape, ' || x_flow_repeat: ', x_flow_repeat.shape) + x = tfa.image.dense_image_warp(x, x_flow_repeat) + x = tf.transpose(x, [1,2,0,3]) # [D,H,W,1] -> [H,W,D,1] + y = tf.concat([ + tfa.image.dense_image_warp( + tf.expand_dims(y[:,:,:,class_id],-1), x_flow_repeat + ) for class_id in self.label_ids] + , axis=-1 + ) # [D,H,W,L] + y = tf.where(tf.math.greater_equal(y, 0.5), 1.0, y) + y = tf.where(tf.math.less(y, 0.5), 0.0, y) + y = tf.transpose(y, [1,2,0,3]) # [D,H,W,L] --> [H,W,D,L] + if self.debug: + import pynvml + tf.print (' - [execute()] GPU mem: ', '%.4f' % (pynvml.nvmlDeviceGetMemoryInfo(self.device_id).used/1024/1024/1000),'GB, ') + + if show: + + x_slice = x[:,:,idx,0] # [H,W] + axarr[0][1].imshow(x_slice, cmap='gray') + y_slice = tf.argmax(y[:,:,idx,:], axis=-1) # [H,W] + axarr[0][1].imshow(y_slice, cmap='gray', alpha=0.5) + grid_new = tf.reshape(grid_control_points_new, self.grid_control_points_orig.shape) + axarr[0][1].plot(grid_new[:,:,1], grid_new[:,:,0], 'y-.', alpha=0.2) + axarr[0][1].plot(tf.transpose(grid_new[:,:,1]), tf.transpose(grid_new[:,:,0]), 'y-.', alpha=0.5) + axarr[0][1].set_xlim([0, self.img_shape[1]]) + axarr[0][1].set_ylim([0, self.img_shape[0]]) + + axarr[1][0].imshow(x_flow[0,:,:,1], cmap='gray') + axarr[1][1].imshow(x_flow[0,:,:,0], cmap='gray') + plt.show() + + # return x_tf, y, meta1, meta2 + return x, y, meta1, meta2 diff --git a/src/dataloader/dataset.py b/src/dataloader/dataset.py new file mode 100644 index 0000000..b359813 --- /dev/null +++ b/src/dataloader/dataset.py @@ -0,0 +1,53 @@ +import tensorflow as tf + +class ZipDataset: + + def __init__(self, datasets): + self.datasets = datasets + self.datasets_generators = [] + self._init_constants() + + + def _init_constants(self): + self.HU_MIN = self.datasets[0].HU_MIN + self.HU_MAX = self.datasets[0].HU_MAX + self.dataset_dir_processed = self.datasets[0].dataset_dir_processed + self.grid = self.datasets[0].grid + self.pregridnorm = self.datasets[0].pregridnorm + + def __len__(self): + length = 0 + for dataset in self.datasets: + length += len(dataset) + + return length + + def generator(self): + for dataset in self.datasets: + self.datasets_generators.append(dataset.generator()) + return tf.data.experimental.sample_from_datasets(self.datasets_generators) #<_DirectedInterleaveDataset shapes: (, , , ), types: (tf.float32, tf.float32, tf.int16, tf.string)> + + def get_subdataset(self, param_name): + if type(param_name) == str: + for dataset in self.datasets: + if dataset.name == param_name: + return dataset + else: + print (' - [ERROR][ZipDataset] param_name needs to a str') + + return None + + def get_label_map(self, label_map_full=False): + if label_map_full: + return self.datasets[0].LABEL_MAP_FULL + else: + return self.datasets[0].LABEL_MAP + + def get_label_colors(self): + return self.datasets[0].LABEL_COLORS + + def get_label_weights(self): + return self.datasets[0].LABEL_WEIGHTS + + def get_mask_type(self, idx=0): + return self.datasets[idx].mask_type \ No newline at end of file diff --git a/src/dataloader/extractors/__init__.py b/src/dataloader/extractors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/dataloader/extractors/han_deepmindtcia.py b/src/dataloader/extractors/han_deepmindtcia.py new file mode 100644 index 0000000..a80175a --- /dev/null +++ b/src/dataloader/extractors/han_deepmindtcia.py @@ -0,0 +1,339 @@ +# Import private libraries +import medloader.dataloader.config as config +import medloader.dataloader.utils as utils + +# Import public libraries +import pdb +import tqdm +import nrrd +import copy +import traceback +import numpy as np +from pathlib import Path + +class HaNDeepMindTCIADownloader: + + def __init__(self, dataset_dir_raw, dataset_dir_processed): + + self.class_name = 'HaNDeepMindTCIADownloader' + self.dataset_dir_raw = dataset_dir_raw + self.dataset_dir_processed = dataset_dir_processed + + self.url_zip = 'https://github.com/deepmind/tcia-ct-scan-dataset/archive/refs/heads/master.zip' + self.unzipped_folder_name = self.url_zip.split('/')[-5] + '-' + self.url_zip.split('/')[-1].split('.')[0] + + def download(self): + + # Step 1 - Make raw directory + self.dataset_dir_raw.mkdir(parents=True, exist_ok=True) + + # Step 2 - Download .zip and then unzip it + path_zip_folder = Path(self.dataset_dir_raw, self.unzipped_folder_name + '.zip') + + if not Path(path_zip_folder).exists(): + utils.download_zip(self.url_zip, path_zip_folder, self.dataset_dir_raw) + else: + utils.read_zip(path_zip_folder, self.dataset_dir_raw) + + def sort(self): + + path_download_nrrds = Path(self.dataset_dir_raw).joinpath(self.unzipped_folder_name, 'nrrds') + if Path(path_download_nrrds).exists(): + path_download_nrrds_test_src = Path(path_download_nrrds).joinpath('test') + path_download_nrrds_val_src = Path(path_download_nrrds).joinpath('validation') + + path_nrrds_test_dest = Path(self.dataset_dir_raw).joinpath('test') + path_nrrds_val_dest = Path(self.dataset_dir_raw).joinpath('validation') + + utils.move_folder(path_download_nrrds_test_src, path_nrrds_test_dest) + utils.move_folder(path_download_nrrds_val_src, path_nrrds_val_dest) + else: + print (' - [ERROR][{}] Could not find nrrds folder: {}'.format(self.class_name, path_download_nrrds)) + + +class HaNDeepMindTCIAExtractor: + + def __init__(self, name, dataset_dir_raw, dataset_dir_processed, dataset_dir_datatypes): + + self.name = name + self.class_name = 'HaNDeepMindTCIAExtractor' + self.dataset_dir_raw = dataset_dir_raw + self.dataset_dir_processed = dataset_dir_processed + self.dataset_dir_datatypes = dataset_dir_datatypes + + self._preprint() + self._init_constants() + + def _preprint(self): + self.VOXEL_RESO = getattr(config, self.name)[config.KEY_VOXELRESO] + print ('') + print (' - [{}] VOXEL_RESO: {}'.format(self.class_name, self.VOXEL_RESO)) + print ('') + + def _init_constants(self): + + # File names + self.DATATYPE_ORIG = '.nrrd' + self.IMG_VOXEL_FILENAME = 'CT_IMAGE.nrrd' + self.MASK_VOXEL_FILENAME = 'mask.nrrd' + self.MASK_ORGANS_FOLDERNAME = 'segmentations' + + # Label information + self.dataset_config = getattr(config, self.name) + self.LABEL_MAP_MICCAI2015_DEEPMINDTCIA = self.dataset_config[config.KEY_LABELMAP_MICCAI_DEEPMINDTCIA] + self.LABEL_MAP = self.dataset_config[config.KEY_LABEL_MAP] + self.IGNORE_LABELS = self.dataset_config[config.KEY_IGNORE_LABELS] + self.LABELID_MIDPOINT = self.dataset_config[config.KEY_LABELID_MIDPOINT] + + def extract3D(self): + + if 1: + import concurrent + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + for dir_type in self.dataset_dir_datatypes: + path_extract3D = Path(self.dataset_dir_raw).joinpath(dir_type) + if Path(path_extract3D).exists(): + executor.submit(self._extract3D_patients, path_extract3D) + else: + print (' - [ERROR][{}][extract3D()] {} does not exist: '.format(self.class_name, path_extract3D)) + + else: + for dir_type in [Path(config.DATALOADER_DEEPMINDTCIA_VAL, config.DATALOADER_DEEPMINDTCIA_ONC)]: + path_extract3D = Path(self.dataset_dir_raw).joinpath(dir_type) + if Path(path_extract3D).exists(): + self._extract3D_patients(path_extract3D) + else: + print (' - [ERROR][{}][extract3D()] {} does not exist: '.format(self.class_name, path_extract3D)) + + print ('') + print (' - Note: You can view the 3D data in visualizers like MeVisLab or 3DSlicer') + print ('') + + def _extract3D_patients(self, dir_dataset): + + dir_type = Path(Path(dir_dataset).parts[-2], Path(dir_dataset).parts[-1]) + paths_global_voxel_img = [] + paths_global_voxel_mask = [] + + # Step 1 - Loop over patients of dir_type and get their img and mask paths + with tqdm.tqdm(total=len(list(dir_dataset.glob('*'))), desc='[3D][{}] Patients: '.format(str(dir_type)), disable=False) as pbar: + for _, patient_dir_path in enumerate(dir_dataset.iterdir()): + try: + if Path(patient_dir_path).is_dir(): + voxel_img_filepath, voxel_mask_filepath, _ = self._extract3D_patient(patient_dir_path) + paths_global_voxel_img.append(voxel_img_filepath) + paths_global_voxel_mask.append(voxel_mask_filepath) + pbar.update(1) + + except: + print ('') + print (' - [ERROR][{}][_extract3D_patients()] Error with patient_id: {}'.format(self.class_name, Path(patient_dir_path).parts[-3:])) + traceback.print_exc() + pdb.set_trace() + + # Step 2 - Save paths in .csvs + if len(paths_global_voxel_img) and len(paths_global_voxel_mask): + paths_global_voxel_img = list(map(lambda x: str(x), paths_global_voxel_img)) + paths_global_voxel_mask = list(map(lambda x: str(x), paths_global_voxel_mask)) + utils.save_csv(Path(self.dataset_dir_processed).joinpath(dir_type, config.DIRNAME_SAVE_3D, config.FILENAME_CSV_IMG), paths_global_voxel_img) + utils.save_csv(Path(self.dataset_dir_processed).joinpath(dir_type, config.DIRNAME_SAVE_3D, config.FILENAME_CSV_MASK), paths_global_voxel_mask) + + if len(self.VOXEL_RESO): + paths_global_voxel_img_resampled = [] + for path_global_voxel_img in paths_global_voxel_img: + path_global_voxel_img_parts = list(Path(path_global_voxel_img).parts) + patient_id = path_global_voxel_img_parts[-1].split('_')[-1].split('.')[0] + path_global_voxel_img_parts[-1] = config.FILENAME_IMG_RESAMPLED_3D.format(patient_id) + path_global_voxel_img_resampled = Path(*path_global_voxel_img_parts) + paths_global_voxel_img_resampled.append(path_global_voxel_img_resampled) + + paths_global_voxel_mask_resampled = [] + for path_global_voxel_mask in paths_global_voxel_mask: + path_global_voxel_mask_parts = list(Path(path_global_voxel_mask).parts) + patient_id = path_global_voxel_mask_parts[-1].split('_')[-1].split('.')[0] + path_global_voxel_mask_parts[-1] = config.FILENAME_MASK_RESAMPLED_3D.format(patient_id) + path_global_voxel_mask_resampled = Path(*path_global_voxel_mask_parts) + paths_global_voxel_mask_resampled.append(path_global_voxel_mask_resampled) + + utils.save_csv(Path(self.dataset_dir_processed).joinpath(dir_type, config.DIRNAME_SAVE_3D, config.FILENAME_CSV_IMG_RESAMPLED), paths_global_voxel_img_resampled) + utils.save_csv(Path(self.dataset_dir_processed).joinpath(dir_type, config.DIRNAME_SAVE_3D, config.FILENAME_CSV_MASK_RESAMPLED), paths_global_voxel_mask_resampled) + + else: + print (' - [ERROR][{}][_extract3D_patients()] Unable to save .csv'.format(self.class_name)) + pdb.set_trace() + print (' - Exiting!') + import sys; sys.exit(1) + + def _extract3D_patient(self, patient_dir): + + try: + voxel_img, voxel_img_headers, voxel_mask, voxel_mask_headers = self._get_data3D(patient_dir) + + dir_type = Path(Path(patient_dir).parts[-3], Path(patient_dir).parts[-2]) + patient_id = Path(patient_dir).parts[-1] + return self._save_data3D(dir_type, patient_id, voxel_img, voxel_img_headers, voxel_mask, voxel_mask_headers) + + except: + print (' - [ERROR][{}][_extract_patient()] path_folder: {}'.format(self.class_name, patient_dir.parts[-3:])) + traceback.print_exc() + pdb.set_trace() + + def _get_data3D(self, patient_dir): + + try: + voxel_img, voxel_mask = [], [] + voxel_img_headers, voxel_mask_headers = {}, {} + + if Path(patient_dir).exists(): + + if Path(patient_dir).is_dir(): + + # Step 1 - Get Voxel Data + path_voxel_img = Path(patient_dir).joinpath(self.IMG_VOXEL_FILENAME) + voxel_img, voxel_img_headers = self._get_voxel_img(path_voxel_img) + + # Step 2 - Get Mask Data + path_voxel_mask = Path(patient_dir).joinpath(self.MASK_VOXEL_FILENAME) + voxel_mask, voxel_mask_headers = self._get_voxel_mask(path_voxel_mask) + + else: + print (' - [ERROR][{}][get_data()]: Path does not exist: patient_dir: {}'.format(self.class_name, patient_dir)) + + return voxel_img, voxel_img_headers, voxel_mask, voxel_mask_headers + + except: + print (' - [ERROR][{}][get_data()] patient_dir: '.format(self.class_name, patient_dir.parts[-3:])) + traceback.print_exc() + pdb.set_trace() + + def _get_voxel_img(self, path_voxel, histogram=False): + + try: + if Path(path_voxel).exists(): + voxel_img_data, voxel_img_header = nrrd.read(str(path_voxel)) # shape=[H,W,D] + + if histogram: + import matplotlib.pyplot as plt + plt.hist(voxel_img_data.flatten()) + plt.show() + + return voxel_img_data, voxel_img_header + else: + print (' - [ERROR][{}][_get_voxel_img()]: Path does not exist: {}'.format(self.class_name, path_voxel)) + except: + traceback.print_exc() + pdb.set_trace() + + def _get_voxel_mask(self, path_voxel_mask): + + try: + + # Step 1 - Get mask data and headers + if 0: #Path(path_voxel_mask).exists(): + voxel_mask_data, voxel_mask_headers = nrrd.read(str(path_voxel_mask)) + + else: + path_mask_folder = Path(*Path(path_voxel_mask).parts[:-1]).joinpath(self.MASK_ORGANS_FOLDERNAME) + voxel_mask_data, voxel_mask_headers = self._merge_masks(path_mask_folder) + + # Step 2 - Make a list of available headers + voxel_mask_headers = dict(voxel_mask_headers) + voxel_mask_headers[config.KEYNAME_LABEL_OARS] = [] + voxel_mask_headers[config.KEYNAME_LABEL_MISSING] = [] + label_map_inverse = {label_id: label_name for label_name, label_id in self.LABEL_MAP.items()} + label_ids_all = self.LABEL_MAP.values() + label_ids_voxel_mask = np.unique(voxel_mask_data) + for label_id in label_ids_all: + label_name = label_map_inverse[label_id] + if label_id not in label_ids_voxel_mask: + voxel_mask_headers[config.KEYNAME_LABEL_MISSING].append(label_name) + else: + voxel_mask_headers[config.KEYNAME_LABEL_OARS].append(label_name) + + return voxel_mask_data, voxel_mask_headers + + except: + traceback.print_exc() + pdb.set_trace() + + def _merge_masks(self, path_mask_folder): + + try: + voxel_mask_full = [] + voxel_mask_headers = {} + labels_oars = [] + labels_missing = [] + patient_id = Path(path_mask_folder).parts[-2] + + if Path(path_mask_folder).exists(): + with tqdm.tqdm(total=len(list(Path(path_mask_folder).glob('*{}'.format(self.DATATYPE_ORIG)))), leave=False, disable=True) as pbar_mask: + for filepath_mask in Path(path_mask_folder).iterdir(): + class_name = Path(filepath_mask).parts[-1].split(self.DATATYPE_ORIG)[0] + class_name_miccai2015 = self.LABEL_MAP_MICCAI2015_DEEPMINDTCIA.get(class_name, None) + class_id = -1 + + if class_name_miccai2015 in self.LABEL_MAP: + class_id = self.LABEL_MAP[class_name_miccai2015] + labels_oars.append(class_name_miccai2015) + voxel_mask, voxel_mask_headers = nrrd.read(str(filepath_mask)) + + if class_id not in self.IGNORE_LABELS and class_id > 0: + if len(voxel_mask_full) == 0: + voxel_mask_full = np.array(voxel_mask, copy=True) + idxs = np.argwhere(voxel_mask > 0) + voxel_mask_full[idxs[:,0], idxs[:,1], idxs[:,2]] = class_id + if 0: + print (' - [merge_masks()] class_id:', class_id, ' || name: ', class_name_miccai2015, ' || idxs: ', len(idxs)) + print (' --- [merge_masks()] label_ids: ', np.unique(voxel_mask_full)) + + pbar_mask.update(1) + + path_mask = Path(*Path(path_mask_folder).parts[:-1]).joinpath(self.MASK_VOXEL_FILENAME) + nrrd.write(str(path_mask), voxel_mask_full, voxel_mask_headers) + + else: + print (' - [ERROR][{}][_merge_masks()] path_mask_folder: {} does not exist '.format(self.class_name, path_mask_folder)) + + return voxel_mask_full, voxel_mask_headers + + except: + traceback.print_exc() + pdb.set_trace() + + def _save_data3D(self, dir_type, patient_id, voxel_img, voxel_img_headers, voxel_mask, voxel_mask_headers): + + try: + + # Step 1 - Create directory + voxel_savedir = Path(self.dataset_dir_processed).joinpath(dir_type, config.DIRNAME_SAVE_3D) + + # Step 2.1 - Save voxel + voxel_img_headers_new = {config.TYPE_VOXEL_ORIGSHAPE:{}} + voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_PIXEL_SPACING] = voxel_img_headers[config.KEY_NRRD_PIXEL_SPACING][voxel_img_headers[config.KEY_NRRD_PIXEL_SPACING] > 0].tolist() + voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_ORIGIN] = voxel_img_headers[config.KEY_NRRD_ORIGIN].tolist() + voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_SHAPE] = voxel_img_headers[config.KEY_NRRD_SHAPE].tolist() + voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_OTHERS] = voxel_img_headers + voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_LABEL_MISSING] = voxel_mask_headers[config.KEYNAME_LABEL_MISSING] + voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_LABEL_OARS] = voxel_mask_headers[config.KEYNAME_LABEL_OARS] + + if len(voxel_img) and len(voxel_mask): + resample_save = False + if len(self.VOXEL_RESO): + resample_save = True + + # Find midpoint 3D coord of self.LABELID_MIDPOINT + meanpoint_idxs = np.argwhere(voxel_mask == self.LABELID_MIDPOINT) + meanpoint_idxs_mean = np.mean(meanpoint_idxs, axis=0) + voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_MEAN_MIDPOINT] = meanpoint_idxs_mean.tolist() + + return utils.save_as_mha(voxel_savedir, patient_id, voxel_img, voxel_img_headers_new, voxel_mask + , labelid_midpoint=self.LABELID_MIDPOINT + , resample_spacing=self.VOXEL_RESO) + + else: + print (' - [ERROR][HaNMICCAI2015Extractor] Error with patient_id: ', patient_id) + + except: + traceback.print_exc() + pdb.set_trace() \ No newline at end of file diff --git a/src/dataloader/extractors/han_miccai2015.py b/src/dataloader/extractors/han_miccai2015.py new file mode 100644 index 0000000..71f1a97 --- /dev/null +++ b/src/dataloader/extractors/han_miccai2015.py @@ -0,0 +1,339 @@ +# Import private libraries +import src.config as config +import src.dataloader.utils as utils + +# Import public libraries +import pdb +import tqdm +import nrrd +import copy +import traceback +import numpy as np +from pathlib import Path + + +class HaNMICCAI2015Downloader: + + def __init__(self, dataset_dir_raw, dataset_dir_processed): + self.dataset_dir_raw = dataset_dir_raw + self.dataset_dir_processed = dataset_dir_processed + + def download(self): + self.dataset_dir_raw.mkdir(parents=True, exist_ok=True) + # Step 1 - Download .zips and unzip them + urls_zip = ['http://www.imagenglab.com/data/pddca/PDDCA-1.4.1_part1.zip' + , 'http://www.imagenglab.com/data/pddca/PDDCA-1.4.1_part2.zip' + , 'http://www.imagenglab.com/data/pddca/PDDCA-1.4.1_part3.zip'] + + import concurrent + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + for url_zip in urls_zip: + filepath_zip = Path(self.dataset_dir_raw, url_zip.split('/')[-1]) + + # Step 1.1 - Download .zip and then unzip it + if not Path(filepath_zip).exists(): + executor.submit(utils.download_zip, url_zip, filepath_zip, self.dataset_dir_raw) + else: + executor.submit(utils.read_zip, filepath_zip, self.dataset_dir_raw) + + # Step 1.2 - Unzip .zip + # executor.submit(utils.read_zip(filepath_zip, self.dataset_dir_raw) + + def sort(self, dataset_dir_datatypes, dataset_dir_datatypes_ranges): + + print ('') + import tqdm + import shutil + import numpy as np + + # Step 1 - Make necessay directories + self.dataset_dir_raw.mkdir(parents=True, exist_ok=True) + for each in dataset_dir_datatypes: + path_tmp = Path(self.dataset_dir_raw).joinpath(each) + path_tmp.mkdir(parents=True, exist_ok=True) + + # Step 2 - Sort + with tqdm.tqdm(total=len(list(Path(self.dataset_dir_raw).glob('0522*'))), desc='Sorting', leave=False) as pbar: + for path_patient in self.dataset_dir_raw.iterdir(): + if '.zip' not in path_patient.parts[-1]: #and path_patient.parts[-1] not in dataset_dir_datatypes: + try: + patient_number = Path(path_patient).parts[-1][-3:] + if patient_number.isdigit(): + folder_id = np.digitize(patient_number, dataset_dir_datatypes_ranges) + shutil.move(src=str(path_patient), dst=str(Path(self.dataset_dir_raw).joinpath(dataset_dir_datatypes[folder_id]))) + pbar.update(1) + except: + traceback.print_exc() + pdb.set_trace() + +class HaNMICCAI2015Extractor: + """ + More information on the .nrrd format can be found here: http://teem.sourceforge.net/nrrd/format.html#space + """ + + def __init__(self, name, dataset_dir_raw, dataset_dir_processed, dataset_dir_datatypes): + + self.name = name + self.dataset_dir_raw = dataset_dir_raw + self.dataset_dir_processed = dataset_dir_processed + self.dataset_dir_datatypes = dataset_dir_datatypes + self.folder_prefix = '0522' + + self._preprint() + self._init_constants() + + def _preprint(self): + self.VOXEL_RESO = getattr(config, self.name)[config.KEY_VOXELRESO] + print ('') + print (' - [HaNMICCAI2015Extractor] VOXEL_RESO: ', self.VOXEL_RESO) + print ('') + + def _init_constants(self): + + # File names + self.DATATYPE_ORIG = '.nrrd' + self.IMG_VOXEL_FILENAME = 'img.nrrd' + self.MASK_VOXEL_FILENAME = 'mask.nrrd' + + # Label information + self.LABEL_MAP = getattr(config, self.name)[config.KEY_LABEL_MAP] + self.IGNORE_LABELS = getattr(config,self.name)[config.KEY_IGNORE_LABELS] + self.LABELID_MIDPOINT = getattr(config, self.name)[config.KEY_LABELID_MIDPOINT] + + def extract3D(self): + + import concurrent + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + for dir_type in self.dataset_dir_datatypes: + executor.submit(self._extract3D_patients, Path(self.dataset_dir_raw).joinpath(dir_type)) + # for dir_type in ['train']: + # self._extract3D_patients(Path(self.dataset_dir_raw).joinpath(dir_type)) + + print ('') + print (' - Note: You can view the 3D data in visualizers like MeVisLab or 3DSlicer') + print ('') + + def _extract3D_patients(self, dir_dataset): + + dir_type = Path(dir_dataset).parts[-1] + paths_global_voxel_img = [] + paths_global_voxel_mask = [] + + # Step 1 - Loop over patients of dir_type and get their img and mask paths + dir_type_idx = self.dataset_dir_datatypes.index(dir_type) + with tqdm.tqdm(total=len(list(dir_dataset.glob('*'))), desc='[3D][{}] Patients: '.format(dir_type), disable=False, position=dir_type_idx) as pbar: + for _, patient_dir_path in enumerate(dir_dataset.iterdir()): + try: + if Path(patient_dir_path).is_dir(): + voxel_img_filepath, voxel_mask_filepath, _ = self._extract3D_patient(patient_dir_path) + paths_global_voxel_img.append(voxel_img_filepath) + paths_global_voxel_mask.append(voxel_mask_filepath) + pbar.update(1) + + except: + print ('') + print (' - [ERROR][HaNMICCAI2015Extractor] Error with patient_id: ', Path(patient_dir_path).parts[-2:]) + traceback.print_exc() + pdb.set_trace() + + # Step 2 - Save paths in .csvs + if len(paths_global_voxel_img) and len(paths_global_voxel_mask): + paths_global_voxel_img = list(map(lambda x: str(x), paths_global_voxel_img)) + paths_global_voxel_mask = list(map(lambda x: str(x), paths_global_voxel_mask)) + utils.save_csv(Path(self.dataset_dir_processed).joinpath(dir_type, config.DIRNAME_SAVE_3D, config.FILENAME_CSV_IMG), paths_global_voxel_img) + utils.save_csv(Path(self.dataset_dir_processed).joinpath(dir_type, config.DIRNAME_SAVE_3D, config.FILENAME_CSV_MASK), paths_global_voxel_mask) + + if len(self.VOXEL_RESO): + paths_global_voxel_img_resampled = [] + for path_global_voxel_img in paths_global_voxel_img: + path_global_voxel_img_parts = list(Path(path_global_voxel_img).parts) + patient_id = path_global_voxel_img_parts[-1].split('_')[-1].split('.')[0] + path_global_voxel_img_parts[-1] = config.FILENAME_IMG_RESAMPLED_3D.format(patient_id) + path_global_voxel_img_resampled = Path(*path_global_voxel_img_parts) + paths_global_voxel_img_resampled.append(path_global_voxel_img_resampled) + + paths_global_voxel_mask_resampled = [] + for path_global_voxel_mask in paths_global_voxel_mask: + path_global_voxel_mask_parts = list(Path(path_global_voxel_mask).parts) + patient_id = path_global_voxel_mask_parts[-1].split('_')[-1].split('.')[0] + path_global_voxel_mask_parts[-1] = config.FILENAME_MASK_RESAMPLED_3D.format(patient_id) + path_global_voxel_mask_resampled = Path(*path_global_voxel_mask_parts) + paths_global_voxel_mask_resampled.append(path_global_voxel_mask_resampled) + + utils.save_csv(Path(self.dataset_dir_processed).joinpath(dir_type, config.DIRNAME_SAVE_3D, config.FILENAME_CSV_IMG_RESAMPLED), paths_global_voxel_img_resampled) + utils.save_csv(Path(self.dataset_dir_processed).joinpath(dir_type, config.DIRNAME_SAVE_3D, config.FILENAME_CSV_MASK_RESAMPLED), paths_global_voxel_mask_resampled) + + else: + print (' - [ERROR][HaNMICCAI2015Extractor] Unable to save .csv') + pdb.set_trace() + print (' - Exiting!') + import sys; sys.exit(1) + + def _extract3D_patient(self, patient_dir): + + try: + voxel_img, voxel_img_headers, voxel_mask, voxel_mask_headers = self._get_data3D(patient_dir) + + dir_type = Path(patient_dir).parts[-2] + patient_id = Path(patient_dir).parts[-1] + return self._save_data3D(dir_type, patient_id, voxel_img, voxel_img_headers, voxel_mask, voxel_mask_headers) + + except: + print (' - [ERROR][_extract_patient()] path_folder: ', patient_dir.parts[-1]) + traceback.print_exc() + pdb.set_trace() + + def _get_data3D(self, patient_dir): + try: + voxel_img, voxel_mask = [], [] + voxel_img_headers, voxel_mask_headers = {}, {} + if Path(patient_dir).exists(): + + if Path(patient_dir).is_dir(): + # Step 1 - Get Voxel Data + path_voxel_img = Path(patient_dir).joinpath(self.IMG_VOXEL_FILENAME) + voxel_img, voxel_img_headers = self._get_voxel_img(path_voxel_img) + + # Step 2 - Get Mask Data + path_voxel_mask = Path(patient_dir).joinpath(self.MASK_VOXEL_FILENAME) + voxel_mask, voxel_mask_headers = self._get_voxel_mask(path_voxel_mask) + + else: + print (' - Error: Path does not exist: patient_dir', patient_dir) + + return voxel_img, voxel_img_headers, voxel_mask, voxel_mask_headers + + except: + print (' - [ERROR][get_data()] patient_dir: ', patient_dir.parts[-1]) + traceback.print_exc() + pdb.set_trace() + + def _get_voxel_img(self, path_voxel, histogram=False): + try: + if path_voxel.exists(): + voxel_img_data, voxel_img_header = nrrd.read(str(path_voxel)) # shape=[H,W,D] + + if histogram: + import matplotlib.pyplot as plt + plt.hist(voxel_img_data.flatten()) + plt.show() + + return voxel_img_data, voxel_img_header + else: + print (' - Error: Path does not exist: ', path_voxel) + except: + traceback.print_exc() + pdb.set_trace() + + def _get_voxel_mask(self, path_voxel_mask): + try: + + # Step 1 - Get mask data and headers + if Path(path_voxel_mask).exists(): + voxel_mask_data, voxel_mask_headers = nrrd.read(str(path_voxel_mask)) + + else: + path_mask_folder = Path(*Path(path_voxel_mask).parts[:-1]).joinpath('structures') + voxel_mask_data, voxel_mask_headers = self._merge_masks(path_mask_folder) + + # Step 2 - Make a list of available headers + voxel_mask_headers = dict(voxel_mask_headers) + voxel_mask_headers[config.KEYNAME_LABEL_OARS] = [] + voxel_mask_headers[config.KEYNAME_LABEL_MISSING] = [] + label_map_inverse = {label_id: label_name for label_name, label_id in self.LABEL_MAP.items()} + label_ids_all = self.LABEL_MAP.values() + label_ids_voxel_mask = np.unique(voxel_mask_data) + for label_id in label_ids_all: + label_name = label_map_inverse[label_id] + if label_id not in label_ids_voxel_mask: + voxel_mask_headers[config.KEYNAME_LABEL_MISSING].append(label_name) + else: + voxel_mask_headers[config.KEYNAME_LABEL_OARS].append(label_name) + + return voxel_mask_data, voxel_mask_headers + + except: + traceback.print_exc() + pdb.set_trace() + + def _merge_masks(self, path_mask_folder): + + try: + voxel_mask_full = [] + voxel_mask_headers = {} + labels_oars = [] + labels_missing = [] + if Path(path_mask_folder).exists(): + with tqdm.tqdm(total=len(list(Path(path_mask_folder).glob('*{}'.format(self.DATATYPE_ORIG)))), leave=False, disable=True) as pbar_mask: + for filepath_mask in Path(path_mask_folder).iterdir(): + class_name = Path(filepath_mask).parts[-1].split(self.DATATYPE_ORIG)[0] + class_id = -1 + + if class_name in self.LABEL_MAP: + class_id = self.LABEL_MAP[class_name] + labels_oars.append(class_name) + voxel_mask, voxel_mask_headers = nrrd.read(str(filepath_mask)) + # else: + # print (' - [ERROR][HaNMICCAI2015Extractor][_merge_masks] Unknown class name: ', class_name) + + if class_id not in self.IGNORE_LABELS and class_id > 0: + if len(voxel_mask_full) == 0: + voxel_mask_full = copy.deepcopy(voxel_mask) + idxs = np.argwhere(voxel_mask > 0) + voxel_mask_full[idxs[:,0], idxs[:,1], idxs[:,2]] = class_id + if 0: + print (' - [merge_masks()] class_id:', class_id, ' || name: ', class_name, ' || idxs: ', len(idxs)) + print (' --- [merge_masks()] label_ids: ', np.unique(voxel_mask_full)) + + pbar_mask.update(1) + + path_mask = Path(*Path(path_mask_folder).parts[:-1]).joinpath(self.MASK_VOXEL_FILENAME) + nrrd.write(str(path_mask), voxel_mask_full, voxel_mask_headers) + + else: + print (' - Error with path_mask_folder: ', path_mask_folder) + + return voxel_mask_full, voxel_mask_headers + + except: + traceback.print_exc() + pdb.set_trace() + + def _save_data3D(self, dir_type, patient_id, voxel_img, voxel_img_headers, voxel_mask, voxel_mask_headers): + + try: + + # Step 1 - Create directory + voxel_savedir = Path(self.dataset_dir_processed).joinpath(dir_type, config.DIRNAME_SAVE_3D) + + # Step 2.1 - Save voxel + voxel_img_headers_new = {config.TYPE_VOXEL_ORIGSHAPE:{}} + voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_PIXEL_SPACING] = voxel_img_headers[config.KEY_NRRD_PIXEL_SPACING][voxel_img_headers[config.KEY_NRRD_PIXEL_SPACING] > 0].tolist() + voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_ORIGIN] = voxel_img_headers[config.KEY_NRRD_ORIGIN].tolist() + voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_SHAPE] = voxel_img_headers[config.KEY_NRRD_SHAPE].tolist() + voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_OTHERS] = voxel_img_headers + voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_LABEL_MISSING] = voxel_mask_headers[config.KEYNAME_LABEL_MISSING] + voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_LABEL_OARS] = voxel_mask_headers[config.KEYNAME_LABEL_OARS] + + if len(voxel_img) and len(voxel_mask): + resample_save = False + if len(self.VOXEL_RESO): + resample_save = True + + # Find average HU value in the brainstem + meanpoint_idxs = np.argwhere(voxel_mask == self.LABELID_MIDPOINT) + meanpoint_idxs_mean = np.mean(meanpoint_idxs, axis=0) + voxel_img_headers_new[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_MEAN_MIDPOINT] = meanpoint_idxs_mean.tolist() + + return utils.save_as_mha(voxel_savedir, patient_id, voxel_img, voxel_img_headers_new, voxel_mask + , labelid_midpoint=self.LABELID_MIDPOINT + , resample_spacing=self.VOXEL_RESO) + + else: + print (' - [ERROR][HaNMICCAI2015Extractor] Error with patient_id: ', patient_id) + + except: + traceback.print_exc() + pdb.set_trace() diff --git a/src/dataloader/han_deepmindtcia.py b/src/dataloader/han_deepmindtcia.py new file mode 100644 index 0000000..f2169d2 --- /dev/null +++ b/src/dataloader/han_deepmindtcia.py @@ -0,0 +1,726 @@ +# Import private libraries +import src.config as config +import src.dataloader.utils as utils + +# Import public libraries +import pdb +import time +import json +import itertools +import traceback +import numpy as np +from pathlib import Path + +import tensorflow as tf + + +class HaNDeepMindTCIADataset: + """ + Contains data from https://github.com/deepmind/tcia-ct-scan-dataset + - Ref: Deep learning to achieve clinically applicable segmentation of head and neck anatomy for radiotherapy + """ + + def __init__(self, data_dir, dir_type, annotator_type + , grid=True, crop_init=False, resampled=False, mask_type=config.MASK_TYPE_ONEHOT + , transforms=[], filter_grid=False, random_grid=False, pregridnorm=False + , patient_shuffle=False + , centred_dataloader_prob = 0.0 + , parallel_calls=None, deterministic=False + , debug=False): + + self.name = '{}_DeepMindTCIA'.format(config.HEAD_AND_NECK) + self.class_name = 'HaNDeepMindTCIADataset' + + # Params - Source + self.data_dir = data_dir + self.dir_type = dir_type + self.annotator_type = annotator_type + + # Params - Spatial (x,y) + self.grid = grid + self.crop_init = crop_init + self.resampled = resampled + self.mask_type = mask_type + + # Params - Transforms/Filters + self.transforms = transforms + self.filter_grid = filter_grid + self.random_grid = random_grid + self.pregridnorm = pregridnorm + + # Params - Memory related + self.patient_shuffle = patient_shuffle + + # Params - Dataset related + self.centred_dataloader_prob = centred_dataloader_prob + + # Params - TFlow Dataloader related + self.parallel_calls = parallel_calls # [1, tf.data.experimental.AUTOTUNE] + self.deterministic = deterministic + + # Params - Debug + self.debug = debug + + # Data items + self.data = {} + self.paths_img = [] + self.paths_mask = [] + self.cache = {} + self.filter = None + + # Config + self.dataset_config = getattr(config, self.name) + + # Function calls + self._download() + self._init_data() + + def __len__(self): + + if self.grid: + + if self.crop_init: + + if self.voxel_shape_cropped == [240,240,80]: + + if self.grid_size == [240,240,80]: + return len(self.paths_img) * 1 + elif self.grid_size == [240,240,40]: + return len(self.paths_img) * 2 + elif self.grid_size == [140,140,40]: + return len(self.paths_img) * 8 + else: + return None + + else: + return len(self.paths_img) + + def _download(self): + + self.dataset_dir = Path(self.data_dir).joinpath(self.name) + self.dataset_dir_raw = Path(self.dataset_dir).joinpath(config.DIRNAME_RAW) + + self.VOXEL_RESO = self.dataset_config[config.KEY_VOXELRESO] + if self.VOXEL_RESO == (0.8,0.8,2.5): + self.dataset_dir_processed = Path(self.dataset_dir).joinpath(config.DIRNAME_PROCESSED) + self.dataset_dir_processed_3D = Path(self.dataset_dir_processed).joinpath(self.dir_type, self.annotator_type, config.DIRNAME_SAVE_3D) + self.dataset_dir_datatypes = [ + Path(config.DATALOADER_DEEPMINDTCIA_TEST, config.DATALOADER_DEEPMINDTCIA_ONC) + , Path(config.DATALOADER_DEEPMINDTCIA_TEST, config.DATALOADER_DEEPMINDTCIA_RAD) + , Path(config.DATALOADER_DEEPMINDTCIA_VAL, config.DATALOADER_DEEPMINDTCIA_ONC) + , Path(config.DATALOADER_DEEPMINDTCIA_VAL, config.DATALOADER_DEEPMINDTCIA_RAD) + ] + + if not Path(self.dataset_dir_raw).exists() or not Path(self.dataset_dir_processed).exists(): + print ('') + print (' ------------------ {} Dataset ------------------'.format(self.name)) + + if not Path(self.dataset_dir_raw).exists(): + print ('') + print (' ------------------ Download Data ------------------') + from src.dataloader.extractors.han_deepmindtcia import HaNDeepMindTCIADownloader + downloader = HaNDeepMindTCIADownloader(self.dataset_dir_raw, self.dataset_dir_processed) + downloader.download() + downloader.sort() + + if not Path(self.dataset_dir_processed_3D).exists(): + print ('') + print (' ------------------ Process Data (3D) ------------------') + from src.dataloader.extractors.han_deepmindtcia import HaNDeepMindTCIAExtractor + extractor = HaNDeepMindTCIAExtractor(self.name, self.dataset_dir_raw, self.dataset_dir_processed, self.dataset_dir_datatypes) + extractor.extract3D() + print ('') + print (' ------------------------- * -------------------------') + print ('') + + def _init_data(self): + + # Step 0 - Init vars + self.patient_meta_info = {} + self.path_img_csv = '' + self.path_mask_csv = '' + if self.VOXEL_RESO == (0.8,0.8,2.5): + self.patients_z_prob = [] + else: + self.patients_z_prob = [] + + # Step 1 - Define global paths + self.data_dir_processed = Path(self.dataset_dir_processed).joinpath(self.dir_type, self.annotator_type) + if not Path(self.data_dir_processed).exists(): + print (' - [ERROR][{}][_init_data()] Processed Dir Path issue: {}'.format(self.class_name, self.data_dir_processed)) + self.data_dir_processed_3D = Path(self.data_dir_processed).joinpath(config.DIRNAME_SAVE_3D) + + # Step 2.1 - Get paths for 2D/3D + if self.resampled is False: + self.path_img_csv = Path(self.data_dir_processed_3D).joinpath(config.FILENAME_CSV_IMG) + self.path_mask_csv = Path(self.data_dir_processed_3D).joinpath(config.FILENAME_CSV_MASK) + else: + self.path_img_csv = Path(self.data_dir_processed_3D).joinpath(config.FILENAME_CSV_IMG_RESAMPLED) + self.path_mask_csv = Path(self.data_dir_processed_3D).joinpath(config.FILENAME_CSV_MASK_RESAMPLED) + + # Step 2.2 - Get file paths + if Path(self.path_img_csv).exists() and Path(self.path_mask_csv).exists(): + self.paths_img = utils.read_csv(self.path_img_csv) + self.paths_mask = utils.read_csv(self.path_mask_csv) + + exit_condition = False + for path_img in self.paths_img: + if not Path(path_img).exists(): + print (' - [ERROR][{}][_init_data()] path issue: path_img: {}/{}, {}'.format(self.class_name, self.dir_type, self.annotator_type, path_img)) + exit_condition = True + for path_mask in self.paths_mask: + if not Path(path_mask).exists(): + print (' - [ERROR][{}][_init_data()] path issue: path_mask: {}/{}, {}'.format(self.class_name, self.dir_type, self.annotator_type, path_mask)) + exit_condition = True + + if exit_condition: + print ('\n - [ERROR][{}][_init_data()] Exiting due to path issues'.format(self.class_name)) + import sys; sys.exit(1) + + else: + print (' - [ERROR][{}] Issue with path'.format(self.class_name)) + print (' -- [ERROR][{}] self.path_img_csv : ({}) {}'.format(self.class_name, Path(self.path_img_csv).exists(), self.path_img_csv )) + print (' -- [ERROR][{}] self.path_mask_csv: ({}) {}'.format(self.class_name, Path(self.path_mask_csv).exists(), self.path_mask_csv )) + + # Step 3.1 - Meta for labels + self.LABEL_MAP = self.dataset_config[config.KEY_LABEL_MAP] + self.LABEL_COLORS = self.dataset_config[config.KEY_LABEL_COLORS] + if len(self.dataset_config[config.KEY_LABEL_WEIGHTS]): + self.LABEL_WEIGHTS = np.array(self.dataset_config[config.KEY_LABEL_WEIGHTS]) + self.LABEL_WEIGHTS = (self.LABEL_WEIGHTS / np.sum(self.LABEL_WEIGHTS)).tolist() + else: + self.LABEL_WEIGHTS = [] + + # Step 3.2 - Meta for voxel HU + self.HU_MIN = self.dataset_config['HU_MIN'] + self.HU_MAX = self.dataset_config['HU_MAX'] + + # Step 4 - Get patient meta info + if self.resampled is False: + print (' - [Warning][{}]: This dataloader is not extracting 3D Volumes which have been resampled to the same 3D voxel spacing: {}/{}'.format(self.class_name, self.dir_type, self.annotator_type)) + for path_img in self.paths_img: + path_img_json = Path(path_img).parent.absolute().joinpath(config.FILENAME_VOXEL_INFO) + with open(str(path_img_json), 'r') as fp: + patient_config_file = json.load(fp) + + patient_id = Path(path_img).parts[-2] + midpoint_idxs_mean = [] + missing_label_names = [] + voxel_shape_resampled = [] + voxel_shape_orig = [] + if self.resampled is False: + midpoint_idxs_mean = patient_config_file[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_MEAN_MIDPOINT] + missing_label_names = patient_config_file[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_LABEL_MISSING] + voxel_shape_orig = patient_config_file[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_SHAPE] + else: + if config.TYPE_VOXEL_RESAMPLED in patient_config_file: + midpoint_idxs_mean = patient_config_file[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_MEAN_MIDPOINT] + missing_label_names = patient_config_file[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_LABEL_MISSING] + voxel_shape_orig = patient_config_file[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_SHAPE] + voxel_shape_resampled = patient_config_file[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_SHAPE] + + else: + print (' - [ERROR][{}][_init_data()] There is no resampled data and you have set resample=True'.format(self.class_name)) + print (' -- Delete _data/{}/processed directory and set VOXEL_RESO to a tuple of pixel spacing values for x,y,z axes'.format(self.name)) + print (' -- Exiting now! ') + import sys; sys.exit(1) + + if len(midpoint_idxs_mean): + midpoint_idxs_mean = np.array(midpoint_idxs_mean).astype(config.DATATYPE_VOXEL_IMG).tolist() + self.patient_meta_info[patient_id] = { + config.KEYNAME_MEAN_MIDPOINT : midpoint_idxs_mean + , config.KEYNAME_LABEL_MISSING : missing_label_names + , config.KEYNAME_SHAPE_ORIG : voxel_shape_orig + , config.KEYNAME_SHAPE_RESAMPLED : voxel_shape_resampled + } + else: + if self.crop_init: + print ('') + print (' - [ERROR][{}][_init_data()] Crop is set to true and there is no midpoint mean idx'.format(self.class_name)) + print (' -- Exiting now! ') + import sys; sys.exit(1) + + # Step 6 - Meta for grid sampling + if self.resampled: + + if self.crop_init: + self.crop_info = self.dataset_config[config.KEY_PREPROCESS][config.TYPE_VOXEL_RESAMPLED][config.KEY_CROP][str(self.VOXEL_RESO)] + self.w_lateral_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_W_LEFT] + self.w_medial_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_W_RIGHT] + self.h_posterior_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_H_BACK] + self.h_anterior_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_H_FRONT] + self.d_cranial_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_D_TOP] + self.d_caudal_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_D_BOTTOM] + self.voxel_shape_cropped = [self.w_lateral_crop+self.w_medial_crop + , self.h_posterior_crop+self.h_anterior_crop + , self.d_cranial_crop+self.d_caudal_crop] + + if self.grid: + grid_3D_params = self.dataset_config[config.KEY_GRID_3D][config.TYPE_VOXEL_RESAMPLED][str(self.VOXEL_RESO)] + self.grid_size = grid_3D_params[config.KEY_GRID_SIZE] + self.grid_overlap = grid_3D_params[config.KEY_GRID_OVERLAP] + self.SAMPLER_PERC = grid_3D_params[config.KEY_GRID_SAMPLER_PERC] + self.RANDOM_SHIFT_MAX = grid_3D_params[config.KEY_GRID_RANDOM_SHIFT_MAX] + self.RANDOM_SHIFT_PERC = grid_3D_params[config.KEY_GRID_RANDOM_SHIFT_PERC] + + self.w_grid, self.h_grid, self.d_grid = self.grid_size + self.w_overlap, self.h_overlap, self.d_overlap = self.grid_overlap + + else: + + if self.crop_init: + self.w_grid, self.h_grid, self.d_grid = self.voxel_shape_cropped + else: + print (' - [ERROR][HaNMICCAI2015Dataset] No info present for non-grid cropping') + + else: + if self.crop_init: + self.crop_info = self.dataset_config[config.KEY_PREPROCESS][config.TYPE_VOXEL_ORIGSHAPE][config.KEY_CROP][str(self.VOXEL_RESO)] + self.w_lateral_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_W_LEFT] + self.w_medial_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_W_RIGHT] + self.h_posterior_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_H_BACK] + self.h_anterior_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_H_FRONT] + self.d_cranial_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_D_TOP] + self.d_caudal_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_D_BOTTOM] + self.voxel_shape_cropped = [self.w_lateral_crop+self.w_medial_crop + , self.h_posterior_crop+self.h_anterior_crop + , self.d_cranial_crop+self.d_caudal_crop] + + if self.grid: + pass + else: + if self.crop_init: + self.w_grid, self.h_grid, self.d_grid = self.voxel_shape_cropped + else: + print (' - [ERROR][HaNMICCAI2015Dataset] No info present for non-crop size') + + def generator(self): + """ + - Note: + - In general, even when running your model on an accelerator like a GPU or TPU, the tf.data pipelines are run on the CPU + - Ref: https://www.tensorflow.org/guide/data_performance_analysis#analysis_workflow + """ + + try: + + if len(self.paths_img) and len(self.paths_mask): + + # Step 1 - Create basic generator + dataset = None + dataset = tf.data.Dataset.from_generator(self._generator3D + , output_types=(config.DATATYPE_TF_FLOAT32, config.DATATYPE_TF_UINT8, config.DATATYPE_TF_INT32, tf.string) + ,args=()) + + # Step 2 - Get 3D data + dataset = dataset.map(self._get_data_3D, num_parallel_calls=self.parallel_calls, deterministic=self.deterministic) + + # Step 3 - Filter function + if self.filter_grid: + dataset = dataset.filter(self.filter.execute) + + # Step 4 - Data augmentations + if len(self.transforms): + for transform in self.transforms: + try: + dataset = dataset.map(transform.execute, num_parallel_calls=self.parallel_calls, deterministic=self.deterministic) + except: + traceback.print_exc() + print (' - [ERROR][{}][generator()] Issue with transform: {}'.format(self.class_name, transform.name)) + else: + print ('') + print (' - [INFO][{}][generator()] No transformations available! - {}, {}'.format(self.class_name, self.dir_type, self.annotator_type)) + print ('') + + # Step 6 - Return + return dataset + + else: + return None + + except: + traceback.print_exc() + pdb.set_trace() + return None + + def _get_paths(self, idx): + patient_id = '' + study_id = '' + path_img, path_mask = '', '' + + if self.debug: + path_img = Path(self.paths_img[0]).absolute() + path_mask = Path(self.paths_mask[0]).absolute() + path_img, path_mask = self.path_debug_3D(path_img, path_mask) + else: + path_img = Path(self.paths_img[idx]).absolute() + path_mask = Path(self.paths_mask[idx]).absolute() + + if path_img.exists() and path_mask.exists(): + patient_id = Path(path_img).parts[-2] + study_id = Path(path_img).parts[-5] + '-' + Path(path_img).parts[-4] + else: + print (' - [ERROR] Issue with path') + print (' -- [ERROR][{}] path_img : {}'.format(self.class_name, path_img)) + print (' -- [ERROR][{}] path_mask: {}'.format(self.class_name, path_mask)) + + return path_img, path_mask, patient_id, study_id + + def _generator3D(self): + + # Step 0 - Init + res = [] + + # Step 1 - Get patient idxs + idxs = np.arange(len(self.paths_img)).tolist() #[:3] + if self.single_sample: idxs = idxs[0:1] # [2:4] + if self.patient_shuffle: np.random.shuffle(idxs) + + # Step 2 - Proceed on the basis of grid sampling or full-volume (self.grid=False) sampling + if self.grid: + + # Step 2.1 - Get grid sampler info for each patient-idx + sampler_info = {} + for idx in idxs: + path_img = Path(self.paths_img[idx]).absolute() + patient_id = path_img.parts[-2] + + if config.TYPE_VOXEL_RESAMPLED in str(path_img): + voxel_shape = self.patient_meta_info[patient_id][config.KEYNAME_SHAPE_RESAMPLED] + else: + voxel_shape = self.patient_meta_info[patient_id][config.KEYNAME_SHAPE_ORIG] + + if self.crop_init: + if patient_id in self.patients_z_prob: + voxel_shape[0] = self.voxel_shape_cropped[0] + voxel_shape[1] = self.voxel_shape_cropped[1] + else: + voxel_shape = self.voxel_shape_cropped + + grid_idxs_width = utils.split_into_overlapping_grids(voxel_shape[0], len_grid=self.grid_size[0], len_overlap=self.grid_overlap[0]) + grid_idxs_height = utils.split_into_overlapping_grids(voxel_shape[1], len_grid=self.grid_size[1], len_overlap=self.grid_overlap[1]) + grid_idxs_depth = utils.split_into_overlapping_grids(voxel_shape[2], len_grid=self.grid_size[2], len_overlap=self.grid_overlap[2]) + sampler_info[idx] = list(itertools.product(grid_idxs_width,grid_idxs_height,grid_idxs_depth)) + + # Step 2.2 - Loop over all patients and their grids + # Note - Grids of a patient are extracted in order + for i, idx in enumerate(idxs): + path_img, path_mask, patient_id, study_id = self._get_paths(idx) + missing_labels = self.patient_meta_info[patient_id][config.KEYNAME_LABEL_MISSING] + bgd_mask = 1 # by default + if len(missing_labels): + bgd_mask = 0 + if path_img.exists() and path_mask.exists(): + for sample_info in sampler_info[idx]: + grid_idxs = sample_info + meta1 = [idx] + [grid_idxs[0][0], grid_idxs[1][0], grid_idxs[2][0]] # only include w_start, h_start, d_start + meta2 = '-'.join([self.name, study_id, patient_id + '_resample_' + str(self.resampled)]) + path_img = str(path_img) + path_mask = str(path_mask) + res.append((path_img, path_mask, meta1, meta2, bgd_mask)) + + else: + label_names = list(self.LABEL_MAP.keys()) + for i, idx in enumerate(idxs): + path_img, path_mask, patient_id, study_id = self._get_paths(idx) + missing_label_names = self.patient_meta_info[patient_id][config.KEYNAME_LABEL_MISSING] + bgd_mask = 1 + + # if len(missing_labels): bgd_mask = 0 + if len(missing_label_names): + if len(set(label_names) - set(missing_label_names)): + bgd_mask = 0 + + if path_img.exists() and path_mask.exists(): + meta1 = [idx] + [0,0,0] # dummy for w_start, h_start, d_start + meta2 ='-'.join([self.name, study_id, patient_id + '_resample_' + str(self.resampled)]) + path_img = str(path_img) + path_mask = str(path_mask) + res.append((path_img, path_mask, meta1, meta2, bgd_mask)) + + # Step 3 - Yield + for each in res: + path_img, path_mask, meta1, meta2, bgd_mask = each + + vol_img_npy, vol_mask_npy, spacing = self._get_cache_item_old(path_img, path_mask) + if vol_img_npy is None and vol_mask_npy is None: + vol_img_npy, vol_mask_npy, spacing = self._get_volume_from_path(path_img, path_mask) + self._set_cache_item_old(path_img, path_mask, vol_img_npy, vol_mask_npy, spacing) + + spacing = tf.constant(spacing, dtype=tf.int32) + vol_img_npy_shape = tf.constant(vol_img_npy.shape, dtype=tf.int32) + meta1 = tf.concat([meta1, spacing, vol_img_npy_shape, [bgd_mask]], axis=0) # [idx,[grid_idxs],[spacing],[shape]] + + yield (vol_img_npy, vol_mask_npy, meta1, meta2) + + def _get_cache_item_old(self, path_img, path_mask): + if 'img' in self.cache and 'mask' in self.cache: + if path_img in self.cache['img'] and path_mask in self.cache['mask']: + # print (' - [_get_cache_item()] ') + return self.cache['img'][path_img], self.cache['mask'][path_mask], self.cache['spacing'] + else: + return None, None, None + else: + return None, None, None + + def _set_cache_item_old(self, path_img, path_mask, vol_img, vol_mask, spacing): + + self.cache = { + 'img': {path_img: vol_img} + , 'mask': {path_mask: vol_mask} + , 'spacing': spacing + } + + def _set_cache_item(self, path_img, path_mask, vol_img, vol_mask, spacing): + if len(self.cache) == 0: + self.cache = {path_img: [vol_img, vol_mask, spacing]} + self.cache_id = {path_img:0} + elif len(self.cache) == 1: + self.cache[path_img] = [vol_img, vol_mask, spacing] + self.cache_id[path_img] = 1 + elif len(self.cache) == 2: + max_order_id = max(self.cache_id.values()) + for path_img_ in self.cache_id: + if self.cache_id[path_img_] == max_order_id - 1: + self.cache.pop(path_img_) + self.cache[path_img] = [vol_img, vol_mask, spacing] + self.cache_id[path_img] = max_order_id+1 + + def _get_cache_item(self, path_img, path_mask): + + if path_img in self.cache: + return self.cache[path_img] + else: + return None, None, None + + def _get_volume_from_path(self, path_img, path_mask, verbose=False): + + # Step 1 - Get volumes + if verbose: t0 = time.time() + vol_img_sitk = utils.read_mha(path_img) + vol_img_npy = utils.sitk_to_array(vol_img_sitk) + vol_mask_sitk = utils.read_mha(path_mask) + vol_mask_npy = utils.sitk_to_array(vol_mask_sitk) + spacing = np.array(vol_img_sitk.GetSpacing()) + + # Step 2 - Perform init crop on volumes + if self.crop_init: + patient_id = str(Path(path_img).parts[-2]) + mean_point = np.array(self.patient_meta_info[patient_id][config.KEYNAME_MEAN_MIDPOINT]).astype(np.uint16).tolist() + vol_img_npy_shape_prev = vol_img_npy.shape + + # Step 2.1 - Perform crops in H,W region + vol_img_npy = vol_img_npy[ + mean_point[0] - self.w_lateral_crop : mean_point[0] + self.w_medial_crop + , mean_point[1] - self.h_anterior_crop : mean_point[1] + self.h_posterior_crop + , : + ] + vol_mask_npy = vol_mask_npy[ + mean_point[0] - self.w_lateral_crop : mean_point[0] + self.w_medial_crop + , mean_point[1] - self.h_anterior_crop : mean_point[1] + self.h_posterior_crop + , : + ] + + # Step 2.2 - Perform crops in D region + if self.grid: + if self.VOXEL_RESO == (0.8,0.8,2.5): + vol_img_npy = vol_img_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop] + vol_mask_npy = vol_mask_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop] + else: + vol_img_npy = vol_img_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop] + vol_mask_npy = vol_mask_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop] + else: + if self.resampled: + if self.VOXEL_RESO == (0.8,0.8,2.5): + vol_img_npy = vol_img_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop] + vol_mask_npy = vol_mask_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop] + else: + vol_img_npy = vol_img_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop] + vol_mask_npy = vol_mask_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop] + else: + vol_img_npy = vol_img_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop] + vol_mask_npy = vol_mask_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop] + + # Step 3 - Pad (with=0) if volume is less in z-dimension + (vol_img_npy_x, vol_img_npy_y, vol_img_npy_z) = vol_img_npy.shape + if vol_img_npy_z < self.d_caudal_crop + self.d_cranial_crop: + del_z = self.d_caudal_crop + self.d_cranial_crop - vol_img_npy_z + vol_img_npy = np.concatenate((vol_img_npy, np.zeros((vol_img_npy_x, vol_img_npy_y, del_z))), axis=2) + vol_mask_npy = np.concatenate((vol_mask_npy, np.zeros((vol_img_npy_x, vol_img_npy_y, del_z))), axis=2) + + if verbose: print (' - [HaNMICCAI2015Dataset._get_volume_from_path()] Time: ({}):{}s'.format(Path(path_img).parts[-2], round(time.time() - t0,2))) + if self.pregridnorm: + vol_img_npy[vol_img_npy <= self.HU_MIN] = self.HU_MIN + vol_img_npy[vol_img_npy >= self.HU_MAX] = self.HU_MAX + vol_img_npy = (vol_img_npy -np.mean(vol_img_npy))/np.std(vol_img_npy) #Standardize (z-scoring) + + return tf.cast(vol_img_npy, dtype=config.DATATYPE_TF_FLOAT32), tf.cast(vol_mask_npy, dtype=config.DATATYPE_TF_UINT8), tf.constant(spacing*100, dtype=config.DATATYPE_TF_INT32) + + @tf.function + def _get_new_grid_idx(self, start, end, max): + + start_prev = start + if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.RANDOM_SHIFT_PERC: + + delta_left = start + delta_right = max - end + shift_voxels = tf.random.uniform([], minval=0, maxval=self.RANDOM_SHIFT_MAX, dtype=tf.dtypes.int32) + + if delta_left > self.RANDOM_SHIFT_MAX and delta_right > self.RANDOM_SHIFT_MAX: + if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.RANDOM_SHIFT_PERC: + start = start - shift_voxels + end = end - shift_voxels + else: + start = start + shift_voxels + end = end + shift_voxels + + elif delta_left > self.RANDOM_SHIFT_MAX and delta_right <= self.RANDOM_SHIFT_MAX: + start = start - shift_voxels + end = end - shift_voxels + + elif delta_left <= self.RANDOM_SHIFT_MAX and delta_right > self.RANDOM_SHIFT_MAX: + start = start + shift_voxels + end = end + shift_voxels + + return start_prev, start, end + + @tf.function + def _get_new_grid_idx_centred(self, grid_size_half, max_pt, mid_pt): + + # Step 1 - Return vars + start, end = 0,0 + + # Step 2 - Define margin on either side of mid point + margin_left = mid_pt + margin_right = max_pt - mid_pt + + # Ste p2 - Calculate vars + if margin_left >= grid_size_half and margin_right >= grid_size_half: + start = mid_pt - grid_size_half + end = mid_pt + grid_size_half + elif margin_right < grid_size_half: + if margin_left >= grid_size_half + (grid_size_half - margin_right): + end = mid_pt + margin_right + start = mid_pt - grid_size_half - (grid_size_half - margin_right) + else: + tf.print(' - [ERROR][_get_new_grid_idx_centred()] Cond 2 problem') + elif margin_left < grid_size_half: + if margin_right >= grid_size_half + (grid_size_half - margin_left): + start = mid_pt - margin_left + end = mid_pt + grid_size_half + (grid_size_half-margin_left) + else: + tf.print(' - [ERROR][_get_new_grid_idx_centred()] Cond 3 problem') + + return start, end + + @tf.function + def _get_data_3D(self, vol_img, vol_mask, meta1, meta2): + """ + Params + ------ + meta1: [idx, [w_start, h_start, d_start], [spacing_x, spacing_y, spacing_z], [shape_x, shape_y, shape_z], [bgd_mask]] + """ + + vol_img_npy = None + vol_mask_npy = None + + # Step 1 - Proceed on the basis of grid sampling or full-volume (self.grid=False) sampling + if self.grid: + + if 0: #tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.centred_dataloader_prob: + + # Step 1.1 - Get a label_id and its mean + label_id = tf.cast(tf.random.categorical(tf.math.log([self.LABEL_WEIGHTS]), 1), dtype=config.DATATYPE_TF_UINT8)[0][0] # the LABEL_WEGIHTS sum to 1; the log is added as tf.random.categorical expects logits + label_id_idxs = tf.where(tf.math.equal(label_id, vol_mask)) + label_id_idxs_mean = tf.math.reduce_mean(label_id_idxs, axis=0) + label_id_idxs_mean = tf.cast(label_id_idxs_mean, dtype=config.DATATYPE_TF_INT32) + + # Step 1.2 - Create a grid around that mid-point + w_start_prev = meta1[1] + h_start_prev = meta1[2] + d_start_prev = meta1[3] + w_max = meta1[7] + h_max = meta1[8] + d_max = meta1[9] + w_grid = self.grid_size[0] + h_grid = self.grid_size[1] + d_grid = self.grid_size[2] + w_mid = label_id_idxs_mean[0] + h_mid = label_id_idxs_mean[1] + d_mid = label_id_idxs_mean[2] + + w_start, w_end = self._get_new_grid_idx_centred(w_grid//2, w_max, w_mid) + h_start, h_end = self._get_new_grid_idx_centred(h_grid//2, h_max, h_mid) + d_start, d_end = self._get_new_grid_idx_centred(d_grid//2, d_max, d_mid) + + meta1_diff = tf.convert_to_tensor([0,w_start - w_start_prev, h_start - h_start_prev, d_start - d_start_prev,0,0,0,0,0,0,0]) + meta1 = meta1 + meta1_diff + + else: + + # tf.print(' - [INFO] regular dataloader: ', self.dir_type) + # Step 1.1 - Get raw images/masks and extract grid + w_start = meta1[1] + w_end = w_start + self.grid_size[0] + h_start = meta1[2] + h_end = h_start + self.grid_size[1] + d_start = meta1[3] + d_end = d_start + self.grid_size[2] + + # Step 1.2 - Randomization of grid + if self.random_grid: + w_max = meta1[7] + h_max = meta1[8] + d_max = meta1[9] + + w_start_prev = w_start + d_start_prev = d_start + w_start_prev, w_start, w_end = self._get_new_grid_idx(w_start, w_end, w_max) + h_start_prev, h_start, h_end = self._get_new_grid_idx(h_start, h_end, h_max) + d_start_prev, d_start, d_end = self._get_new_grid_idx(d_start, d_end, d_max) + + meta1_diff = tf.convert_to_tensor([0,w_start - w_start_prev, h_start - h_start_prev, d_start - d_start_prev,0,0,0,0,0,0,0]) + meta1 = meta1 + meta1_diff + + # Step 1.3 - Extracting grid + vol_img_npy = tf.identity(vol_img[w_start:w_end, h_start:h_end, d_start:d_end]) + vol_mask_npy = tf.identity(vol_mask[w_start:w_end, h_start:h_end, d_start:d_end]) + + + else: + vol_img_npy = vol_img + vol_mask_npy = vol_mask + + # Step 2 - One-hot or not + vol_mask_classes = [] + label_ids_mask = [] + label_ids = sorted(list(self.LABEL_MAP.values())) + if self.mask_type == config.MASK_TYPE_ONEHOT: + vol_mask_classes = tf.concat([tf.expand_dims(tf.math.equal(vol_mask_npy, label), axis=-1) for label in label_ids], axis=-1) # [H,W,D,L] + for label_id in label_ids: + label_ids_mask.append(tf.cast(tf.math.reduce_any(vol_mask_classes[:,:,:,label_id]), dtype=config.DATATYPE_TF_INT32)) + + elif self.mask_type == config.MASK_TYPE_COMBINED: + vol_mask_classes = vol_mask_npy + unique_classes, _ = tf.unique(tf.reshape(vol_mask_npy,[-1])) + unique_classes = tf.cast(unique_classes, config.DATATYPE_TF_INT32) + for label_id in label_ids: + label_ids_mask.append(tf.cast(tf.math.reduce_any(tf.math.equal(unique_classes, label_id)), dtype=config.DATATYPE_TF_INT32)) + + # Step 2.2 - Handling background mask explicitly if there is a missing label + bgd_mask = meta1[-1] + label_ids_mask[0] = bgd_mask + meta1 = meta1[:-1] # removes the bgd mask index + + # Step 3 - Dtype conversion and expading dimensions + if self.mask_type == config.MASK_TYPE_ONEHOT: + x = tf.cast(tf.expand_dims(vol_img_npy, axis=3), dtype=config.DATATYPE_TF_FLOAT32) # [H,W,D,1] + else: + x = tf.cast(vol_img_npy, dtype=config.DATATYPE_TF_FLOAT32) # [H,W,D] + y = tf.cast(vol_mask_classes, dtype=config.DATATYPE_TF_FLOAT32) # [H,W,D,L] + + # Step 4 - Append info to meta1 + meta1 = tf.concat([meta1, label_ids_mask], axis=0) + + # Step 5 - return + return (x, y, meta1, meta2) + \ No newline at end of file diff --git a/src/dataloader/han_miccai2015.py b/src/dataloader/han_miccai2015.py new file mode 100644 index 0000000..3b0c9c8 --- /dev/null +++ b/src/dataloader/han_miccai2015.py @@ -0,0 +1,765 @@ +# Import internal libraries +import src.config as config +import src.dataloader.utils as utils + +# Import external libraries +import pdb +import time +import json +import itertools +import traceback +import numpy as np +from pathlib import Path +import tensorflow as tf + + + +class HaNMICCAI2015Dataset: + """ + The 2015 MICCAI Challenge contains CT scans of the head and neck along with annotations for 9 organs. + It contains train, train_additional, test_onsite and test_offsite folders + + Dataset link: http://www.imagenglab.com/wiki/mediawiki/index.php?title=2015_MICCAI_Challenge + """ + + def __init__(self, data_dir, dir_type + , dimension=3, grid=True, crop_init=False, resampled=False, mask_type=config.MASK_TYPE_ONEHOT + , transforms=[], filter_grid=False, random_grid=False, pregridnorm=False + , parallel_calls=None, deterministic=False + , patient_shuffle=False + , centred_dataloader_prob = 0.0 + , debug=False): + + self.name = '{}_MICCAI2015'.format(config.HEAD_AND_NECK) + + # Params - Source + self.data_dir = data_dir + self.dir_type = dir_type + + # Params - Spatial (x,y) + self.dimension = dimension + self.grid = grid + self.crop_init = crop_init + self.resampled = resampled + self.mask_type = mask_type + + # Params - Transforms/Filters + self.transforms = transforms + self.filter_grid = filter_grid + self.random_grid = random_grid + self.pregridnorm = pregridnorm + + # Params - Memory related + self.patient_shuffle = patient_shuffle + + # Params - Dataset related + self.centred_dataloader_prob = centred_dataloader_prob + + # Params - TFlow Dataloader related + self.parallel_calls = parallel_calls # [1, tf.data.experimental.AUTOTUNE] + self.deterministic = deterministic + + # Params - Debug + self.debug = debug + + # Data items + self.data = {} + self.paths_img = [] + self.paths_mask = [] + self.cache = {} + self.filter = None + + # Config + self.dataset_config = getattr(config, self.name) + + # Function calls + self._download() + self._init_data() + + def __len__(self): + + if self.grid: + + if self.crop_init: + if self.voxel_shape_cropped == [240,240,80]: + if self.grid_size == [96,96,40]: + return len(self.paths_img)*(18) + elif self.grid_size == [140,140,40]: + return len(self.paths_img)*8 + elif self.grid_size == [240,240,40]: + return len(self.paths_img)*2 + elif self.grid_size == [240,240,80]: + return len(self.paths_img)*1 + + elif self.voxel_shape_cropped == [240,240,100]: + if self.grid_size == [140,140,60]: + return len(self.paths_img)*8 + + if self.filter is None and self.filter_grid is False: + if self.crop_init: + return 10*len(self.paths_img) + else: + return 200*len(self.paths_img) # i.e. approx 200 grids per volume + else: + sampler_perc_data = 1.0 - getattr(config, self.name)['GRID_3D']['SAMPLER_PERC'] + 0.1 + return int(200*len(self.paths_img)*sampler_perc_data) + + else: + return len(self.paths_img) + + def _download(self): + self.dataset_dir = Path(self.data_dir).joinpath(self.name) + self.dataset_dir_raw = Path(self.dataset_dir).joinpath(config.DIRNAME_RAW) + self.VOXEL_RESO = self.dataset_config[config.KEY_VOXELRESO] + if self.VOXEL_RESO == (0.8,0.8,2.5): + self.dataset_dir_processed = Path(self.dataset_dir).joinpath(config.DIRNAME_PROCESSED) + elif self.VOXEL_RESO == (1.0,1.0,2.0): + self.dataset_dir_processed = Path(self.dataset_dir).joinpath(config.DIRNAME_PROCESSED + 'v2') + else: + self.dataset_dir_processed = Path(self.dataset_dir).joinpath(config.DIRNAME_PROCESSED + '_v3') + + self.dataset_dir_datatypes = ['train', 'train_additional', 'test_offsite', 'test_onsite'] + self.dataset_dir_datatypes_ranges = [328+1,479+1,746+1,878+1] + self.dataset_dir_processed_3D = Path(self.dataset_dir_processed).joinpath('train', config.DIRNAME_SAVE_3D) + + if not Path(self.dataset_dir_raw).exists() or not Path(self.dataset_dir_processed).exists(): + print ('') + print (' ------------------ HaNMICCAI2015 Dataset ------------------') + + if not Path(self.dataset_dir_raw).exists(): + print ('') + print (' ------------------ Download Data ------------------') + from src.dataloader.extractors.han_miccai2015 import HaNMICCAI2015Downloader + downloader = HaNMICCAI2015Downloader(self.dataset_dir_raw, self.dataset_dir_processed) + downloader.download() + downloader.sort(self.dataset_dir_datatypes, self.dataset_dir_datatypes_ranges) + + if not Path(self.dataset_dir_processed_3D).exists(): + print ('') + print (' ------------------ Process Data (3D) ------------------') + from src.dataloader.extractors.han_miccai2015 import HaNMICCAI2015Extractor + extractor = HaNMICCAI2015Extractor(self.name, self.dataset_dir_raw, self.dataset_dir_processed, self.dataset_dir_datatypes) + extractor.extract3D() + print ('') + print (' ------------------------- * -------------------------') + print ('') + + def _init_data(self): + + # Step 0 - Init vars + self.patient_meta_info = {} + self.path_img_csv = '' + self.path_mask_csv = '' + if self.VOXEL_RESO == (0.8,0.8,2.4): + self.patients_z_prob = ['0522c0125'] + else: + self.patients_z_prob = [] + + # Step 1 - Define global paths + self.data_dir_processed = Path(self.dataset_dir_processed).joinpath(self.dir_type) + if not Path(self.data_dir_processed).exists(): + print (' - [ERROR][HaNMICCAI2015Dataset][_init_data()] Processed Dir Path issue: ', self.dir_type, self.data_dir_processed) + self.data_dir_processed_3D = Path(self.data_dir_processed).joinpath(config.DIRNAME_SAVE_3D) + + # Step 2.1 - Get paths for 3D volumes + if self.dimension == 3: + if self.resampled is False: + self.path_img_csv = Path(self.data_dir_processed_3D).joinpath(config.FILENAME_CSV_IMG) + self.path_mask_csv = Path(self.data_dir_processed_3D).joinpath(config.FILENAME_CSV_MASK) + else: + self.path_img_csv = Path(self.data_dir_processed_3D).joinpath(config.FILENAME_CSV_IMG_RESAMPLED) + self.path_mask_csv = Path(self.data_dir_processed_3D).joinpath(config.FILENAME_CSV_MASK_RESAMPLED) + + # Step 2.2 - Get file paths + if Path(self.path_img_csv).exists() and Path(self.path_mask_csv).exists(): + self.paths_img = utils.read_csv(self.path_img_csv) + self.paths_mask = utils.read_csv(self.path_mask_csv) + + exit_condition = False + for path_img in self.paths_img: + if not Path(path_img).exists(): + print (' - [ERROR][HaNMICCAI2015Dataset][_init_data()] path issue: path_img: ',self.dir_type, path_img) + exit_condition = True + for path_mask in self.paths_mask: + if not Path(path_mask).exists(): + print (' - [ERROR][HaNMICCAI2015Dataset][_init_data()] path issue: path_mask: ',self.dir_type, path_mask) + exit_condition = True + + if exit_condition: + print ('\n - [_init_data()] Exiting due to path issues') + import sys; sys.exit(1) + + else: + print (' - [ERROR] Issue with path') + print (' -- [ERROR] self.path_img_csv : ({}) {}'.format(Path(self.path_img_csv).exists(), self.path_img_csv )) + print (' -- [ERROR] self.path_mask_csv: ({}) {}'.format(Path(self.path_mask_csv).exists(), self.path_mask_csv )) + + # Step 3.1 - Meta for labels + self.LABEL_MAP = self.dataset_config[config.KEY_LABEL_MAP] + self.LABEL_COLORS = self.dataset_config[config.KEY_LABEL_COLORS] + self.LABEL_WEIGHTS = np.array(self.dataset_config[config.KEY_LABEL_WEIGHTS]) + self.LABEL_WEIGHTS = (self.LABEL_WEIGHTS / np.sum(self.LABEL_WEIGHTS)).tolist() + + # Step 3.2 - Meta for voxel HU + self.HU_MIN = self.dataset_config['HU_MIN'] + self.HU_MAX = self.dataset_config['HU_MAX'] + + # Step 4 - Get patient meta info + if self.resampled is False: + print (' - [Warning][HaNMICCAI2015Dataset]: This dataloader is not extracting 3D Volumes which have been resampled to the same 3D voxel spacing: ', self.dir_type) + for path_img in self.paths_img: + path_img_json = Path(path_img).parent.absolute().joinpath(config.FILENAME_VOXEL_INFO) + with open(str(path_img_json), 'r') as fp: + patient_config_file = json.load(fp) + + patient_id = Path(path_img).parts[-2] + midpoint_idxs_mean = [] + missing_label_names = [] + voxel_shape_resampled = [] + voxel_shape_orig = [] + if self.resampled is False: + midpoint_idxs_mean = patient_config_file[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_MEAN_MIDPOINT] + missing_label_names = patient_config_file[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_LABEL_MISSING] + voxel_shape_orig = patient_config_file[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_SHAPE] + else: + if config.TYPE_VOXEL_RESAMPLED in patient_config_file: + midpoint_idxs_mean = patient_config_file[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_MEAN_MIDPOINT] + missing_label_names = patient_config_file[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_LABEL_MISSING] + voxel_shape_orig = patient_config_file[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_SHAPE] + voxel_shape_resampled = patient_config_file[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_SHAPE] + else: + print (' - [ERROR][HaNMICCAI2015Dataset] There is no resampled data and you have set resample=True') + print (' -- Delete the data/HaNMICCAI2015Dataset/processed directory and set VOXEL_RESO to a tuple of pixel spacing values for x,y,z axes') + print (' -- Exiting now! ') + import sys; sys.exit(1) + + if len(midpoint_idxs_mean): + midpoint_idxs_mean = np.array(midpoint_idxs_mean).astype(config.DATATYPE_VOXEL_IMG).tolist() + self.patient_meta_info[patient_id] = { + config.KEYNAME_MEAN_MIDPOINT : midpoint_idxs_mean + , config.KEYNAME_LABEL_MISSING : missing_label_names + , config.KEYNAME_SHAPE_ORIG : voxel_shape_orig + , config.KEYNAME_SHAPE_RESAMPLED : voxel_shape_resampled + } + else: + if self.crop_init: + print ('') + print (' - [ERROR][HaNMICCAI2015Dataset] Crop is set to true and there is no midpoint mean idx') + print (' -- Exiting now! ') + import sys; sys.exit(1) + + # Step 6 - Meta for grid sampling + if self.dimension == 3: + if self.resampled: + + if self.crop_init: + self.crop_info = self.dataset_config[config.KEY_PREPROCESS][config.TYPE_VOXEL_RESAMPLED][config.KEY_CROP][str(self.VOXEL_RESO)] + self.w_lateral_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_W_LEFT] + self.w_medial_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_W_RIGHT] + self.h_posterior_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_H_BACK] + self.h_anterior_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_H_FRONT] + self.d_cranial_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_D_TOP] + self.d_caudal_crop = self.crop_info[config.KEY_MIDPOINT_EXTENSION_D_BOTTOM] + self.voxel_shape_cropped = [self.w_lateral_crop+self.w_medial_crop + , self.h_posterior_crop+self.h_anterior_crop + , self.d_cranial_crop+self.d_caudal_crop] + + if self.grid: + grid_3D_params = self.dataset_config[config.KEY_GRID_3D][config.TYPE_VOXEL_RESAMPLED][str(self.VOXEL_RESO)] + self.grid_size = grid_3D_params[config.KEY_GRID_SIZE] + self.grid_overlap = grid_3D_params[config.KEY_GRID_OVERLAP] + self.SAMPLER_PERC = grid_3D_params[config.KEY_GRID_SAMPLER_PERC] + self.RANDOM_SHIFT_MAX = grid_3D_params[config.KEY_GRID_RANDOM_SHIFT_MAX] + self.RANDOM_SHIFT_PERC = grid_3D_params[config.KEY_GRID_RANDOM_SHIFT_PERC] + + self.w_grid, self.h_grid, self.d_grid = self.grid_size + self.w_overlap, self.h_overlap, self.d_overlap = self.grid_overlap + + else: + + if self.crop_init: + self.w_grid, self.h_grid, self.d_grid = self.voxel_shape_cropped + else: + print (' - [ERROR][HaNMICCAI2015Dataset] No info present for non-grid cropping') + + else: + if self.crop_init: + preprocess_obj = getattr(config, self.name)['PREPROCESS'][config.TYPE_VOXEL_ORIGSHAPE] + self.crop_info = preprocess_obj['CROP'] + self.w_lateral_crop = self.crop_info['MIDPOINT_EXTENSION_W_LEFT'] + self.w_medial_crop = self.crop_info['MIDPOINT_EXTENSION_W_RIGHT'] + self.h_posterior_crop = self.crop_info['MIDPOINT_EXTENSION_H_BACK'] + self.h_anterior_crop = self.crop_info['MIDPOINT_EXTENSION_H_FRONT'] + self.d_cranial_crop = self.crop_info['MIDPOINT_EXTENSION_D_TOP'] + self.d_caudal_crop = self.crop_info['MIDPOINT_EXTENSION_D_BOTTOM'] + self.voxel_shape_cropped = [self.w_lateral_crop+self.w_medial_crop + , self.h_posterior_crop+self.h_anterior_crop + , self.d_cranial_crop+self.d_caudal_crop] + + if self.grid: + pass + else: + if self.crop_init: + self.w_grid, self.h_grid, self.d_grid = self.voxel_shape_cropped + else: + print (' - [ERROR][HaNMICCAI2015Dataset] No info present for non-crop size') + + else: + print (' - [ERROR][HaNMICCAI2015Dataset] No info present for 2D data') + + def get_voxel_stats(self, show=False): + + spacing_x = [] + spacing_y = [] + spacing_z = [] + + info_img_path = Path(self.data_dir).joinpath(self.name, config.DIRNAME_PROCESSED, self.dir_type, config.DIRNAME_SAVE_2D, config.FILENAME_VOXEL_INFO) + + if Path(info_img_path).exists(): + import json + with open(str(info_img_path), 'r') as fp: + data = json.load(fp) + for patient_id in data: + spacing_info = data[patient_id][config.TYPE_VOXEL_ORIGSHAPE]['spacing'] + spacing_x.append(spacing_info[0]) + spacing_y.append(spacing_info[1]) + spacing_z.append(spacing_info[2]) + + if show: + if len(spacing_x) and len(spacing_y) and len(spacing_z): + import matplotlib.pyplot as plt + f,axarr = plt.subplots(1,3) + axarr[0].hist(spacing_x); axarr[0].set_title('Voxel Spacing (X)') + axarr[1].hist(spacing_y); axarr[1].set_title('Voxel Spacing (Y)') + axarr[2].hist(spacing_z); axarr[2].set_title('Voxel Spacing (Z)') + plt.suptitle(self.name) + plt.show() + + else: + print (' - [ERROR][get_voxel_stats()] Path issue: info_img_path: ', info_img_path) + + return spacing_x, spacing_y, spacing_z + + def generator(self): + """ + - Note: + - In general, even when running your model on an accelerator like a GPU or TPU, the tf.data pipelines are run on the CPU + - Ref: https://www.tensorflow.org/guide/data_performance_analysis#analysis_workflow + """ + + try: + + if len(self.paths_img) and len(self.paths_mask): + + # Step 1 - Create basic generator + dataset = None + if self.dimension == 3: + dataset = tf.data.Dataset.from_generator(self._generator3D + , output_types=(config.DATATYPE_TF_FLOAT32, config.DATATYPE_TF_UINT8, config.DATATYPE_TF_INT32, tf.string) + ,args=()) + + # Step 2 - Get 3D data + if self.dimension == 3: + dataset = dataset.map(self._get_data_3D, num_parallel_calls=self.parallel_calls, deterministic=self.deterministic) + # dataset = dataset.apply(tf.data.experimental.copy_to_device(target_device='/GPU:0')) + + # Step 3 - Filter function + if self.filter_grid: + dataset = dataset.filter(self.filter.execute) + + # Step 4 - Data augmentations + if len(self.transforms): + for transform in self.transforms: + try: + dataset = dataset.map(transform.execute, num_parallel_calls=self.parallel_calls, deterministic=self.deterministic) + except: + traceback.print_exc() + print (' - [ERROR][HaNMICCAI2015Dataset] Issue with transform: ', transform.name) + else: + print ('') + print (' - [INFO][HaNMICCAI2015Dataset] No transformations available!', self.dir_type) + print ('') + + # Step 6 - Return + return dataset + + else: + return None + + except: + traceback.print_exc() + pdb.set_trace() + return None + + def _get_paths(self, idx): + patient_id = '' + study_id = '' + path_img, path_mask = '', '' + + if self.debug: + path_img = Path(self.paths_img[0]).absolute() + path_mask = Path(self.paths_mask[0]).absolute() + path_img, path_mask = self.path_debug_3D(path_img, path_mask) + else: + path_img = Path(self.paths_img[idx]).absolute() + path_mask = Path(self.paths_mask[idx]).absolute() + + if path_img.exists() and path_mask.exists(): + patient_id = Path(path_img).parts[-2] + study_id = Path(path_img).parts[-4] + else: + print (' - [ERROR] Issue with path') + print (' -- [ERROR][HaNMICCAI2015] path_img : ', path_img) + print (' -- [ERROR][HaNMICCAI2015] path_mask: ', path_mask) + + return path_img, path_mask, patient_id, study_id + + def _generator3D(self): + + # Step 0 - Init + res = [] + + # Step 1 - Get patient idxs + idxs = np.arange(len(self.paths_img)).tolist() #[:3] + if self.patient_shuffle: np.random.shuffle(idxs) + + # Step 2 - Proceed on the basis of grid sampling or full-volume (self.grid=False) sampling + if self.grid: + + # Step 2.1 - Get grid sampler info for each patient-idx + sampler_info = {} + for idx in idxs: + path_img = Path(self.paths_img[idx]).absolute() + patient_id = path_img.parts[-2] + + if config.TYPE_VOXEL_RESAMPLED in str(path_img): + voxel_shape = self.patient_meta_info[patient_id][config.KEYNAME_SHAPE_RESAMPLED] + else: + voxel_shape = self.patient_meta_info[patient_id][config.KEYNAME_SHAPE_ORIG] + + if self.crop_init: + if patient_id in self.patients_z_prob: + voxel_shape[0] = self.voxel_shape_cropped[0] + voxel_shape[1] = self.voxel_shape_cropped[1] + else: + voxel_shape = self.voxel_shape_cropped + + grid_idxs_width = utils.split_into_overlapping_grids(voxel_shape[0], len_grid=self.grid_size[0], len_overlap=self.grid_overlap[0]) + grid_idxs_height = utils.split_into_overlapping_grids(voxel_shape[1], len_grid=self.grid_size[1], len_overlap=self.grid_overlap[1]) + grid_idxs_depth = utils.split_into_overlapping_grids(voxel_shape[2], len_grid=self.grid_size[2], len_overlap=self.grid_overlap[2]) + sampler_info[idx] = list(itertools.product(grid_idxs_width,grid_idxs_height,grid_idxs_depth)) + + if 0: #patient_id in self.patients_z_prob: + print (' - [DEBUG] patient_id: ', patient_id, ' || voxel_shape: ', voxel_shape) + print (' - [DEBUG] sampler_info: ', sampler_info[idx]) + + # Step 2.2 - Loop over all patients and their grids + # Note - Grids of a patient are extracted in order + for i, idx in enumerate(idxs): + path_img, path_mask, patient_id, study_id = self._get_paths(idx) + missing_labels = self.patient_meta_info[patient_id][config.KEYNAME_LABEL_MISSING] + bgd_mask = 1 # by default + if len(missing_labels): + bgd_mask = 0 + if path_img.exists() and path_mask.exists(): + for sample_info in sampler_info[idx]: + grid_idxs = sample_info + meta1 = [idx] + [grid_idxs[0][0], grid_idxs[1][0], grid_idxs[2][0]] # only include w_start, h_start, d_start + meta2 = '-'.join([self.name, study_id, patient_id + '_resample_' + str(self.resampled)]) + path_img = str(path_img) + path_mask = str(path_mask) + res.append((path_img, path_mask, meta1, meta2, bgd_mask)) + + else: + label_names = list(self.LABEL_MAP.keys()) + for i, idx in enumerate(idxs): + path_img, path_mask, patient_id, study_id = self._get_paths(idx) + missing_label_names = self.patient_meta_info[patient_id][config.KEYNAME_LABEL_MISSING] + bgd_mask = 1 + + # if len(missing_labels): bgd_mask = 0 + if len(missing_label_names): + if len(set(label_names) - set(missing_label_names)): + bgd_mask = 0 + + if path_img.exists() and path_mask.exists(): + meta1 = [idx] + [0,0,0] # dummy for w_start, h_start, d_start + meta2 ='-'.join([self.name, study_id, patient_id + '_resample_' + str(self.resampled)]) + path_img = str(path_img) + path_mask = str(path_mask) + res.append((path_img, path_mask, meta1, meta2, bgd_mask)) + + # Step 3 - Yield + for each in res: + path_img, path_mask, meta1, meta2, bgd_mask = each + + vol_img_npy, vol_mask_npy, spacing = self._get_cache_item(path_img, path_mask) + if vol_img_npy is None and vol_mask_npy is None: + vol_img_npy, vol_mask_npy, spacing = self._get_volume_from_path(path_img, path_mask) + self._set_cache_item(path_img, path_mask, vol_img_npy, vol_mask_npy, spacing) + + spacing = tf.constant(spacing, dtype=tf.int32) + vol_img_npy_shape = tf.constant(vol_img_npy.shape, dtype=tf.int32) + meta1 = tf.concat([meta1, spacing, vol_img_npy_shape, [bgd_mask]], axis=0) # [idx,[grid_idxs],[spacing],[shape]] + + yield (vol_img_npy, vol_mask_npy, meta1, meta2) + + def _get_cache_item(self, path_img, path_mask): + if 'img' in self.cache and 'mask' in self.cache: + if path_img in self.cache['img'] and path_mask in self.cache['mask']: + return self.cache['img'][path_img], self.cache['mask'][path_mask], self.cache['spacing'] + else: + return None, None, None + else: + return None, None, None + + def _set_cache_item(self, path_img, path_mask, vol_img, vol_mask, spacing): + # print (' - [_set_cache_item() ]: ', vol_img.shape, vol_mask.shape) + self.cache = { + 'img': {path_img: vol_img} + , 'mask': {path_mask: vol_mask} + , 'spacing': spacing + } + + def _get_volume_from_path(self, path_img, path_mask, verbose=False): + + # Step 1 - Get volumes + if verbose: t0 = time.time() + vol_img_sitk = utils.read_mha(path_img) + vol_img_npy = utils.sitk_to_array(vol_img_sitk) + vol_mask_sitk = utils.read_mha(path_mask) + vol_mask_npy = utils.sitk_to_array(vol_mask_sitk) + spacing = np.array(vol_img_sitk.GetSpacing()) + + # Step 2 - Perform init crop on volumes + if self.crop_init: + patient_id = str(Path(path_img).parts[-2]) + mean_point = np.array(self.patient_meta_info[patient_id][config.KEYNAME_MEAN_MIDPOINT]).astype(np.uint16).tolist() + + # Step 2.1 - Perform crops in H,W region + vol_img_npy = vol_img_npy[ + mean_point[0] - self.w_lateral_crop : mean_point[0] + self.w_medial_crop + , mean_point[1] - self.h_anterior_crop : mean_point[1] + self.h_posterior_crop + , : + ] + vol_mask_npy = vol_mask_npy[ + mean_point[0] - self.w_lateral_crop : mean_point[0] + self.w_medial_crop + , mean_point[1] - self.h_anterior_crop : mean_point[1] + self.h_posterior_crop + , : + ] + + # Step 2.2 - Perform crops in D region + if self.grid: + if self.VOXEL_RESO == (0.8,0.8,2.4): + if '0522c0125' not in patient_id: + vol_img_npy = vol_img_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop] + vol_mask_npy = vol_mask_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop] + else: + vol_img_npy = vol_img_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop] + vol_mask_npy = vol_mask_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop] + else: + if self.resampled: + if self.VOXEL_RESO == (0.8,0.8,2.4): + if '0522c0125' in patient_id: + # vol_img_npy = vol_img_npy[:, :, 0:self.d_caudal_crop + self.d_cranial_crop] + # vol_mask_npy = vol_mask_npy[:, :, 0:self.d_caudal_crop + self.d_cranial_crop] + vol_img_npy = vol_img_npy[:, :, -(self.d_caudal_crop + self.d_cranial_crop):] + vol_mask_npy = vol_mask_npy[:, :, -(self.d_caudal_crop + self.d_cranial_crop):] + else: + vol_img_npy = vol_img_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop] + vol_mask_npy = vol_mask_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop] + else: + vol_img_npy = vol_img_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop] + vol_mask_npy = vol_mask_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop] + else: + vol_img_npy = vol_img_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop] + vol_mask_npy = vol_mask_npy[:, :, mean_point[2] - self.d_caudal_crop : mean_point[2] + self.d_cranial_crop] + + # Step 3 - Pad (with=0) if volume is less in z-dimension + (vol_img_npy_x, vol_img_npy_y, vol_img_npy_z) = vol_img_npy.shape + if vol_img_npy_z < self.d_caudal_crop + self.d_cranial_crop: + del_z = self.d_caudal_crop + self.d_cranial_crop - vol_img_npy_z + vol_img_npy = np.concatenate((vol_img_npy, np.zeros((vol_img_npy_x, vol_img_npy_y, del_z))), axis=2) + vol_mask_npy = np.concatenate((vol_mask_npy, np.zeros((vol_img_npy_x, vol_img_npy_y, del_z))), axis=2) + + # print (' - [cropping] z_mean: ', mean_point[2], ' || -', self.d_caudal_crop, ' || + ', self.d_cranial_crop) + # print (' - [cropping] || shape: ', vol_img_npy.shape) + # print (' - [DEBUG] change: ', vol_img_npy_shape_prev, vol_img_npy.shape) + + if verbose: print (' - [HaNMICCAI2015Dataset._get_volume_from_path()] Time: ({}):{}s'.format(Path(path_img).parts[-2], round(time.time() - t0,2))) + if self.pregridnorm: + vol_img_npy[vol_img_npy <= self.HU_MIN] = self.HU_MIN + vol_img_npy[vol_img_npy >= self.HU_MAX] = self.HU_MAX + vol_img_npy = (vol_img_npy -np.mean(vol_img_npy))/np.std(vol_img_npy) #Standardize (z-scoring) + + return tf.cast(vol_img_npy, dtype=config.DATATYPE_TF_FLOAT32), tf.cast(vol_mask_npy, dtype=config.DATATYPE_TF_UINT8), tf.constant(spacing*100, dtype=config.DATATYPE_TF_INT32) + + @tf.function + def _get_new_grid_idx(self, start, end, max): + + start_prev = start + if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.RANDOM_SHIFT_PERC: + + delta_left = start + delta_right = max - end + shift_voxels = tf.random.uniform([], minval=0, maxval=self.RANDOM_SHIFT_MAX, dtype=tf.dtypes.int32) + + if delta_left > self.RANDOM_SHIFT_MAX and delta_right > self.RANDOM_SHIFT_MAX: + if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.RANDOM_SHIFT_PERC: + start = start - shift_voxels + end = end - shift_voxels + else: + start = start + shift_voxels + end = end + shift_voxels + + elif delta_left > self.RANDOM_SHIFT_MAX and delta_right <= self.RANDOM_SHIFT_MAX: + start = start - shift_voxels + end = end - shift_voxels + + elif delta_left <= self.RANDOM_SHIFT_MAX and delta_right > self.RANDOM_SHIFT_MAX: + start = start + shift_voxels + end = end + shift_voxels + + return start_prev, start, end + + @tf.function + def _get_new_grid_idx_centred(self, grid_size_half, max_pt, mid_pt): + + # Step 1 - Return vars + start, end = 0,0 + + # Step 2 - Define margin on either side of mid point + margin_left = mid_pt + margin_right = max_pt - mid_pt + + # Ste p2 - Calculate vars + if margin_left >= grid_size_half and margin_right >= grid_size_half: + start = mid_pt - grid_size_half + end = mid_pt + grid_size_half + elif margin_right < grid_size_half: + if margin_left >= grid_size_half + (grid_size_half - margin_right): + end = mid_pt + margin_right + start = mid_pt - grid_size_half - (grid_size_half - margin_right) + else: + tf.print(' - [ERROR][_get_new_grid_idx_centred()] Cond 2 problem') + elif margin_left < grid_size_half: + if margin_right >= grid_size_half + (grid_size_half - margin_left): + start = mid_pt - margin_left + end = mid_pt + grid_size_half + (grid_size_half-margin_left) + else: + tf.print(' - [ERROR][_get_new_grid_idx_centred()] Cond 3 problem') + + return start, end + + @tf.function + def _get_data_3D(self, vol_img, vol_mask, meta1, meta2): + """ + Params + ------ + meta1: [idx, [w_start, h_start, d_start], [spacing_x, spacing_y, spacing_z], [shape_x, shape_y, shape_z], [bgd_mask]] + """ + + vol_img_npy = None + vol_mask_npy = None + + # Step 1 - Proceed on the basis of grid sampling or full-volume (self.grid=False) sampling + if self.grid: + + if tf.random.uniform([], minval=0, maxval=1, dtype=tf.dtypes.float32) <= self.centred_dataloader_prob: + + # Step 1.1 - Get a label_id and its mean + label_id = tf.cast(tf.random.categorical(tf.math.log([self.LABEL_WEIGHTS]), 1), dtype=config.DATATYPE_TF_UINT8)[0][0] # the LABEL_WEGIHTS sum to 1; the log is added as tf.random.categorical expects logits + label_id_idxs = tf.where(tf.math.equal(label_id, vol_mask)) + label_id_idxs_mean = tf.math.reduce_mean(label_id_idxs, axis=0) + label_id_idxs_mean = tf.cast(label_id_idxs_mean, dtype=config.DATATYPE_TF_INT32) + + # Step 1.2 - Create a grid around that mid-point + w_start_prev = meta1[1] + h_start_prev = meta1[2] + d_start_prev = meta1[3] + w_max = meta1[7] + h_max = meta1[8] + d_max = meta1[9] + w_grid = self.grid_size[0] + h_grid = self.grid_size[1] + d_grid = self.grid_size[2] + w_mid = label_id_idxs_mean[0] + h_mid = label_id_idxs_mean[1] + d_mid = label_id_idxs_mean[2] + + w_start, w_end = self._get_new_grid_idx_centred(w_grid//2, w_max, w_mid) + h_start, h_end = self._get_new_grid_idx_centred(h_grid//2, h_max, h_mid) + d_start, d_end = self._get_new_grid_idx_centred(d_grid//2, d_max, d_mid) + + meta1_diff = tf.convert_to_tensor([0,w_start - w_start_prev, h_start - h_start_prev, d_start - d_start_prev,0,0,0,0,0,0,0]) + meta1 = meta1 + meta1_diff + + else: + + # tf.print(' - [INFO] regular dataloader: ', self.dir_type) + # Step 1.1 - Get raw images/masks and extract grid + w_start = meta1[1] + w_end = w_start + self.grid_size[0] + h_start = meta1[2] + h_end = h_start + self.grid_size[1] + d_start = meta1[3] + d_end = d_start + self.grid_size[2] + + # Step 1.2 - Randomization of grid + if self.random_grid: + w_max = meta1[7] + h_max = meta1[8] + d_max = meta1[9] + + w_start_prev = w_start + d_start_prev = d_start + w_start_prev, w_start, w_end = self._get_new_grid_idx(w_start, w_end, w_max) + h_start_prev, h_start, h_end = self._get_new_grid_idx(h_start, h_end, h_max) + d_start_prev, d_start, d_end = self._get_new_grid_idx(d_start, d_end, d_max) + + meta1_diff = tf.convert_to_tensor([0,w_start - w_start_prev, h_start - h_start_prev, d_start - d_start_prev,0,0,0,0,0,0,0]) + meta1 = meta1 + meta1_diff + + # Step 1.3 - Extracting grid + vol_img_npy = tf.identity(vol_img[w_start:w_end, h_start:h_end, d_start:d_end]) + vol_mask_npy = tf.identity(vol_mask[w_start:w_end, h_start:h_end, d_start:d_end]) + + + else: + vol_img_npy = vol_img + vol_mask_npy = vol_mask + + # Step 2 - One-hot or not + vol_mask_classes = [] + label_ids_mask = [] + label_ids = sorted(list(self.LABEL_MAP.values())) + if self.mask_type == config.MASK_TYPE_ONEHOT: + vol_mask_classes = tf.concat([tf.expand_dims(tf.math.equal(vol_mask_npy, label), axis=-1) for label in label_ids], axis=-1) # [H,W,D,L] + for label_id in label_ids: + label_ids_mask.append(tf.cast(tf.math.reduce_any(vol_mask_classes[:,:,:,label_id]), dtype=config.DATATYPE_TF_INT32)) + + elif self.mask_type == config.MASK_TYPE_COMBINED: + vol_mask_classes = vol_mask_npy + unique_classes, _ = tf.unique(tf.reshape(vol_mask_npy,[-1])) + unique_classes = tf.cast(unique_classes, config.DATATYPE_TF_INT32) + for label_id in label_ids: + label_ids_mask.append(tf.cast(tf.math.reduce_any(tf.math.equal(unique_classes, label_id)), dtype=config.DATATYPE_TF_INT32)) + + # Step 2.2 - Handling background mask explicitly if there is a missing label + bgd_mask = meta1[-1] + label_ids_mask[0] = bgd_mask + meta1 = meta1[:-1] # removes the bgd mask index + + # Step 3 - Dtype conversion and expading dimensions + if self.mask_type == config.MASK_TYPE_ONEHOT: + x = tf.cast(tf.expand_dims(vol_img_npy, axis=3), dtype=config.DATATYPE_TF_FLOAT32) # [H,W,D,1] + else: + x = tf.cast(vol_img_npy, dtype=config.DATATYPE_TF_FLOAT32) # [H,W,D] + y = tf.cast(vol_mask_classes, dtype=config.DATATYPE_TF_FLOAT32) # [H,W,D,L] + + # Step 4 - Append info to meta1 + meta1 = tf.concat([meta1, label_ids_mask], axis=0) + + # Step 5 - return + return (x, y, meta1, meta2) + \ No newline at end of file diff --git a/src/dataloader/utils.py b/src/dataloader/utils.py new file mode 100644 index 0000000..88c6741 --- /dev/null +++ b/src/dataloader/utils.py @@ -0,0 +1,1273 @@ +# Import internal libraries +import src.config as config +import src.dataloader.augmentations as aug +from src.dataloader.dataset import ZipDataset +from src.dataloader.han_miccai2015 import HaNMICCAI2015Dataset +from src.dataloader.han_deepmindtcia import HaNDeepMindTCIADataset + +# Import external libraries +import os +import pdb +import itk +import copy +import time +import tqdm +import json +import urllib +import psutil +import pydicom +import humanize +import traceback +import numpy as np +import tensorflow as tf +from pathlib import Path +import SimpleITK as sitk # sitk.Version.ExtendedVersionString() + +if len(tf.config.list_physical_devices('GPU')): + tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True) + + + +############################################################ +# DOWNLOAD RELATED # +############################################################ + +class DownloadProgressBar(tqdm.tqdm): + def update_to(self, b=1, bsize=1, tsize=None): + if tsize is not None: + self.total = tsize + self.update(b * bsize - self.n) + +def download_zip(url_zip, filepath_zip, filepath_output): + import urllib + with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=' - [Download]' + str(url_zip.split('/')[-1]) ) as pbar: + urllib.request.urlretrieve(url_zip, filename=filepath_zip, reporthook=pbar.update_to) + read_zip(filepath_zip, filepath_output) + +def read_zip(filepath_zip, filepath_output=None, leave=False): + + import zipfile + + # Step 0 - Init + if Path(filepath_zip).exists(): + if filepath_output is None: + filepath_zip_parts = list(Path(filepath_zip).parts) + filepath_zip_name = filepath_zip_parts[-1].split(config.EXT_ZIP)[0] + filepath_zip_parts[-1] = filepath_zip_name + filepath_output = Path(*filepath_zip_parts) + + zip_fp = zipfile.ZipFile(filepath_zip, 'r') + zip_fp_members = zip_fp.namelist() + with tqdm.tqdm(total=len(zip_fp_members), desc=' - [Unzip] ' + str(filepath_zip.parts[-1]), leave=leave) as pbar_zip: + for member in zip_fp_members: + zip_fp.extract(member, filepath_output) + pbar_zip.update(1) + + return filepath_output + else: + print (' - [ERROR][read_zip()] Path does not exist: ', filepath_zip) + return None + +def move_folder(path_src_folder, path_dest_folder): + + import shutil + if Path(path_src_folder).exists(): + shutil.move(str(path_src_folder), str(path_dest_folder)) + else: + print (' - [ERROR][utils.move_folder()] path_src_folder:{} does not exist'.format(path_src_folder)) + +############################################################ +# ITK/SITK RELATED # +############################################################ + +def write_mha(filepath, data_array, spacing=[], origin=[]): + # data_array = numpy array + if len(spacing) and len(origin): + img_sitk = array_to_sitk(data_array, spacing=spacing, origin=origin) + elif len(spacing) and not len(origin): + img_sitk = array_to_sitk(data_array, spacing=spacing) + + sitk.WriteImage(img_sitk, str(filepath), useCompression=True) + +def write_mha(path_save, img_data, img_headers, img_dtype): + + # Step 0 - Path related + path_save_parents = Path(path_save).parent.absolute() + path_save_parents.mkdir(exist_ok=True, parents=True) + + # Step 1 - Create ITK volume + orig_origin = img_headers[config.KEYNAME_ORIGIN] + orig_pixel_spacing = img_headers[config.KEYNAME_PIXEL_SPACING] + img_sitk = array_to_sitk(img_data, origin=orig_origin, spacing=orig_pixel_spacing) + img_sitk = sitk.Cast(img_sitk, img_dtype) + sitk.WriteImage(img_sitk, str(path_save), useCompression=True) + + return img_sitk + +def imwrite_sitk(img, filepath, dtype, compression=True): + """ + img: itk () or sitk image + """ + import itk + + def convert_itk_to_sitk(image_itk, dtype): + + img_array = itk.GetArrayFromImage(image_itk) + if dtype in ['short', 'int16', config.DATATYPE_VOXEL_IMG]: + img_array = np.array(img_array, dtype=config.DATATYPE_VOXEL_IMG) + elif dtype in ['unsigned int', 'uint8', config.DATATYPE_VOXEL_MASK]: + img_array = np.array(img_array, dtype=config.DATATYPE_VOXEL_MASK) + + image_sitk = sitk.GetImageFromArray(img_array, isVector=image_itk.GetNumberOfComponentsPerPixel()>1) + image_sitk.SetOrigin(tuple(image_itk.GetOrigin())) + image_sitk.SetSpacing(tuple(image_itk.GetSpacing())) + image_sitk.SetDirection(itk.GetArrayFromMatrix(image_itk.GetDirection()).flatten()) + return image_sitk + + writer = sitk.ImageFileWriter() + writer.SetFileName(str(filepath)) + writer.SetUseCompression(compression) + if 'SimpleITK' not in str(type(img)): + writer.Execute(convert_itk_to_sitk(img, dtype)) + else: + writer.Execute(img) + +def read_itk(img_url, fixed_ct_skip_slices=None): + + if Path(img_url).exists(): + img = itk.imread(str(img_url), itk.F) + + if fixed_ct_skip_slices is not None: + img_array = itk.GetArrayFromImage(img) # [D,H,W] + img_array = img_array[fixed_ct_skip_slices:,:,:] + + img_ = itk.GetImageFromArray(img_array) + img_.SetOrigin(tuple(img.GetOrigin())) + img_.SetSpacing(tuple(img.GetSpacing())) + img_.SetDirection(img.GetDirection()) + img = img_ + + return img + else: + print (' - [read_itk()] Path does not exist: ', img_url) + return None + +def read_itk_mask(mask_url, fixed_ct_skip_slices=None): + + if Path(mask_url).exists(): + img = itk.imread(str(mask_url), itk.UC) + + if fixed_ct_skip_slices is not None: + img_array = itk.GetArrayFromImage(img) # [D,H,W] + img_array = img_array[fixed_ct_skip_slices:,:,:] + + img_ = itk.GetImageFromArray(img_array) + img_.SetOrigin(tuple(img.GetOrigin())) + img_.SetSpacing(tuple(img.GetSpacing())) + img_.SetDirection(img.GetDirection()) + img = img_ + + return img + + else: + print (' - [read_itk()] Path does not exist: ', mask_url) + return None + +def array_to_itk(array, origin, spacing): + """ + array = [H,W,D] + origin = [x,y,z] + spacing = [x,y,z] + """ + + try: + import itk + + img_itk = itk.GetImageFromArray(np.moveaxis(array, [0,1,2], [2,1,0]).copy()) + img_itk.SetOrigin(tuple(origin)) + img_itk.SetSpacing(tuple(spacing)) + # dir = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]).astype(np.float64) + # img_itk.SetDirection(itk.GetMatrixFromArray(dir)) + + return img_itk + + except: + traceback.print_exc() + pdb.set_trace() + +def array_to_sitk(array_input, size=None, origin=None, spacing=None, direction=None, is_vector=False, im_ref=None): + """ + This function takes an array and converts it into a SimpleITK image. + + Parameters + ---------- + array_input: numpy + The numpy array to convert to a SimpleITK image + size: tuple, optional + The size of the array + origin: tuple, optional + The origin of the data in physical space + spacing: tuple, optional + Spacing describes the physical sie of each pixel + direction: tuple, optional + A [nxn] matrix passed as a 1D in a row-major form for a nD matrix (n=[2,3]) to infer the orientation of the data + is_vector: bool, optional + If isVector is True, then the Image will have a Vector pixel type, and the last dimension of the array will be considered the component index. + im_ref: sitk image + An empty image with meta information + + Ref: https://github.com/hsokooti/RegNet/blob/46f345d25cd6a1e0ee6f230f64c32bd15b7650d3/functions/image/image_processing.py#L86 + """ + import SimpleITK as sitk + verbose = False + + if origin is None: + origin = [0.0, 0.0, 0.0] + if spacing is None: + spacing = [1, 1, 1] # the voxel spacing + if direction is None: + direction = [1, 0, 0, 0, 1, 0, 0, 0, 1] + if size is None: + size = np.array(array_input).shape + + """ + ITK has a GetPixel which takes an ITK Index object as an argument, which is ordered as (x,y,z). + This is the convention that SimpleITK's Image class uses for the GetPixel method and slicing operator as well. + In numpy, an array is indexed in the opposite order (z,y,x) + """ + sitk_output = sitk.GetImageFromArray(np.moveaxis(array_input, [0,1,2], [2,1,0]), isVector=is_vector) # np([H,W,D]) -> np([D,W,H]) -> sitk([H,W,D]) + + if im_ref is None: + sitk_output.SetOrigin(origin) + sitk_output.SetSpacing(spacing) + sitk_output.SetDirection(direction) + else: + sitk_output.SetOrigin(im_ref.GetOrigin()) + sitk_output.SetSpacing(im_ref.GetSpacing()) + sitk_output.SetDirection(im_ref.GetDirection()) + + return sitk_output + +def sitk_to_array(sitk_image): + array = sitk.GetArrayFromImage(sitk_image) + array = np.moveaxis(array, [0,1,2], [2,1,0]) # [D,W,H] --> [H,W,D] + return array + +def itk_to_array(sitk_image): + import itk + array = itk.GetArrayFromImage(sitk_image) + array = np.moveaxis(array, [0,1,2], [2,1,0]) # [D,W,H] --> [H,W,D] + return array + +def resampler_sitk(image_sitk, spacing=None, scale=None, im_ref=None, im_ref_size=None, default_pixel_value=0, interpolator=None, dimension=3): + """ + :param image_sitk: input image + :param spacing: desired spacing to set + :param scale: if greater than 1 means downsampling, less than 1 means upsampling + :param im_ref: if im_ref available, the spacing will be overwritten by the im_ref.GetSpacing() + :param im_ref_size: in sikt order: x, y, z + :param default_pixel_value: + :param interpolator: + :param dimension: + :return: + """ + + import math + import SimpleITK as sitk + + if spacing is None and scale is None: + raise ValueError('spacing and scale cannot be both None') + if interpolator is None: + interpolator = sitk.sitkBSpline # sitk.Linear, sitk.Nearest + + if spacing is None: + spacing = tuple(i * scale for i in image_sitk.GetSpacing()) + if im_ref_size is None: + im_ref_size = tuple(round(i / scale) for i in image_sitk.GetSize()) + + elif scale is None: + ratio = [spacing_dim / spacing[i] for i, spacing_dim in enumerate(image_sitk.GetSpacing())] + if im_ref_size is None: + im_ref_size = tuple(math.ceil(size_dim * ratio[i]) - 1 for i, size_dim in enumerate(image_sitk.GetSize())) + else: + raise ValueError('spacing and scale cannot both have values') + + if im_ref is None: + im_ref = sitk.Image(im_ref_size, sitk.sitkInt8) + im_ref.SetOrigin(image_sitk.GetOrigin()) + im_ref.SetDirection(image_sitk.GetDirection()) + im_ref.SetSpacing(spacing) + + + identity = sitk.Transform(dimension, sitk.sitkIdentity) + resampled_sitk = resampler_by_transform(image_sitk, identity, im_ref=im_ref, + default_pixel_value=default_pixel_value, + interpolator=interpolator) + return resampled_sitk + +def resampler_by_transform(im_sitk, dvf_t, im_ref=None, default_pixel_value=0, interpolator=None): + import SimpleITK as sitk + + if im_ref is None: + im_ref = sitk.Image(dvf_t.GetDisplacementField().GetSize(), sitk.sitkInt8) + im_ref.SetOrigin(dvf_t.GetDisplacementField().GetOrigin()) + im_ref.SetSpacing(dvf_t.GetDisplacementField().GetSpacing()) + im_ref.SetDirection(dvf_t.GetDisplacementField().GetDirection()) + + if interpolator is None: + interpolator = sitk.sitkBSpline + + resampler = sitk.ResampleImageFilter() + resampler.SetReferenceImage(im_ref) + resampler.SetInterpolator(interpolator) + resampler.SetDefaultPixelValue(default_pixel_value) + resampler.SetTransform(dvf_t) + + # [DEBUG] + resampler.SetOutputPixelType(sitk.sitkFloat32) + + out_im = resampler.Execute(im_sitk) + return out_im + +def save_as_mha_mask(data_dir, patient_id, voxel_mask, voxel_img_headers): + + try: + voxel_save_folder = Path(data_dir).joinpath(patient_id) + Path(voxel_save_folder).mkdir(parents=True, exist_ok=True) + study_id = Path(voxel_save_folder).parts[-1] + if config.STR_ACCESSION_PREFIX in str(study_id): + study_id = 'acc_' + str(study_id.split(config.STR_ACCESSION_PREFIX)[-1]) + study_id = study_id.replace('.', '') + + + orig_origin = voxel_img_headers[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_ORIGIN] + orig_pixel_spacing = voxel_img_headers[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_PIXEL_SPACING] + + if len(voxel_mask): + voxel_mask_sitk = array_to_sitk(voxel_mask.astype(config.DATATYPE_VOXEL_MASK) + , origin=orig_origin, spacing=orig_pixel_spacing) + path_voxel_mask = Path(voxel_save_folder).joinpath(config.FILENAME_MASK_TUMOR_3D.format(study_id)) + sitk.WriteImage(voxel_mask_sitk, str(path_voxel_mask), useCompression=True) + else: + print (' - [ERROR][save_as_mha_mask()] Patient: ', patient_id) + except: + traceback.print_exc() + pdb.set_trace() + pass + +def save_as_mha(data_dir, patient_id, voxel_img, voxel_img_headers, voxel_mask + , voxel_img_reg_dict={}, labelid_midpoint=None + , resample_spacing=[]): + + try: + """ + Thi function converts the raw numpy data into a SimpleITK image and saves as .mha + + Parameters + ---------- + data_dir: Path + The path where you would like to save the data + patient_id: str + A reference to the patient + voxel_img: numpy + A nD numpy array with [H,W,D] format containing radiodensity data in Hounsfield units + voxel_img_headers: dict + A python dictionary containing information on 'origin' and 'pixel_spacing' + voxel_mask: numpy + A nD array with labels on each nD voxel + resample_save: bool + A boolean variable to indicate whether the function should resample + """ + + # Step 1 - Original Voxel resolution + ## Step 1.1 - Create save dir + voxel_save_folder = Path(data_dir).joinpath(patient_id) + Path(voxel_save_folder).mkdir(parents=True, exist_ok=True) + study_id = Path(voxel_save_folder).parts[-1] + if config.STR_ACCESSION_PREFIX in str(study_id): + study_id = config.FILEPREFIX_ACCENSION + get_acc_id_from_str(study_id) + + ## Step 1.2 - Save img voxel + orig_origin = voxel_img_headers[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_ORIGIN] + orig_pixel_spacing = voxel_img_headers[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_PIXEL_SPACING] + voxel_img_sitk = array_to_sitk(voxel_img.astype(config.DATATYPE_VOXEL_IMG) + , origin=orig_origin, spacing=orig_pixel_spacing) + path_voxel_img = Path(voxel_save_folder).joinpath(config.FILENAME_IMG_3D.format(study_id)) + sitk.WriteImage(voxel_img_sitk, str(path_voxel_img), useCompression=True) + + ## Step 1.3 - Save mask voxel + voxel_mask_sitk = array_to_sitk(voxel_mask.astype(config.DATATYPE_VOXEL_MASK) + , origin=orig_origin, spacing=orig_pixel_spacing) + path_voxel_mask = Path(voxel_save_folder).joinpath(config.FILENAME_MASK_3D.format(study_id)) + sitk.WriteImage(voxel_mask_sitk, str(path_voxel_mask), useCompression=True) + + ## Step 1.4 - Save registration params + paths_voxel_reg = {} + for study_id in voxel_img_reg_dict: + if type(voxel_img_reg_dict[study_id]) == sitk.AffineTransform: + path_voxel_reg = str(Path(voxel_save_folder).joinpath('{}.tfm'.format(study_id))) + sitk.WriteTransform(voxel_img_reg_dict[study_id], path_voxel_reg) + paths_voxel_reg[study_id] = str(path_voxel_reg) + + # Step 2 - Resampled Voxel + if len(resample_spacing): + new_spacing = resample_spacing + + ## Step 2.1 - Save resampled img + voxel_img_sitk_resampled = resampler_sitk(voxel_img_sitk, spacing=new_spacing, interpolator=sitk.sitkBSpline) + voxel_img_sitk_resampled = sitk.Cast(voxel_img_sitk_resampled, sitk.sitkInt16) + path_voxel_img_resampled = Path(voxel_save_folder).joinpath(config.FILENAME_IMG_RESAMPLED_3D.format(study_id)) + sitk.WriteImage(voxel_img_sitk_resampled, str(path_voxel_img_resampled), useCompression=True) + interpolator_img = 'sitk.sitkBSpline' + + ## Step 2.2 - Save resampled mask + voxel_mask_sitk_resampled = [] + interpolator_mask = '' + if 0: + voxel_mask_sitk_resampled = resampler_sitk(voxel_mask_sitk, spacing=new_spacing, interpolator=sitk.sitkNearestNeighbor) + interpolator_mask = 'sitk.sitkNearestNeighbor' + + elif 1: + interpolator_mask = 'sitk.sitkLinear' + new_size = voxel_img_sitk_resampled.GetSize() + voxel_mask_resampled = np.zeros(new_size) + for label_id in np.unique(voxel_mask): + if label_id != 0: + voxel_mask_singlelabel = copy.deepcopy(voxel_mask).astype(config.DATATYPE_VOXEL_MASK) + voxel_mask_singlelabel[voxel_mask_singlelabel != label_id] = 0 + voxel_mask_singlelabel[voxel_mask_singlelabel == label_id] = 1 + voxel_mask_singlelabel_sitk = array_to_sitk(voxel_mask_singlelabel + , origin=orig_origin, spacing=orig_pixel_spacing) + voxel_mask_singlelabel_sitk_resampled = resampler_sitk(voxel_mask_singlelabel_sitk, spacing=new_spacing + , interpolator=sitk.sitkLinear) + if 0: + voxel_mask_singlelabel_sitk_resampled = sitk.Cast(voxel_mask_singlelabel_sitk_resampled, sitk.sitkUInt8) + voxel_mask_singlelabel_array_resampled = sitk_to_array(voxel_mask_singlelabel_sitk_resampled) + idxs = np.argwhere(voxel_mask_singlelabel_array_resampled > 0) + else: + voxel_mask_singlelabel_array_resampled = sitk_to_array(voxel_mask_singlelabel_sitk_resampled) + idxs = np.argwhere(voxel_mask_singlelabel_array_resampled >= 0.5) + voxel_mask_resampled[idxs[:,0], idxs[:,1], idxs[:,2]] = label_id + + voxel_mask_sitk_resampled = array_to_sitk(voxel_mask_resampled + , origin=orig_origin, spacing=new_spacing) + + voxel_mask_sitk_resampled = sitk.Cast(voxel_mask_sitk_resampled, sitk.sitkUInt8) + path_voxel_mask_resampled = Path(voxel_save_folder).joinpath(config.FILENAME_MASK_RESAMPLED_3D.format(study_id)) + sitk.WriteImage(voxel_mask_sitk_resampled, str(path_voxel_mask_resampled), useCompression=True) + + # Step 2.3 - Save voxel info for resampled data + midpoint_idxs_mean = [] + if labelid_midpoint is not None: + voxel_mask_resampled_data = sitk_to_array(voxel_mask_sitk_resampled) + midpoint_idxs = np.argwhere(voxel_mask_resampled_data == labelid_midpoint) + midpoint_idxs_mean = np.mean(midpoint_idxs, axis=0) + + path_voxel_headers = Path(voxel_save_folder).joinpath(config.FILENAME_VOXEL_INFO) + + voxel_img_headers[config.TYPE_VOXEL_RESAMPLED] = {config.KEYNAME_MEAN_MIDPOINT : midpoint_idxs_mean.tolist()} + voxel_img_headers[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_PIXEL_SPACING] = voxel_img_sitk_resampled.GetSpacing() + voxel_img_headers[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_ORIGIN] = voxel_img_sitk_resampled.GetOrigin() + voxel_img_headers[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_SHAPE] = voxel_img_sitk_resampled.GetSize() + voxel_img_headers[config.TYPE_VOXEL_RESAMPLED][config.TYPE_VOXEL_ORIGSHAPE] = { + config.KEYNAME_INTERPOLATOR_IMG: interpolator_img + , config.KEYNAME_INTERPOLATOR_MASK: interpolator_mask + } + if config.KEYNAME_LABEL_OARS in voxel_img_headers[config.TYPE_VOXEL_ORIGSHAPE]: + voxel_img_headers[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_LABEL_OARS] = voxel_img_headers[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_LABEL_OARS] + if config.KEYNAME_LABEL_MISSING in voxel_img_headers[config.TYPE_VOXEL_ORIGSHAPE]: + voxel_img_headers[config.TYPE_VOXEL_RESAMPLED][config.KEYNAME_LABEL_MISSING] = voxel_img_headers[config.TYPE_VOXEL_ORIGSHAPE][config.KEYNAME_LABEL_MISSING] + + write_json(voxel_img_headers, path_voxel_headers) + + ## Step 3 - Save img headers + path_voxel_headers = Path(voxel_save_folder).joinpath(config.FILENAME_VOXEL_INFO) + write_json(voxel_img_headers, path_voxel_headers) + + if len(voxel_img_reg_dict): + return str(path_voxel_img), str(path_voxel_mask), paths_voxel_reg + else: + return str(path_voxel_img), str(path_voxel_mask), {} + + except: + print ('\n --------------- [ERROR][save_as_mha()]') + traceback.print_exc() + print (' --------------- [ERROR][save_as_mha()]\n') + pdb.set_trace() + +def read_mha(path_file): + try: + + if Path(path_file).exists(): + img_mha = sitk.ReadImage(str(path_file)) + return img_mha + else: + print (' - [ERROR][read_mha()] Path issue: path_file: ', path_file) + pdb.set_trace() + + except: + traceback.print_exc() + pdb.set_trace() + +def write_itk(data, filepath): + filepath_parent = Path(filepath).parent.absolute() + filepath_parent.mkdir(parents=True, exist_ok=True) + itk.imwrite(data, str(filepath), compression=True) + +def resample_img(path_img, path_new_img, spacing): + + try: + + img_sitk = sitk.ReadImage(str(path_img)) + img_resampled = resampler_sitk(img_sitk, spacing=spacing, interpolator=sitk.sitkBSpline) + img_resampled = sitk.Cast(img_resampled, sitk.sitkInt16) + + path_new_parent = Path(path_new_img).parent.absolute() + Path(path_new_parent).mkdir(exist_ok=True, parents=True) + sitk.WriteImage(img_resampled, str(path_new_img), useCompression=True) + + return img_resampled + + except: + traceback.print_exc() + +def resample_mask(path_mask, path_new_mask, spacing, size, labels_allowed = []): + + try: + + # Step 0 - Init + mask_array_resampled = np.zeros(size) + + # Step 1 - Read mask + mask_sitk = sitk.ReadImage(str(path_mask)) + mask_spacing = mask_sitk.GetSpacing() + mask_origin = mask_sitk.GetOrigin() + mask_array = sitk_to_array(mask_sitk) + + # Step 2 - Loop over mask labels + for label_id in np.unique(mask_array): + if label_id != 0 and label_id in labels_allowed: + mask_array_label = np.array(mask_array, copy=True) + mask_array_label[mask_array_label != label_id] = 0 + mask_array_label[mask_array_label == label_id] = 1 + mask_sitk_label = array_to_sitk(mask_array_label, origin=mask_origin, spacing=mask_spacing) + mask_sitk_label_resampled = resampler_sitk(mask_sitk_label, spacing=spacing, interpolator=sitk.sitkLinear) + mask_array_label_resampled = sitk_to_array(mask_sitk_label_resampled) + idxs = np.argwhere(mask_array_label_resampled >= 0.5) + mask_array_resampled[idxs[:,0], idxs[:,1], idxs[:,2]] = label_id + + mask_sitk_resampled = array_to_sitk(mask_array_resampled , origin=mask_origin, spacing=spacing) + mask_sitk_resampled = sitk.Cast(mask_sitk_resampled, sitk.sitkUInt8) + + path_new_parent = Path(path_new_mask).parent.absolute() + Path(path_new_parent).mkdir(exist_ok=True, parents=True) + sitk.WriteImage(mask_sitk_resampled, str(path_new_mask), useCompression=True) + + return mask_sitk_resampled + + except: + traceback.print_exc() + +############################################################ +# 3D VOXEL RELATED # +############################################################ + +def split_into_overlapping_grids(len_total, len_grid, len_overlap, res_type='boundary'): + res_range = [] + res_boundary = [] + + A = np.arange(len_total) + l_start = 0 + l_end = len_grid + while(l_end < len(A)): + res_range.append(np.arange(l_start, l_end)) + res_boundary.append([l_start, l_end]) + l_start = l_start + len_grid - len_overlap + l_end = l_start + len_grid + + res_range.append(np.arange(len(A)-len_grid, len(A))) + res_boundary.append([len(A)-len_grid, len(A)]) + if res_type == 'boundary': + return res_boundary + elif res_type == 'range': + return res_range + +def extract_numpy_from_dcm(patient_dir, skip_slices=None): + """ + Given the path of the folder containing the .dcm files, this function extracts + a numpy array by combining them + + Parameters + ---------- + patient_dir: Path + The path of the folder containing the .dcm files + """ + slices = [] + voxel_img_data = [] + try: + import pydicom + from pathlib import Path + + if Path(patient_dir).exists(): + for path_ct in Path(patient_dir).iterdir(): + try: + ds = pydicom.filereader.dcmread(path_ct) + slices.append(ds) + except: + pass + + if len(slices): + slices = list(filter(lambda x: 'ImagePositionPatient' in x, slices)) + slices.sort(key = lambda x: float(x.ImagePositionPatient[2])) # inferior to superior + slices_data = [] + for s_id, s in enumerate(slices): + try: + if skip_slices is not None: + if s_id < skip_slices: + continue + slices_data.append(s.pixel_array.T) + except: + pass + + if len(slices_data): + voxel_img_data = np.stack(slices_data, axis=-1) # [row, col, plane], [H,W,D] + + except: + pass + traceback.print_exc() + # pdb.set_trace() + + return slices, voxel_img_data + +def perform_hu(voxel_img, intercept, slope): + """ + Rescale Intercept != 0, Rescale Slope != 1 or Dose Grid Scaling != 1. + The pixel data has not been transformed according to these values. + Consider using the module ApplyDicomPixelModifiers after + importing the volume to transform the image data appropriately. + """ + try: + import copy + + slope = np.float(slope) + intercept = np.float(intercept) + voxel_img_hu = copy.deepcopy(voxel_img).astype(np.float64) + + # Convert to Hounsfield units (HU) + if slope != 1: + voxel_img_hu = slope * voxel_img_hu + + voxel_img_hu += intercept + + return voxel_img_hu.astype(config.DATATYPE_VOXEL_IMG) + except: + traceback.print_exc() + pdb.set_trace() + +def print_final_message(): + print ('') + print (' - Note: You can view the 3D data in visualizers like MeVisLab or 3DSlicer') + print (' - Note: Raw Voxel Data ({}/{}) is in Hounsfield units (HU) with int16 datatype'.format(config.FILENAME_IMG_3D, config.FILENAME_IMG_RESAMPLED_3D)) + print ('') + +def volumize_ct(patient_ct_dir, skip_slices=None): + """ + Params + ------ + patient_ct_dir: Path - contains .dcm files for CT scans + """ + voxel_ct_data = [] + voxel_ct_headers = {} + + try: + + if Path(patient_ct_dir).exists(): + + # Step1 - Accumulate all slices and sort + slices, voxel_ct_data = extract_numpy_from_dcm(Path(patient_ct_dir), skip_slices=skip_slices) + + # Step 2 - Get parameters related to the 3D scan + if len(voxel_ct_data): + slice_thickness = 0 + if slices[0].SliceThickness < 0.01: + slice_thickness = float(slices[1].ImagePositionPatient[2] - slices[0].ImagePositionPatient[2]) + else: + slice_thickness = float(slices[0].SliceThickness) + + voxel_ct_headers[config.KEYNAME_PIXEL_SPACING] = np.array(slices[0].PixelSpacing).tolist() + [slice_thickness] + voxel_ct_headers[config.KEYNAME_ORIGIN] = np.array(slices[0].ImagePositionPatient).tolist() + voxel_ct_headers[config.KEYNAME_SHAPE] = voxel_ct_data.shape + voxel_ct_headers[config.KEYNAME_ZVALS] = [round(s.ImagePositionPatient[2]) for s in slices] + voxel_ct_headers[config.KEYNAME_INTERCEPT] = slices[0].RescaleIntercept + voxel_ct_headers[config.KEYNAME_SLOPE] = slices[0].RescaleSlope + + # Step 3 - Postprocess the 3D voxel data (data in HU + reso=[config.VOXEL_DATA_RESO]) + voxel_ct_data = perform_hu(voxel_ct_data, intercept=slices[0].RescaleIntercept, slope=slices[0].RescaleSlope) + + else: + print (' - [ERROR][volumize_ct()] Error with numpy extraction of CT volume: ', patient_ct_dir) + + else: + print (' - [ERROR][volumize_ct()] Path issue: ', patient_ct_dir) + except: + traceback.print_exc() + pdb.set_trace() + + return voxel_ct_data, voxel_ct_headers + +############################################################ +# 3D VOXEL RELATED - RTSTRUCT # +############################################################ + +def get_self_label_id(label_name_orig, RAW_TO_SELF_LABEL_MAPPING, LABEL_MAP): + + try: + if label_name_orig in RAW_TO_SELF_LABEL_MAPPING: + label_name_self = RAW_TO_SELF_LABEL_MAPPING[label_name_orig] + if label_name_self in LABEL_MAP: + return LABEL_MAP[label_name_self], label_name_self + else: + return 0, '' + else: + return 0, '' + except: + traceback.print_exc() + print (' - [ERROR][get_global_label_id()] label_name_orig: ', label_name_orig) + # pdb.set_trace() + return 0, '' + +def extract_contours_from_rtstruct(rtstruct_ds, RAW_TO_SELF_LABEL_MAPPING=None, LABEL_MAP=None, verbose=False): + """ + Params + ------ + rtstruct_ds: pydicom.dataset.FileDataset + - .StructureSetROISequence: contains identifying info on all contours + - .ROIContourSequence : contains a set of contours for particular ROI + """ + + # Step 0 - Init + contours = [] + labels_debug = {} + + # Step 1 - Loop and extract all different contours + for i in range(len(rtstruct_ds.ROIContourSequence)): + + try: + + # Step 1.1 - Get contour + contour = {} + contour['name'] = str(rtstruct_ds.StructureSetROISequence[i].ROIName) + contour['color'] = list(rtstruct_ds.ROIContourSequence[i].ROIDisplayColor) + contour['contours'] = [s.ContourData for s in rtstruct_ds.ROIContourSequence[i].ContourSequence] + + if RAW_TO_SELF_LABEL_MAPPING is None and LABEL_MAP is None: + contour['number'] = rtstruct_ds.ROIContourSequence[i].ReferencedROINumber + assert contour['number'] == rtstruct_ds.StructureSetROISequence[i].ROINumber + elif RAW_TO_SELF_LABEL_MAPPING is None and LABEL_MAP is not None: + contour['number'] = int(LABEL_MAP.get(contour['name'], -1)) + else: + label_id, _ = get_self_label_id(contour['name'], RAW_TO_SELF_LABEL_MAPPING, LABEL_MAP) + contour['number'] = label_id + + if verbose: print (' - name: ', contour['name'], LABEL_MAP.get(contour['name'], -1)) + + # Step 1.2 - Keep or not condition + if contour['number'] > 0: + contours.append(contour) + + # Step 1.3 - Some debugging + labels_debug[contour['name']] = {'id': len(labels_debug) + 1} + + except: + if verbose: print ('\n ---------- [ERROR][utils.extract_contours_from_rtstruct()] name: ', rtstruct_ds.StructureSetROISequence[i].ROIName) + + # Step 2 - Order your contours + if len(contours): + contours = list(sorted(contours, key = lambda obj: obj['number'])) + + return contours, labels_debug + +def process_contours(contour_obj_list, params, voxel_mask_data, special_processing_ids = []): + """ + Goal: Convert contours to voxel mask + Params + ------ + contour_obj_list: [{'name':'', 'contours': [[], [], ... ,[]]}, {}, ..., {}] + special_processing_ids: For donut shaped contours + """ + import skimage + import skimage.draw + + # Step 1 - Get some position and spacing params + z = params[config.KEYNAME_ZVALS] + pos_r = params[config.KEYNAME_ORIGIN][1] + spacing_r = params[config.KEYNAME_PIXEL_SPACING][1] + pos_c = params[config.KEYNAME_ORIGIN][0] + spacing_c = params[config.KEYNAME_PIXEL_SPACING][0] + shape = params[config.KEYNAME_SHAPE] + + # Step 2 - Loop over contour objects + for contour_obj in contour_obj_list: + + try: + + class_id = int(contour_obj['number']) + + if class_id not in special_processing_ids: + + # Step 2.1 - Pick a contour for a particular ROI + for c_id, contour in enumerate(contour_obj['contours']): + coords = np.array(contour).reshape((-1, 3)) + if len(coords) > 1: + + # Step 2.2 - Get the z-index of a particular z-value + assert np.amax(np.abs(np.diff(coords[:, 2]))) == 0 + z_index = z.index(pydicom.valuerep.DSfloat(float(round(coords[0, 2])))) + + # Step 2.4 - Polygonize + rows = (coords[:, 1] - pos_r) / spacing_r #pixel_idx = f(real_world_idx, ct_resolution) + cols = (coords[:, 0] - pos_c) / spacing_c + if 1: + rr, cc = skimage.draw.polygon(rows, cols) # rr --> y-axis, cc --> x-axis + voxel_mask_data[cc, rr, z_index] = class_id + else: + contour_mask = skimage.draw.polygon2mask(voxel_mask_data.shape[:2], np.hstack((rows[np.newaxis].T, cols[np.newaxis].T))) + contour_idxs = np.argwhere(contour_mask > 0) + rr, cc = contour_idxs[:,0], contour_idxs[:,1] + voxel_mask_data[cc, rr, z_index] = class_id + + # Step 2.99 - Debug + # f,axarr = plt.subplots(1,2); axarr[0].scatter(cols, rows); axarr[0].invert_yaxis(); axarr[1].imshow(contour_mask);plt.suptitle('Z={:f}'.format(contour[0, 2])); plt.show() + + else: + + # Step 2.1 - Gather all contours for a particular ROI + contours_all = [] + for contour in contour_obj['contours']: contours_all.extend(contour) + contours_all = np.array(contours_all).reshape((-1,3)) + + # Step 2.2 - Split contours on the basis of z-value + for c_id, contour in enumerate([contours_all[contours_all[:,2]==z_pos] for z_pos in np.unique(contours_all[:,2])]): + + # Step 2.3 - Get the z-index of a particular z-value + assert np.amax(np.abs(np.diff(contour[:, 2]))) == 0 + z_index = z.index(pydicom.valuerep.DSfloat(float(round(contour[0, 2])))) + + # Step 2.4 - Polygonize + rows = (contour[:, 1] - pos_r) / spacing_r # pixel_idx = f(real_world_idx, ct_resolution) + cols = (contour[:, 0] - pos_c) / spacing_c + rr, cc = skimage.draw.polygon(rows, cols) + voxel_mask_data[cc, rr, z_index] = class_id + + # Step 2.99 - Debug + # if class_id == 8 and c_id > 3 and c_id < 7: # larynx + # import matplotlib.pyplot as plt + # f,axarr = plt.subplots(1,2); axarr[0].scatter(cols, rows); axarr[0].invert_yaxis(); axarr[1].scatter(cc, rr);axarr[1].invert_yaxis();plt.suptitle('Z={:f}'.format(contour[0, 2])); plt.show() + # pdb.set_trace() + + except: + print (' --- [ERROR][utils.process_contours()] contour-number:', contour_obj['number'], ' || contour-label:', contour_obj['name']) + traceback.print_exc() + + return voxel_mask_data + +def volumize_rtstruct(patient_rtstruct_path, params, params_labelinfo): + """ + This function takes a .dcm file (modality=RTSTRUCT) in a folder and converts it into a numpy mask + + Parameters + ---------- + patient_rtstruct_path: Path + path to the .dcm file (modality=RTSTRUCT) + params: dictionary + A python dictionary containing the following keys - ['pixel_spacing', 'origin', 'z_vals']. + 'z_vals' is a list of all the depth values of the raw .dcm slices + """ + + # Step 0 - Init + mask_data_oars = [] + mask_data_external = [] + mask_headers = {config.KEYNAME_LABEL_OARS:[], config.KEYNAME_LABEL_EXTERNAL:[]} + LABEL_MAP_DCM = params_labelinfo.get(config.KEY_LABELMAP_DCM, None) + LABEL_MAP_FULL = params_labelinfo.get(config.KEY_LABEL_MAP_FULL, None) + LABEL_IDS_SPECIAL = params_labelinfo.get(config.KEY_ARTFORCE_DONUTSHAPED_IDS, None) + LABEL_MAP_EXTERNAL = params_labelinfo.get(config.KEY_LABEL_MAP_EXTERNAL, None) + + try: + + ds = pydicom.filereader.dcmread(patient_rtstruct_path) + + if ds.Modality == config.MODALITY_RTSTRUCT: + + if config.KEYNAME_SHAPE in params: + + # Step 1 - Extract all different contours (for OARs) + if LABEL_MAP_DCM is not None or LABEL_MAP_FULL is not None: + contours_oars, _ = extract_contours_from_rtstruct(ds, LABEL_MAP_DCM, LABEL_MAP_FULL) + if len(contours_oars): + mask_headers[config.KEYNAME_LABEL_OARS] = [contour['name'] for contour in contours_oars] + mask_data_oars = np.zeros(params[config.KEYNAME_SHAPE], dtype=np.uint8) + mask_data_oars = process_contours(contours_oars, params, mask_data_oars, special_processing_ids=LABEL_IDS_SPECIAL) + else: + print (' - [ERROR][volumize_rtstruct()] Len of OAR contours is 0') + + # Step 2 - Extract contour for "External" + if LABEL_MAP_EXTERNAL is not None: + contours_external, _ = extract_contours_from_rtstruct(ds, LABEL_MAP_DCM, LABEL_MAP_EXTERNAL) + if len(contours_external): + mask_headers[config.KEYNAME_LABEL_EXTERNAL] = [contour['name'] for contour in contours_external] + mask_data_external = np.zeros(params[config.KEYNAME_SHAPE], dtype=np.uint8) + mask_data_external = process_contours(contours_external, params, mask_data_external) + else: + print (' - [ERROR][volumize_rtstruct()] Len of External contours is 0') + + else: + print (' - [ERROR][volumize_rtstruct()] Issue with voxel params: ', params) + + else: + print (' - [ERROR][volumize_rtstruct()] Could not capture RTSTRUCT file') + + except: + traceback.print_exc() + pdb.set_trace() + + return mask_data_oars, mask_data_external, mask_headers + +############################################################ +# SAVING/READING RELATED # +############################################################ + +def save_csv(filepath, data_array): + Path(filepath).parent.absolute().mkdir(parents=True, exist_ok=True) + np.savetxt(filepath, data_array, fmt='%s') + +def read_csv(filepath): + data = np.loadtxt(filepath, dtype='str') + return data + +def write_json(json_data, json_filepath): + + Path(json_filepath).parent.absolute().mkdir(parents=True, exist_ok=True) + + with open(str(json_filepath), 'w') as fp: + json.dump(json_data, fp, indent=4, cls=NpEncoder) + +def read_json(json_filepath, verbose=True): + + if Path(json_filepath).exists(): + with open(str(json_filepath), 'r') as fp: + data = json.load(fp) + return data + else: + if verbose: print (' - [ERROR][read_json()] json_filepath does not exist: ', json_filepath) + return None + +class NpEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + else: + return super(NpEncoder, self).default(obj) + +def write_nrrd(filepath, data, spacing): + nrrd_headers = {'space':'left-posterior-superior', 'kinds': ['domain', 'domain', 'domain'], 'encoding':'gzip'} + space_directions = np.zeros((3,3), dtype=np.float32) + space_directions[[0,1,2],[0,1,2]] = np.array(spacing) + nrrd_headers['space directions'] = space_directions + + import nrrd + nrrd.write(str(filepath), data, nrrd_headers) + +############################################################ +# RANDOM # +############################################################ +def get_name_patient_study_id(meta): + try: + meta = np.array(meta) + meta = str(meta.astype(str)) + + meta_split = meta.split('-') + name = None + patient_id = None + study_id = None + + if len(meta_split) == 2: + name = meta_split[0] + patient_id = meta_split[1] + elif len(meta_split) == 3: + name = meta_split[0] + patient_id = meta_split[1] + study_id = meta_split[2] + + return name, patient_id, study_id + + except: + traceback.print_exc() + pdb.set_trace() + +def get_numbers(string): + return ''.join([s for s in string if s.isdigit()]) + +def create_folds(idxs, folds=4): + + # Step 0 - Init + res = {} + idxs = np.array(idxs) + + try: + for fold in range(folds): + print (' - [create_crossvalfolds()] fold: ', fold) + val_idxs = list(np.random.choice(np.arange(len((idxs))), size=int(len(idxs)/folds), replace=False)) + train_idxs = list( set(np.arange(len(idxs))) - set(val_idxs) ) + val_patients = idxs[val_idxs] + train_patients = idxs[train_idxs] + + res[fold+1] = {config.MODE_TRAIN: train_patients, config.MODE_VAL: val_patients} + + except: + traceback.print_exc() + pdb.set_trace() + + return res + +############################################################ +# DEBUG RELATED # +############################################################ + +def benchmark_model(model_time): + time.sleep(model_time) + +def benchmark(dataset_generator, model_time=0.05): + + import psutil + import humanize + import pynvml + pynvml.nvmlInit() + + device_id = pynvml.nvmlDeviceGetHandleByIndex(0) + process = psutil.Process(os.getpid()) + + print ('\n - [benchmark()]') + t99 = time.time() + steps = 0 + t0 = time.time() + for X,_,meta1,meta2 in dataset_generator: + t1 = time.time() + benchmark_model(model_time) + t2 = time.time() + steps += 1 + # print (' - Data Time: ', round(t1 - t0,5),'s || Model Time: ', round(t2-t1,2),'s', '(',X.shape,'), (',meta2.numpy(),')') + print (' - Data Time: ', round(t1 - t0,5),'s || Model Time: ', round(t2-t1,2),'s' \ + # , '(', humanize.naturalsize( process.memory_info().rss),'), ' \ + # , '(', '%.4f' % (pynvml.nvmlDeviceGetMemoryInfo(device_id).used/1024/1024/1000),'GB), ' \ + , '(', meta2.numpy(),')' + # , np.sum(meta1[:,-9:].numpy(), axis=1) + # , meta1[:,1:4].numpy() + ) + t0 = time.time() + t99 = time.time() - t99 + print ('\n - Total steps: ', steps) + print (' - Total time for dataset: ', round(t99,2), 's') + +def benchmark2(dataset_generator, model_time=0.05): + + import psutil + import humanize + import pynvml + pynvml.nvmlInit() + + device_id = pynvml.nvmlDeviceGetHandleByIndex(0) + process = psutil.Process(os.getpid()) + + print ('\n - [benchmark2()]') + t99 = time.time() + steps = 0 + t0 = time.time() + for X_moving,_,_,_,meta1,meta2 in dataset_generator: + t1 = time.time() + benchmark_model(model_time) + t2 = time.time() + steps += 1 + # print (' - Data Time: ', round(t1 - t0,5),'s || Model Time: ', round(t2-t1,2),'s', '(',X_moving.shape,'), (',meta2.numpy(),')') + print (' - Data Time: ', round(t1 - t0,5),'s || Model Time: ', round(t2-t1,2),'s' \ + , '(', humanize.naturalsize( process.memory_info().rss),'), ' \ + , '(', '%.4f' % (pynvml.nvmlDeviceGetMemoryInfo(device_id).used/1024/1024/1000),'GB), ' \ + # , '(', meta2.numpy(),')' + # , np.sum(meta1[:,-9:].numpy(), axis=1) + # , meta1[:,1:4].numpy() + ) + t0 = time.time() + t99 = time.time() - t99 + print ('\n - Total steps: ', steps) + print (' - Total time for dataset: ', round(t99,2), 's') + +def print_debug_header(): + print (' ============================================== ') + print (' DEBUG ') + print (' ============================================== ') + +def get_memory(pid): + try: + process = psutil.Process(pid) + return humanize.naturalsize(process.memory_info().rss) + except: + return '-1' + +############################################################ +# DATALOADER RELATED # +############################################################ + +def get_dataloader_3D_train(data_dir, dir_type=['train', 'train_additional'] + , dimension=3, grid=True, crop_init=True, resampled=True, mask_type=config.MASK_TYPE_ONEHOT + , transforms=[], filter_grid=False, random_grid=True + , parallel_calls=None, deterministic=False + , patient_shuffle=True + , centred_dataloader_prob=0.0 + , debug=False + , pregridnorm=True): + + debug = False + datasets = [] + + # Dataset 1 + for dir_type_ in dir_type: + + # Step 1 - Get dataset class + dataset_han_miccai2015 = HaNMICCAI2015Dataset(data_dir=data_dir, dir_type=dir_type_ + , dimension=dimension, grid=grid, crop_init=crop_init, resampled=resampled, mask_type=mask_type + , transforms=transforms, filter_grid=filter_grid, random_grid=random_grid + , parallel_calls=parallel_calls, deterministic=deterministic + , patient_shuffle=patient_shuffle + , centred_dataloader_prob=centred_dataloader_prob + , debug=debug + , pregridnorm=pregridnorm) + + # Step 2 - Training transforms + x_shape_w = dataset_han_miccai2015.w_grid + x_shape_h = dataset_han_miccai2015.h_grid + x_shape_d = dataset_han_miccai2015.d_grid + label_map = dataset_han_miccai2015.LABEL_MAP + transforms = [ + aug.Rotate3DSmall(label_map, mask_type) + , aug.Deform2Punt5D((x_shape_h, x_shape_w, x_shape_d), label_map, grid_points=50, stddev=4, div_factor=2, debug=False) + , aug.Translate(label_map, translations=[40,40]) + , aug.Noise(x_shape=(x_shape_h, x_shape_w, x_shape_d, 1), mean=0.0, std=0.1) + ] + dataset_han_miccai2015.transforms = transforms + + # Step 3 - Training filters for background-only grids + if filter_grid: + dataset_han_miccai2015.filter = aug.FilterByMask(len(dataset_han_miccai2015.LABEL_MAP), dataset_han_miccai2015.SAMPLER_PERC) + + # Step 4 - Append to list + datasets.append(dataset_han_miccai2015) + + dataset = ZipDataset(datasets) + return dataset + +def get_dataloader_3D_train_eval(data_dir, dir_type=['train_train_additional'] + , dimension=3, grid=True, crop_init=True, resampled=True, mask_type=config.MASK_TYPE_ONEHOT + , transforms=[], filter_grid=False, random_grid=False + , parallel_calls=None, deterministic=True + , patient_shuffle=False + , centred_dataloader_prob=0.0 + , debug=False + , pregridnorm=True): + + datasets = [] + + # Dataset 1 + for dir_type_ in dir_type: + + # Step 1 - Get dataset class + dataset_han_miccai2015 = HaNMICCAI2015Dataset(data_dir=data_dir, dir_type=dir_type_ + , dimension=dimension, grid=grid, crop_init=crop_init, resampled=resampled, mask_type=mask_type + , transforms=transforms, filter_grid=filter_grid, random_grid=random_grid + , parallel_calls=parallel_calls, deterministic=deterministic + , patient_shuffle=patient_shuffle + , centred_dataloader_prob=centred_dataloader_prob + , debug=debug + , pregridnorm=pregridnorm) + + # Step 2 - Training transforms + # None + + # Step 3 - Training filters for background-only grids + # None + + # Step 4 - Append to list + datasets.append(dataset_han_miccai2015) + + dataset = ZipDataset(datasets) + return dataset + +def get_dataloader_3D_test_eval(data_dir, dir_type=['test_offsite'] + , dimension=3, grid=True, crop_init=True, resampled=True, mask_type=config.MASK_TYPE_ONEHOT + , transforms=[], filter_grid=False, random_grid=False + , parallel_calls=None, deterministic=True + , patient_shuffle=False + , debug=False + , pregridnorm=True): + + datasets = [] + + # Dataset 1 + for dir_type_ in dir_type: + + # Step 1 - Get dataset class + dataset_han_miccai2015 = HaNMICCAI2015Dataset(data_dir=data_dir, dir_type=dir_type_ + , dimension=dimension, grid=grid, crop_init=crop_init, resampled=resampled, mask_type=mask_type + , transforms=transforms, filter_grid=filter_grid, random_grid=random_grid + , parallel_calls=parallel_calls, deterministic=deterministic + , patient_shuffle=patient_shuffle + , debug=debug + , pregridnorm=pregridnorm) + + # Step 2 - Testing transforms + # None + + # Step 3 - Testing filters for background-only grids + # None + + # Step 4 - Append to list + datasets.append(dataset_han_miccai2015) + + dataset = ZipDataset(datasets) + return dataset + +def get_dataloader_deepmindtcia(data_dir + , dir_type=[config.DATALOADER_DEEPMINDTCIA_TEST] + , annotator_type=[config.DATALOADER_DEEPMINDTCIA_ONC] + , grid=True, crop_init=True, resampled=True, mask_type=config.MASK_TYPE_ONEHOT + , transforms=[], filter_grid=False, random_grid=False, pregridnorm=True + , parallel_calls=None, deterministic=False + , patient_shuffle=True + , centred_dataloader_prob = 0.0 + , debug=False): + + datasets = [] + + # Dataset 1 + for dir_type_ in dir_type: + + for anno_type_ in annotator_type: + + # Step 1 - Get dataset class + dataset_han_deepmindtcia = HaNDeepMindTCIADataset(data_dir=data_dir, dir_type=dir_type_, annotator_type=anno_type_ + , grid=grid, crop_init=crop_init, resampled=resampled, mask_type=mask_type, pregridnorm=pregridnorm + , transforms=transforms, filter_grid=filter_grid, random_grid=random_grid + , parallel_calls=parallel_calls, deterministic=deterministic + , patient_shuffle=patient_shuffle + , centred_dataloader_prob=centred_dataloader_prob + , debug=debug) + + # Step 2 - Append to list + datasets.append(dataset_han_deepmindtcia) + + dataset = ZipDataset(datasets) + return dataset + diff --git a/src/dataloader/utils_viz.py b/src/dataloader/utils_viz.py new file mode 100644 index 0000000..54a0bc4 --- /dev/null +++ b/src/dataloader/utils_viz.py @@ -0,0 +1,437 @@ +import pdb +import copy +import traceback +import numpy as np +from pathlib import Path + +import matplotlib +import matplotlib.pyplot as plt + +import medloader.dataloader.utils as utils +import medloader.dataloader.config as config + +def cmap_for_dataset(dataset): + # dataset = ZipDataset + LABEL_COLORS = dataset.get_label_colors() + cmap_me = matplotlib.colors.ListedColormap(np.array([*LABEL_COLORS.values()])/255.0) + norm = matplotlib.colors.BoundaryNorm(boundaries=range(0,cmap_me.N+1), ncolors=cmap_me.N) + + return cmap_me, norm + +############################################################ +# 2D # +############################################################ + +def get_info_from_label_id(label_id, LABEL_MAP, LABEL_COLORS=None): + """ + The label_id param has to be greater than 0 + """ + + label_name = [label for label in LABEL_MAP if LABEL_MAP[label] == label_id] + if len(label_name): + label_name = label_name[0] + else: + label_name = None + + label_color = None + if LABEL_COLORS is not None: + label_color = np.array(LABEL_COLORS[label_id]) + if np.any(label_color > 1): + label_color = label_color/255.0 + + return label_name, label_color + +def viz_slice_raw_batch(slices_raw, slices_mask, meta, dataset): + try: + + slices_raw = slices_raw[:,:,:,0] + batch_size = slices_raw.shape[0] + + LABEL_COLORS = getattr(config, dataset.name)['LABEL_COLORS'] + cmap_me = matplotlib.colors.ListedColormap(np.array([*LABEL_COLORS.values()])/255.0) + norm = matplotlib.colors.BoundaryNorm(boundaries=range(0,cmap_me.N+1), ncolors=cmap_me.N) + # for i in range(cmap_me.N): print (i, ':', norm.__call__(i)) + + f, axarr = plt.subplots(2,batch_size, figsize=config.FIGSIZE) + if batch_size == 1: + axarr = axarr.reshape(2,batch_size) + + # Loop over all batches + for batch_id in range(batch_size): + slice_raw= slices_raw[batch_id].numpy() + idxs_update = np.argwhere(slice_raw < -200) + slice_raw[idxs_update[:,0], idxs_update[:,1]] = -200 + idxs_update = np.argwhere(slice_raw > 400) + slice_raw[idxs_update[:,0], idxs_update[:,1]] = 400 + axarr[1,batch_id].imshow(slice_raw, cmap='gray') + axarr[0,batch_id].imshow(slice_raw, cmap='gray') + + # Create a mask + label_mask_show = np.zeros(slices_mask[batch_id,:,:,0].shape) + for label_id in range(slices_mask.shape[-1]): + label_mask = slices_mask[batch_id,:,:,label_id].numpy() + if np.sum(label_mask) > 0: + label_idxs = np.argwhere(label_mask > 0) + if np.any(label_mask_show[label_idxs[:,0], label_idxs[:,1]] > 0): + print (' - [utils.viz_slice_raw_batch()] Label overload by label_id:{}'.format(label_id)) + + label_mask_show[label_idxs[:,0], label_idxs[:,1]] = label_id + + if len(label_idxs) < 20: + label_name, _ = get_info_from_label_id(label_id, dataset) + print (' - [utils.viz_slice_raw_batch()] label_id:{} || label: {} || count: {}'.format(label_id+1, label_name, len(label_idxs))) + + + # Show the mask + axarr[0,batch_id].imshow(label_mask_show.astype(np.float32), cmap=cmap_me, norm=norm, interpolation='none', alpha=0.3) + + # Add legend + for label_id_mask in np.unique(label_mask_show): + if label_id_mask != 0: + label, color = get_info_from_label_id(label_id_mask, dataset) + axarr[0,batch_id].scatter(0,0, color=color, label=label + '(' + str(int(label_id_mask))+')') + leg = axarr[0,batch_id].legend(fontsize=8) + for legobj in leg.legendHandles: + legobj.set_linewidth(5.0) + + # Add title + if meta is not None and dataset is not None: + filename = '\n'.join(meta[batch_id].numpy().decode('utf-8').split('-')) + axarr[0,batch_id].set_title(dataset.name + '\n' + filename + '\n(BatchId={})'.format(batch_id)) + + mng = plt.get_current_fig_manager() + try:mng.window.showMaximized() + except:pass + try:mng.window.maxsize() + except:pass + plt.show() + + + except: + traceback.print_exc() + pdb.set_trace() + +def viz_slice_raw_batch_datasets(slices_raw, slices_mask, meta1, meta2, datasets): + try: + + slices_raw = slices_raw[:,:,:,0] + batch_size = slices_raw.shape[0] + + f, axarr = plt.subplots(2,batch_size, figsize=config.FIGSIZE) + if batch_size == 1: + axarr = axarr.reshape(2,batch_size) + + # Loop over all batches + for batch_id in range(batch_size): + axarr[1,batch_id].imshow(slices_raw[batch_id], cmap='gray') + axarr[0,batch_id].imshow(slices_raw[batch_id], cmap='gray') + + meta2_batchid = meta2[batch_id].numpy().decode('utf-8').split('-')[0] + dataset_batch = '' + for dataset in datasets: + if dataset.name == meta2_batchid: + dataset_batch = dataset + + LABELID_BACKGROUND = getattr(config, dataset_batch.name)['LABELID_BACKGROUND'] + LABEL_COLORS = getattr(config, dataset_batch.name)['LABEL_COLORS'] + cmap_me = matplotlib.colors.ListedColormap(np.array([*LABEL_COLORS.values()])/255.0) + norm = matplotlib.colors.BoundaryNorm(boundaries=range(1,12), ncolors=cmap_me.N) + + # Create a mask + label_mask_show = np.zeros(slices_mask[batch_id,:,:,0].shape) + for label_id in range(slices_mask.shape[-1]): + label_mask = slices_mask[batch_id,:,:,label_id].numpy() + if np.sum(label_mask) > 0: + label_idxs = np.argwhere(label_mask > 0) + if label_id == slices_mask.shape[-1] - 1: + label_id_actual = LABELID_BACKGROUND + else: + label_id_actual = label_id + 1 + + if np.any(label_mask_show[label_idxs[:,0], label_idxs[:,1]] > 0): + print (' - [utils.viz_slice_raw_batch()] Label overload by label_id:{}'.format(label_id)) + if len(label_idxs) < 20: + label_name, _ = get_info_from_label_id(label_id_actual, dataset_batch) + print (' - [utils.viz_slice_raw_batch()] label_id:{} || label: {} || count: {}'.format(label_id+1, label_name, len(label_idxs))) + label_mask_show[label_idxs[:,0], label_idxs[:,1]] = label_id_actual + # label_name, label_color = get_info_from_label_id(label_id+1) + + # Show the mask + # axarr[0,batch_id].imshow(label_mask_show, alpha=0.5, cmap=cmap_me, norm=norm) + axarr[0,batch_id].imshow(label_mask_show.astype(np.float32), cmap=cmap_me, norm=norm) + + # Add legend + for label_id_mask in np.unique(label_mask_show): + if label_id_mask != 0: + label, color = get_info_from_label_id(label_id_mask, dataset_batch) + axarr[0,batch_id].scatter(0,0, color=color, label=label + '(' + str(int(label_id_mask))+')') + leg = axarr[0,batch_id].legend(fontsize=8) + for legobj in leg.legendHandles: + legobj.set_linewidth(5.0) + + # Add title + filepath = dataset_batch.paths_raw[meta1[batch_id]] + if len(filepath) == 2: + filepath = filepath[0] + ' ' + filepath[1] + filename = Path(filepath).parts[-1] + axarr[0,batch_id].set_title(dataset_batch.name + '\n' + filename + '\n(BatchId={})'.format(batch_id)) + + mng = plt.get_current_fig_manager() + mng.window.showMaximized() + plt.show() + + + except: + traceback.print_exc() + pdb.set_trace() + +def viz_slice(slice_raw, slice_mask, meta=None, dataset=None): + """ + - slice_raw: [H,W] + - slice_mask: [H,W] + """ + try: + import matplotlib.pyplot as plt + f, axarr = plt.subplots(2,2, figsize=config.FIGSIZE) + axarr[0,0].imshow(slice_raw, cmap='gray') + axarr[0,1].imshow(slice_mask) + axarr[1,0].hist(slice_raw) + axarr[1,1].imshow(slice_raw, cmap='gray') + axarr[1,1].imshow(slice_mask, alpha=0.5) + + axarr[0,0].set_title('Raw image') + axarr[0,1].set_title('Raw mask') + axarr[1,0].set_title('Raw Value Histogram') + axarr[1,1].set_title('Raw + Mask') + + if meta is not None and dataset is not None: + filename = Path(dataset.paths_raw[meta[0]]).parts[-1] + plt.suptitle(filename) + + try: + mng = plt.get_current_fig_manager() + mng.window.showMaximized() + except: pass + + plt.show() + + except: + traceback.print_exc() + pdb.set_trace() + +############################################################ +# 3D # +############################################################ + +def viz_3d_slices(voxel_img, voxel_mask, dataset, meta1, meta2, plots=4): + """ + voxel_img : [B,H,W,D, C=1] + voxel_mask: [B,H,W,D, C] + """ + try: + cmap_me, norm = cmap_for_dataset(dataset) + + for batch_id in range(voxel_img.shape[0]): + voxel_img_batch = voxel_img[batch_id,:,:,:,0]*(dataset.HU_MAX - dataset.HU_MIN) + dataset.HU_MIN + voxel_mask_batch = np.argmax(voxel_mask[batch_id,:,:,:,:], axis=-1) + height = voxel_img_batch.shape[-1] + + f,axarr = plt.subplots(2,plots) + for plt_idx, z_idx in enumerate(np.random.choice(height, plots, replace=False)): + axarr[0][plt_idx].imshow(voxel_img_batch[:,:,z_idx], cmap='gray') + axarr[1][plt_idx].imshow(voxel_img_batch[:,:,z_idx], cmap='gray', alpha=0.2) + axarr[1][plt_idx].imshow(voxel_mask_batch[:,:,z_idx], cmap=cmap_me, norm=norm, interpolation='none') + axarr[0][plt_idx].set_title('Slice: {}/{}'.format(z_idx+1, height)) + + name, patient_id, study_id = utils.get_name_patient_study_id(meta2[batch_id]) + if study_id is not None: + filename = patient_id + '\n' + study_id + else: + filename = patient_id + plt.suptitle(filename) + plt.show() + + + except: + traceback.print_exc() + pdb.set_trace() + +def viz_3d_data(voxel_imgs, voxel_masks, meta1, meta2, dataset): + """ + - voxel_imgs: [B,H,W,D] + """ + + try: + """ + Ref: https://plotly.com/python/visualizing-mri-volume-slices/ + """ + import plotly.graph_objects as go + + voxel_imgs = voxel_imgs.numpy() + for batch_id, voxel_img in enumerate(voxel_imgs): + + # Plot Volume + r,c,d = voxel_img.shape + d_ = d - 1 + + fig = go.Figure( + frames=[ + go.Frame(data=go.Surface( + z=(d_ - k) * np.ones((r, c)), + surfacecolor=np.flipud(voxel_img[:,:,d_ - k]), + # surfacecolor=voxel_img[:,:,d_ - k], + cmin=-1000, cmax=3000 + ) + , name=str(k) # you need to name the frame for the animation to behave properly + ) for k in range(d) + ] + ) + + # Add data to be displayed before animation starts + fig.add_trace(go.Surface( + z=d_ * np.ones((r, c)) + , surfacecolor=np.flipud(voxel_img[:,:,d_]) + # , surfacecolor=voxel_img[:,:,d_] + , colorscale='Gray' + # , cmin=config.HU_MIN, cmax=config.HU_MIN + , cmin=-1000, cmax=3000 + , colorbar=dict(thickness=20, ticklen=4) + ) + ) + + def frame_args(duration): + return { + "frame": {"duration": duration}, + "mode": "immediate", + "fromcurrent": True, + "transition": {"duration": duration, "easing": "linear"}, + } + + sliders = [ + { + "pad": {"b": 10, "t": 60}, + "len": 0.9, + "x": 0.1, + "y": 0, + "steps": [ + { + "args": [[f.name], frame_args(0)], + "label": str(k), + "method": "animate", + } + for k, f in enumerate(fig.frames) + ], + } + ] + + fig.update_layout( + title='Slices in volumetric data', + width=600, + height=600, + scene=dict( + # zaxis=dict(range=[0, d*0.1], autorange=False), + zaxis=dict(range=[0, d], autorange=False), + aspectratio=dict(x=1, y=1, z=1), + ), + updatemenus = [ + { + "buttons": [ + { + "args": [None, frame_args(50)], + "label": "▶", # play symbol + "method": "animate", + }, + { + "args": [[None], frame_args(0)], + "label": "◼", # pause symbol + "method": "animate", + }, + ], + "direction": "left", + "pad": {"r": 10, "t": 70}, + "type": "buttons", + "x": 0.1, + "y": 0, + } + ], + sliders=sliders + ) + fig.show() + + # Plot Mask (3D) + + except: + traceback.print_exc() + pdb.set_trace() + +def viz_3d_mask(voxel_masks, dataset, meta1, meta2, label_map_full=False): + """ + Expects a [B,H,W,D] shaped mask + """ + try: + import plotly.graph_objects as go + import skimage + import skimage.measure + + LABEL_MAP = dataset.get_label_map(label_map_full=label_map_full) + LABEL_COLORS = dataset.get_label_colors() + label_ids = meta1[:,-len(LABEL_MAP):].numpy() + if np.sum(label_ids) < len(meta1): # if <= batch_size: return 0 + print (' - [viz_3d_mask()] No labels present: np.sum(label_ids): ', np.sum(label_ids)) + return 0 + + # Step 1 - Loop over all batch_ids + print (' ------------------------ VIZ 3D ------------------------') + for batch_id, voxel_mask in enumerate(voxel_masks): + fig = go.Figure() + label_ids = np.unique(voxel_mask) + print (' - label_ids: ', label_ids) + + import tensorflow as tf + # voxel_mask_ = tf.image.rot90(voxel_mask, k=3) + # voxel_mask_ = tf.transpose(tf.reverse(voxel_mask, [0]), [1,0,2]) + voxel_mask_ = voxel_mask + print (' -voxel_mask: ', voxel_mask.shape) + + # Step 2 - Loop over all label_ids + for i_, label_id in enumerate(label_ids): + + if label_id == 0 : continue + name, color = get_info_from_label_id(label_id, LABEL_MAP, LABEL_COLORS) + print (' - label_id: ', label_id, '(',name,')') + + # Get surface information + voxel_mask_tmp = np.array(copy.deepcopy(voxel_mask_)).astype(config.DATATYPE_VOXEL_MASK) + voxel_mask_tmp[voxel_mask_tmp != label_id] = 0 + verts, faces, _, _ = skimage.measure.marching_cubes(voxel_mask_tmp, step_size=1) + + # https://plotly.github.io/plotly.py-docs/generated/plotly.graph_objects.Mesh3d.html + visible=True + fig.add_trace( + go.Mesh3d( + x=verts[:,0], y=verts[:,1], z=verts[:,2] + , i=faces[:,0],j=faces[:,1],k=faces[:,2] + , color='rgb({},{},{})'.format(*color) + , name=name, showlegend=True + , visible=visible + # , lighting=go.mesh3d.Lighting(ambient=0.5) + ) + ) + + fig.update_layout( + scene = dict( + xaxis = dict(nticks=10, range=[0,voxel_mask.shape[0]], title='X-axis'), + yaxis = dict(nticks=10, range=[0,voxel_mask.shape[1]]), + zaxis = dict(nticks=10, range=[0,voxel_mask.shape[2]]), + ) + ,width=700, + margin=dict(r=20, l=50, b=10, t=50) + ) + fig.update_layout(legend_title_text='Labels', showlegend=True) + fig.update_layout(scene_aspectmode='cube') + fig.update_layout(title_text='{} (BatchId={})'.format(meta2, batch_id)) + fig.show() + + except: + traceback.print_exc() + pdb.set_trace() diff --git a/src/model/__init__.py b/src/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/model/losses.py b/src/model/losses.py new file mode 100644 index 0000000..61b4833 --- /dev/null +++ b/src/model/losses.py @@ -0,0 +1,488 @@ +# Import internal libraries +import src.config as config + +# Import external libraries +import pdb +import copy +import time +import skimage +import skimage.util +import traceback +import numpy as np +from pathlib import Path +import SimpleITK as sitk +import tensorflow as tf +import tensorflow_addons as tfa +try: import tensorflow_probability as tfp +except: pass +import matplotlib.pyplot as plt + +_EPSILON = tf.keras.backend.epsilon() +MAX_FUNC_TIME = 300 + +############################################################ +# UTILS # +############################################################ + +@tf.function +def get_mask(mask_1D, Y): + # mask_1D: [[1,0,0,0, ...., 1]] - [B,L] something like this + # Y : [B,H,W,D,L] + if 0: + mask = tf.expand_dims(tf.expand_dims(tf.expand_dims(mask_1D, axis=1),axis=1),axis=1) # mask.shape=[B,1,1,1,L] + mask = tf.tile(mask, multiples=[1,Y.shape[1],Y.shape[2],Y.shape[3],1]) # mask.shape = [B,H,W,D,L] + mask = tf.cast(mask, tf.float32) + return mask + else: + return mask_1D + +def get_largest_component(y, verbose=True): + """ + Takes as input predicted predicted probs and returns a binary mask with the largest component for each label + + Params + ------ + y: [H,W,D,L] --> predicted probabilities of np.float32 + - Ref: https://simpleitk.readthedocs.io/en/v1.2.4/Documentation/docs/source/filters.html + """ + + try: + + label_count = y.shape[-1] + y = np.argmax(y, axis=-1) + y = np.concatenate([np.expand_dims(y == label_id, axis=-1) for label_id in range(label_count)],axis=-1) + ccFilter = sitk.ConnectedComponentImageFilter() + + for label_id in range(y.shape[-1]): + + if label_id > 0 and label_id not in []: # Note: pointless to do it for background + if verbose: t0 = time.time() + y_label = y[:,:,:,label_id] # [H,W,D] + + component_img = ccFilter.Execute(sitk.GetImageFromArray(y_label.astype(np.uint8))) + component_array = sitk.GetArrayFromImage(component_img) # will contain pseduo-labels given to different components + component_count = ccFilter.GetObjectCount() + + if component_count >= 2: # background and foreground + component_sizes = np.bincount(component_array.flatten()) # count the voxels belong to different components + component_sizes_sorted = np.asarray(sorted(component_sizes, reverse=True)) + if verbose: print ('\n - [INFO][losses.get_largest_component()] label_id: ', label_id, ' || sizes: ', component_sizes_sorted) + + component_largest_sortedidx = np.argwhere(component_sizes == component_sizes_sorted[1])[0][0] # Note: idx=1 as idx=0 is background # Risk: for optic nerves + y_label_mask = (component_array == component_largest_sortedidx).astype(np.float32) + y[:,:,:,label_id] = y_label_mask + if verbose: print (' - [INFO][losses.get_largest_component()]: label_id: ', label_id, '(',round(time.time() - t0,3),'s)') + else: + if verbose: print (' - [INFO][losses.get_largest_component()] label_id: ', label_id, ' has only background!!') + + # [TODO]: set other components as background (i.e. label=0) + + y = y.astype(np.float32) + return y + + except: + traceback.print_exc() + pdb.set_trace() + +def remove_smaller_components(y_true, y_pred, meta='', label_ids_small = [], verbose=False): + """ + Takes as input predicted probs and returns a binary mask by removing some of the smallest components for each label + + Params + ------ + y: [H,W,D,L] --> predicted probabilities of np.float32 + - Ref: https://simpleitk.readthedocs.io/en/v1.2.4/Documentation/docs/source/filters.html + """ + t0 = time.time() + + try: + + # Step 0 - Preprocess by selecting one voxel per class + y_pred_copy = copy.deepcopy(y_pred) # [H,W,D,L] with probs + label_count = y_pred_copy.shape[-1] + y_pred_copy = np.argmax(y_pred_copy, axis=-1) # [H,W,D] + y_pred_copy = np.concatenate([np.expand_dims(y_pred_copy == label_id, axis=-1) for label_id in range(label_count)],axis=-1) # [H,W,D,L] as a binary mask + + for label_id in range(y_pred_copy.shape[-1]): + + if label_id > 0: # Note: pointless to do it for background + if verbose: t0 = time.time() + y_label = y_pred_copy[:,:,:,label_id] # [H,W,D] + + # Step 1 - Get different components + ccFilter = sitk.ConnectedComponentImageFilter() + component_img = ccFilter.Execute(sitk.GetImageFromArray(y_label.astype(np.uint8))) + component_array = sitk.GetArrayFromImage(component_img) # will contain pseduo-labels given to different components + component_count = ccFilter.GetObjectCount() + component_sizes = np.bincount(component_array.flatten()) # count the voxels belong to different components + + # Step 2 - Evaluate each component + if component_count >= 1: # at least a foreground (along with background) + + # Step 2.1 - Sort them on the basis of voxel count + component_sizes_sorted = np.asarray(sorted(component_sizes, reverse=True)) + if verbose: + print ('\n - [INFO][losses.get_largest_component()] label_id: ', label_id, ' || sizes: ', component_sizes_sorted) + print (' - [INFO][losses.get_largest_component()] unique_comp_labels: ', np.unique(component_array)) + + # Step 2.1 - Remove really small components for good Hausdorff calculation + component_sizes_sorted_unique = np.unique(component_sizes_sorted[::-1]) # ascending order + for comp_size_id, comp_size in enumerate(component_sizes_sorted_unique): + if label_id in label_ids_small: + if comp_size <= config.MIN_SIZE_COMPONENT: + components_labels = [each[0] for each in np.argwhere(component_sizes == comp_size)] + for component_label in components_labels: + component_array[component_array == component_label] = 0 + else: + if comp_size_id < len(component_sizes_sorted_unique) - 2: # set to 0 except background and foreground + components_labels = [each[0] for each in np.argwhere(component_sizes == comp_size)] + for component_label in components_labels: + component_array[component_array == component_label] = 0 + if verbose: print (' - [INFO][losses.get_largest_component()] unique_comp_labels: ', np.unique(component_array)) + y_pred_copy[:,:,:,label_id] = component_array.astype(np.bool).astype(np.float32) + if verbose: print (' - [INFO][losses.get_largest_component()] label_id: ', label_id, '(',round(time.time() - t0,3),'s)') + + if 0: + # Step 1 - Hausdorff + y_true_label = y_true[:,:,:,label_id] + y_pred_label = y_pred_copy[:,:,:,label_id] + + hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter() + hausdorff_distance_filter.Execute(sitk.GetImageFromArray(y_true_label.astype(np.uint8)), sitk.GetImageFromArray(y_pred_label.astype(np.uint8))) + print (' - hausdorff: ', hausdorff_distance_filter.GetHausdorffDistance()) + + # Step 2 - 95% Hausdorff + y_true_contour = sitk.LabelContour(sitk.GetImageFromArray(y_true_label.astype(np.uint8)), False) + y_pred_contour = sitk.LabelContour(sitk.GetImageFromArray(y_pred_label.astype(np.uint8)), False) + y_true_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(y_true_contour, squaredDistance=False, useImageSpacing=True)) # i.e. euclidean distance + y_pred_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(y_pred_contour, squaredDistance=False, useImageSpacing=True)) + dist_y_pred = sitk.GetArrayViewFromImage(y_pred_distance_map)[sitk.GetArrayViewFromImage(y_true_distance_map)==0] # pointless? + dist_y_true = sitk.GetArrayViewFromImage(y_true_distance_map)[sitk.GetArrayViewFromImage(y_pred_distance_map)==0] + print (' - 95 hausdorff:', np.percentile(dist_y_true,95), np.percentile(dist_y_pred,95)) + + else: + print (' - [INFO][losses.get_largest_component()] for meta: {} || label_id: {} has only background!! ({}) '.format(meta, label_id, component_sizes)) + + if time.time() - t0 > MAX_FUNC_TIME: + print (' - [INFO][losses.get_largest_component()] Taking too long: ', round(time.time() - t0,2),'s') + + y = y_pred_copy.astype(np.float32) + return y + + except: + traceback.print_exc() + pdb.set_trace() + +def get_hausdorff(y_true, y_pred, spacing, verbose=False): + """ + :param y_true: [H, W, D, L] + :param y_pred: [H, W, D, L] + - Ref: https://simpleitk.readthedocs.io/en/master/filters.html?highlight=%20HausdorffDistanceImageFilter()#simpleitk-filters + - Ref: http://insightsoftwareconsortium.github.io/SimpleITK-Notebooks/Python_html/34_Segmentation_Evaluation.html + """ + + try: + hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter() + + hausdorff_labels = [] + for label_id in range(y_pred.shape[-1]): + + y_true_label = y_true[:,:,:,label_id] # [H,W,D] + y_pred_label = y_pred[:,:,:,label_id] # [H,W,D] + + # Calculate loss (over all pixels) + if np.sum(y_true_label) > 0: + if label_id > 0: + try: + if np.sum(y_true_label) > 0: + y_true_sitk = sitk.GetImageFromArray(y_true_label.astype(np.uint8)) + y_pred_sitk = sitk.GetImageFromArray(y_pred_label.astype(np.uint8)) + y_true_sitk.SetSpacing(tuple(spacing)) + y_pred_sitk.SetSpacing(tuple(spacing)) + hausdorff_distance_filter.Execute(y_true_sitk, y_pred_sitk) + hausdorff_labels.append(hausdorff_distance_filter.GetHausdorffDistance()) + if verbose: print (' - [INFO][get_hausdorff()] label_id: {} || hausdorff: {}'.format(label_id, hausdorff_labels[-1])) + else: + hausdorff_labels.append(-1) + except: + print (' - [ERROR][get_hausdorff()] label_id: {}'.format(label_id)) + hausdorff_labels.append(-1) + else: + hausdorff_labels.append(0) + else: + hausdorff_labels.append(0) + + hausdorff_labels = np.array(hausdorff_labels) + hausdorff = np.mean(hausdorff_labels[hausdorff_labels>0]) + return hausdorff, hausdorff_labels + + except: + traceback.print_exc() + pdb.set_trace() + +def get_surface_distances(y_true, y_pred, spacing, verbose=False): + + """ + :param y_true: [H, W, D, L] --> binary mask of np.float32 + :param y_pred: [H, W, D, L] --> also, binary mask of np.float32 + - Ref: https://discourse.itk.org/t/computing-95-hausdorff-distance/3832/3 + - Ref: https://git.lumc.nl/mselbiallyelmahdy/jointregistrationsegmentation-via-crossstetch/-/blob/master/lib/label_eval.py + """ + + try: + hausdorff_labels = [] + hausdorff95_labels = [] + msd_labels = [] + for label_id in range(y_pred.shape[-1]): + + y_true_label = y_true[:,:,:,label_id] # [H,W,D] + y_pred_label = y_pred[:,:,:,label_id] # [H,W,D] + + # Calculate loss (over all pixels) + if np.sum(y_true_label) > 0: + + if label_id > 0: + if np.sum(y_pred_label) > 0: + y_true_sitk = sitk.GetImageFromArray(y_true_label.astype(np.uint8)) + y_pred_sitk = sitk.GetImageFromArray(y_pred_label.astype(np.uint8)) + y_true_sitk.SetSpacing(tuple(spacing)) + y_pred_sitk.SetSpacing(tuple(spacing)) + y_true_contour = sitk.LabelContour(y_true_sitk, False, backgroundValue=0) + y_pred_contour = sitk.LabelContour(y_pred_sitk, False, backgroundValue=0) + + y_true_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(y_true_sitk, squaredDistance=False, useImageSpacing=True)) + y_pred_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(y_pred_sitk, squaredDistance=False, useImageSpacing=True)) + dist_y_true = sitk.GetArrayFromImage(y_true_distance_map*sitk.Cast(y_pred_contour, sitk.sitkFloat32)) + dist_y_pred = sitk.GetArrayFromImage(y_pred_distance_map*sitk.Cast(y_true_contour, sitk.sitkFloat32)) + dist_y_true = dist_y_true[dist_y_true != 0] + dist_y_pred = dist_y_pred[dist_y_pred != 0] + + if len(dist_y_true): + msd_labels.append(np.mean(np.array(list(dist_y_true) + list(dist_y_pred)))) + if len(dist_y_true) and len(dist_y_pred): + hausdorff_labels.append( np.max( [np.max(dist_y_true), np.max(dist_y_pred)] ) ) + hausdorff95_labels.append(np.max([np.percentile(dist_y_true, 95), np.percentile(dist_y_pred, 95)])) + elif len(dist_y_true) and not len(dist_y_pred): + hausdorff_labels.append(np.max(dist_y_true)) + hausdorff95_labels.append(np.percentile(dist_y_true, 95)) + elif not len(dist_y_true) and not len(dist_y_pred): + hausdorff_labels.append(np.max(dist_y_pred)) + hausdorff95_labels.append(np.percentile(dist_y_pred, 95)) + else: + hausdorff_labels.append(-1) + hausdorff95_labels.append(-1) + msd_labels.append(-1) + + else: + hausdorff_labels.append(-1) + hausdorff95_labels.append(-1) + msd_labels.append(-1) + + else: + hausdorff_labels.append(0) + hausdorff95_labels.append(0) + msd_labels.append(0) + + else: + hausdorff_labels.append(0) + hausdorff95_labels.append(0) + msd_labels.append(0) + + hausdorff_labels = np.array(hausdorff_labels) + hausdorff_mean = np.mean(hausdorff_labels[hausdorff_labels > 0]) + hausdorff95_labels = np.array(hausdorff95_labels) + hausdorff95_mean = np.mean(hausdorff95_labels[hausdorff95_labels > 0]) + msd_labels = np.array(msd_labels) + msd_mean = np.mean(msd_labels[msd_labels > 0]) + return hausdorff_mean, hausdorff_labels, hausdorff95_mean, hausdorff95_labels, msd_mean, msd_labels + + except: + traceback.print_exc() + return -1, [], -1, [] + +def dice_numpy_slice(y_true_slice, y_pred_slice): + """ + Specifically designed for 2D slices + + Params + ------ + y_true_slice: [H,W] + y_pred_slice: [H,W] + """ + + sum_true = np.sum(y_true_slice) + sum_pred = np.sum(y_pred_slice) + if sum_true > 0 and sum_pred > 0: + num = 2 * np.sum(y_true_slice *y_pred_slice) + den = sum_true + sum_pred + return num/den + elif sum_true > 0 and sum_pred == 0: + return 0 + elif sum_true == 0 and sum_pred > 0: + return -0.1 + else: + return -1 + +def dice_numpy(y_true, y_pred): + """ + :param y_true: [H, W, D, L] + :param y_pred: [H, W, D, L] + - Ref: V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation + - Ref: https://gist.github.com/jeremyjordan/9ea3032a32909f71dd2ab35fe3bacc08#file-soft_dice_loss-py + """ + + dice_labels = [] + for label_id in range(y_pred.shape[-1]): + + y_true_label = y_true[:,:,:,label_id] # [H,W,D] + y_pred_label = y_pred[:,:,:,label_id] + 1e-8 # [H,W,D] + + # Calculate loss (over all pixels) + if np.sum(y_true_label) > 0: + num = 2*np.sum(y_true_label * y_pred_label) + den = np.sum(y_true_label + y_pred_label) + dice_label = num/den + else: + dice_label = -1.0 + + dice_labels.append(dice_label) + + dice_labels = np.array(dice_labels) + dice = np.mean(dice_labels[dice_labels>0]) + return dice, dice_labels + +############################################################ +# LOSSES # +############################################################ + +@tf.function +def loss_dice_3d_tf_func(y_true, y_pred, label_mask, weights=[], verbose=False): + + """ + Calculates soft-DICE loss + + :param y_true: [B, H, W, C, L] + :param y_pred: [B, H, W, C, L] + :param label_mask: [B,L] + - Ref: V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation + - Ref: https://gist.github.com/jeremyjordan/9ea3032a32909f71dd2ab35fe3bacc08#file-soft_dice_loss-py + """ + + # Step 0 - Init + dice_labels = [] + label_mask = tf.cast(label_mask, dtype=tf.float32) # [B,L] + + # Step 1 - Get DICE of each sample in each label + y_pred = y_pred + _EPSILON + dice_labels = (2*tf.math.reduce_sum(y_true * y_pred, axis=[1,2,3]))/(tf.math.reduce_sum(y_true + y_pred, axis=[1,2,3])) # [B,H,W,D,L] -> [B,L] + dice_labels = dice_labels*label_mask # if mask of a label (e.g. background) has been explicitly set to 0, do not consider its loss + + # Step 2 - Mask results on the basis of ground truth availability + label_mask = tf.where(tf.math.greater(label_mask,0), label_mask, _EPSILON) # to handle division by 0 + dice_for_train = None + dice_labels_for_train = None + dice_labels_for_report = tf.math.reduce_sum(dice_labels,axis=0) / tf.math.reduce_sum(label_mask, axis=0) + dice_for_report = tf.math.reduce_mean(tf.math.reduce_sum(dice_labels,axis=1) / tf.math.reduce_sum(label_mask, axis=1)) + + # Step 3 - Weighted DICE + if len(weights): + label_weights = weights / tf.math.reduce_sum(weights) # nomalized + dice_labels_w = dice_labels * label_weights + dice_labels_for_train = tf.math.reduce_sum(dice_labels_w,axis=0) / tf.math.reduce_sum(label_mask, axis=0) + dice_for_train = tf.math.reduce_mean(tf.math.reduce_sum(dice_labels_w,axis=1) / tf.math.reduce_sum(label_mask, axis=1)) + else: + dice_labels_for_train = dice_labels_for_report + dice_for_train = dice_for_report + + # Step 4 - Return results + return 1.0 - dice_for_train, 1.0 - dice_labels_for_train, dice_for_report, dice_labels_for_report + +@tf.function +def loss_ce_3d_tf_func(y_true, y_pred, label_mask, weights=[], verbose=False): + """ + Calculates cross entropy loss + + Params + ------ + y_true : [B, H, W, C, L] + y_pred : [B, H, W, C, L] + label_mask: [B,L] + - Ref: https://www.dlology.com/blog/multi-class-classification-with-focal-loss-for-imbalanced-datasets/ + """ + + # Step 0 - Init + loss_labels = [] + label_mask = tf.cast(label_mask, dtype=tf.float32) + y_pred = y_pred + _EPSILON + + # Step 1.1 - Foreground loss + loss_labels = -1.0 * y_true * tf.math.log(y_pred) # [B,H,W,D,L] + loss_labels = label_mask * tf.math.reduce_sum(loss_labels, axis=[1,2,3]) # [B,H,W,D,L] --> [B,L] + + # Step 1.2 - Background loss + y_pred_neg = 1 - y_pred + _EPSILON + loss_labels_neg = -1.0 * (1 - y_true) * tf.math.log(y_pred_neg) # [B,H,W,D,L] + loss_labels_neg = label_mask * tf.math.reduce_sum(loss_labels_neg, axis=[1,2,3]) + loss_labels = loss_labels + loss_labels_neg + + # Step 2 - Mask results on the basis of ground truth availability + label_mask = tf.where(tf.math.greater(label_mask,0), label_mask, _EPSILON) # for reasons of division + loss_for_train = None + loss_labels_for_train = None + loss_labels_for_report = tf.math.reduce_sum(loss_labels,axis=0) / tf.math.reduce_sum(label_mask, axis=0) + loss_for_report = tf.math.reduce_mean(tf.math.reduce_sum(loss_labels,axis=1) / tf.math.reduce_sum(label_mask, axis=1)) + + # Step 3 - Weighted DICE + if len(weights): + label_weights = weights / tf.math.reduce_sum(weights) # nomalized + loss_labels_w = loss_labels * label_weights + loss_labels_for_train = tf.math.reduce_sum(loss_labels_w,axis=0) / tf.math.reduce_sum(label_mask, axis=0) + loss_for_train = tf.math.reduce_mean(tf.math.reduce_sum(loss_labels_w,axis=1) / tf.math.reduce_sum(label_mask, axis=1)) + else: + loss_labels_for_train = loss_labels_for_report + loss_for_train = loss_for_report + + # Step 4 - Return results + return loss_for_train, loss_labels_for_train, loss_for_report, loss_labels_for_report + +@tf.function +def loss_cebasic_3d_tf_func(y_true, y_pred, label_mask, weights=[], verbose=False): + """ + Calculates cross entropy loss + + Params + ------ + y_true : [B, H, W, C, L] + y_pred : [B, H, W, C, L] + label_mask: [B,L] + - Ref: https://www.dlology.com/blog/multi-class-classification-with-focal-loss-for-imbalanced-datasets/ + """ + + # Step 0 - Init + loss_labels = [] + label_mask = tf.cast(label_mask, dtype=tf.float32) + y_pred = y_pred + _EPSILON + + # Step 1 - Foreground loss + loss_labels = -1.0 * y_true * tf.math.log(y_pred) # [B,H,W,D,L] + loss_labels = label_mask * tf.math.reduce_sum(loss_labels, axis=[1,2,3]) # [B,H,W,D,L] --> [B,L] + + # Step 2 - Mask results on the basis of ground truth availability + label_mask = tf.where(tf.math.greater(label_mask,0), label_mask, _EPSILON) # for reasons of division + loss_for_train = None + loss_labels_for_train = None + loss_labels_for_report = tf.math.reduce_sum(loss_labels,axis=0) / tf.math.reduce_sum(label_mask, axis=0) + loss_for_report = tf.math.reduce_mean(tf.math.reduce_sum(loss_labels,axis=1) / tf.math.reduce_sum(label_mask, axis=1)) + + # Step 3 - Weighted DICE + if len(weights): + label_weights = weights / tf.math.reduce_sum(weights) # nomalized + loss_labels_w = loss_labels * label_weights + loss_labels_for_train = tf.math.reduce_sum(loss_labels_w,axis=0) / tf.math.reduce_sum(label_mask, axis=0) + loss_for_train = tf.math.reduce_mean(tf.math.reduce_sum(loss_labels_w,axis=1) / tf.math.reduce_sum(label_mask, axis=1)) + else: + loss_labels_for_train = loss_labels_for_report + loss_for_train = loss_for_report + + # Step 4 - Return results + return loss_for_train, loss_labels_for_train, loss_for_report, loss_labels_for_report \ No newline at end of file diff --git a/src/model/models.py b/src/model/models.py new file mode 100644 index 0000000..dbb27cb --- /dev/null +++ b/src/model/models.py @@ -0,0 +1,572 @@ +# Import internal libraries +import src.config as config + +# Import external libraries +import pdb +import traceback +import numpy as np +import tensorflow as tf +import tensorflow_addons as tfa +import tensorflow_probability as tfp + +############################################################ +# 3D MODEL BLOCKS # +############################################################ + +class ConvBlock3D(tf.keras.layers.Layer): + + def __init__(self, filters, kernel_size=(3,3,3), strides=(1, 1, 1), padding='same' + , dilation_rate=(1,1,1) + , activation='relu' + , trainable=False + , pool=False + , name=''): + super(ConvBlock3D, self).__init__(name='{}_ConvBlock3D'.format(name)) + + # Step 0 - Init + self.pool = pool + self.filters = filters + self.trainable = trainable + if type(filters) == int: + filters = [filters] + + # Step 1 - Conv Blocks + self.conv_layer = tf.keras.Sequential() + for filter_id, filter_count in enumerate(filters): + self.conv_layer.add( + tf.keras.layers.Conv3D(filters=filter_count, kernel_size=kernel_size, strides=strides, padding=padding + , dilation_rate=dilation_rate + , activation=activation + , name='Conv_{}'.format(filter_id)) + ) + self.conv_layer.add(tf.keras.layers.BatchNormalization(trainable=trainable)) + # https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization + ## with the argument training=False (which is the default), the layer normalizes its output using a moving average of the mean and standard deviation of the batches it has seen during training + # if filter_id == 0 and dropout is not None: + # self.conv_layer.add(tf.keras.layers.Dropout(rate=dropout, name='DropOut')) + + # Step 2 - Pooling Blocks + if self.pool: + self.pool_layer = tf.keras.layers.MaxPooling3D((2,2,2), strides=(2,2,2), name='Pool') + + @tf.function + def call(self, x): + + x = self.conv_layer(x) + + if self.pool: + return x, self.pool_layer(x) + else: + return x + +class ConvBlock3DSERes(tf.keras.layers.Layer): + """ + For channel-wise attention + """ + + def __init__(self, filters, kernel_size=(3,3,3), strides=(1, 1, 1), padding='same' + , dilation_rate=(1,1,1) + , activation='relu' + , trainable=False + , pool=False + , squeeze_ratio=None + , init=False + , name=''): + + super(ConvBlock3DSERes, self).__init__(name='{}_ConvBlock3DSERes'.format(name)) + + # Step 0 - Init + self.init = init + self.trainable = trainable + + # Step 1 - Init (to get equivalent feature map count) + if self.init: + self.convblock_filterequalizer = tf.keras.layers.Conv3D(filters=filters[0], kernel_size=(1,1,1), strides=(1,1,1), padding='same' + , activation='relu' + ) + + # Step 2- Conv Block + self.convblock_res = ConvBlock3D(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding + , dilation_rate=dilation_rate + , activation=activation + , trainable=trainable + , pool=False + , name=name + ) + + # Step 3 - Squeeze Block + """ + Ref: https://github.com/imkhan2/se-resnet/blob/master/se_resnet.py + """ + self.squeeze_ratio = squeeze_ratio + if self.squeeze_ratio is not None: + self.seblock = tf.keras.Sequential() + self.seblock.add(tf.keras.layers.GlobalAveragePooling3D()) + self.seblock.add(tf.keras.layers.Reshape(target_shape=(1,1,1,filters[0]))) + self.seblock.add(tf.keras.layers.Conv3D(filters=filters[0]//squeeze_ratio, kernel_size=(1,1,1), strides=(1,1,1), padding='same' + , activation='relu')) + self.seblock.add(tf.keras.layers.Conv3D(filters=filters[0], kernel_size=(1,1,1), strides=(1,1,1), padding='same' + , activation='sigmoid')) + + self.pool = pool + if self.pool: + self.pool_layer = tf.keras.layers.MaxPooling3D((2,2,2), strides=(2,2,2), name='{}_Pool'.format(name)) + + @tf.function + def call(self, x): + + if self.init: + x = self.convblock_filterequalizer(x) + + x_res = self.convblock_res(x) + + if self.squeeze_ratio is not None: + x_se = self.seblock(x_res) # squeeze and then get excitation factor + x_res = tf.math.multiply(x_res, x_se) # excited block + + y = x + x_res + + if self.pool: + return y, self.pool_layer(y) + else: + return y + + +class ConvBlock3DDropOut(tf.keras.layers.Layer): + + def __init__(self, filters, kernel_size=(3,3,3), strides=(1, 1, 1), padding='same' + , dilation_rate=(1,1,1) + , activation='relu' + , trainable=False + , dropout=None + , pool=False + , name=''): + super(ConvBlock3DDropOut, self).__init__(name='{}_ConvBlock3DDropOut'.format(name)) + + self.pool = pool + self.filters = filters + self.trainable = trainable + + if type(filters) == int: + filters = [filters] + + self.conv_layer = tf.keras.Sequential() + for filter_id, filter_count in enumerate(filters): + + if dropout is not None: + self.conv_layer.add(tf.keras.layers.Dropout(rate=dropout, name='DropOut_{}'.format(filter_id))) # before every conv layer (could also be after every layer?) + + self.conv_layer.add( + tf.keras.layers.Conv3D(filters=filter_count, kernel_size=kernel_size, strides=strides, padding=padding + , dilation_rate=dilation_rate + , activation=activation + , name='Conv_{}'.format(filter_id)) + ) + self.conv_layer.add(tf.keras.layers.BatchNormalization(trainable=trainable)) + + if self.pool: + self.pool_layer = tf.keras.layers.MaxPooling3D((2,2,2), strides=(2,2,2), name='Pool') + + @tf.function + def call(self, x): + + x = self.conv_layer(x) + + if self.pool: + return x, self.pool_layer(x) + else: + return x + +class ConvBlock3DSEResDropout(tf.keras.layers.Layer): + + def __init__(self, filters, kernel_size=(3,3,3), strides=(1, 1, 1), padding='same' + , dilation_rate=(1,1,1) + , activation='relu' + , trainable=False + , dropout=None + , pool=False + , squeeze_ratio=None + , init=False + , name=''): + + super(ConvBlock3DSEResDropout, self).__init__(name='{}_ConvBlock3DSEResDropout'.format(name)) + + self.init = init + self.trainable = trainable + + if self.init: + self.convblock_filterequalizer = tf.keras.layers.Conv3D(filters=filters[0], kernel_size=(1,1,1), strides=(1,1,1), padding='same' + , activation='relu' + ) + + self.convblock_res = ConvBlock3DDropOut(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding + , dilation_rate=dilation_rate + , activation=activation + , trainable=trainable + , dropout=dropout + , pool=False + , name=name + ) + + """ + Ref: https://github.com/imkhan2/se-resnet/blob/master/se_resnet.py + """ + self.squeeze_ratio = squeeze_ratio + if self.squeeze_ratio is not None: + self.seblock = tf.keras.Sequential() + self.seblock.add(tf.keras.layers.GlobalAveragePooling3D()) + self.seblock.add(tf.keras.layers.Reshape(target_shape=(1,1,1,filters[0]))) + self.seblock.add(tf.keras.layers.Conv3D(filters=filters[0]//squeeze_ratio, kernel_size=(1,1,1), strides=(1,1,1), padding='same' + , activation='relu')) + self.seblock.add(tf.keras.layers.Conv3D(filters=filters[0], kernel_size=(1,1,1), strides=(1,1,1), padding='same' + , activation='sigmoid')) + + self.pool = pool + if self.pool: + self.pool_layer = tf.keras.layers.MaxPooling3D((2,2,2), strides=(2,2,2), name='{}_Pool'.format(name)) + + @tf.function + def call(self, x): + + if self.init: + x = self.convblock_filterequalizer(x) + + x_res = self.convblock_res(x) + + if self.squeeze_ratio is not None: + x_se = self.seblock(x_res) # squeeze and then get excitation factor + x_res = tf.math.multiply(x_res, x_se) # excited block + + y = x + x_res + + if self.pool: + return y, self.pool_layer(y) + else: + return y + + +class UpConvBlock3D(tf.keras.layers.Layer): + + def __init__(self, filters, kernel_size=(2,2,2), strides=(2, 2, 2), padding='same', trainable=False, name=''): + super(UpConvBlock3D, self).__init__(name='{}_UpConv3D'.format(name)) + + self.trainable = trainable + self.upconv_layer = tf.keras.Sequential() + self.upconv_layer.add(tf.keras.layers.Conv3DTranspose(filters, kernel_size, strides, padding=padding + , activation='relu' + , kernel_regularizer=None + , name='UpConv_{}'.format(self.name)) + ) + # self.upconv_layer.add(tf.keras.layers.BatchNormalization(trainable=trainable)) + + @tf.function + def call(self, x): + return self.upconv_layer(x) + + +class ConvBlock3DFlipOut(tf.keras.layers.Layer): + """ + Ref + - https://www.tensorflow.org/probability/api_docs/python/tfp/layers/Convolution3DFlipout + """ + + def __init__(self, filters, kernel_size=(3,3,3), strides=(1, 1, 1), padding='same' + , dilation_rate=(1,1,1) + , activation='relu' + , trainable=False + , pool=False + , name=''): + super(ConvBlock3DFlipOut, self).__init__(name='{}ConvBlock3DFlipOut'.format(name)) + + self.pool = pool + self.filters = filters + + if type(filters) == int: + filters = [filters] + + self.conv_layer = tf.keras.Sequential() + for filter_id, filter_count in enumerate(filters): + self.conv_layer.add( + tfp.layers.Convolution3DFlipout(filters=filter_count, kernel_size=kernel_size, strides=strides, padding=padding + , dilation_rate=dilation_rate + , activation=activation + # , kernel_prior_fn=? + , name='Conv3DFlip_{}'.format(filter_id)) + ) + self.conv_layer.add(tfa.layers.GroupNormalization(groups=filter_count//2, trainable=trainable)) + + if self.pool: + self.pool_layer = tf.keras.layers.MaxPooling3D((2,2,2), strides=(2,2,2), name='Pool') + + def call(self, x): + + x = self.conv_layer(x) + + if self.pool: + return x, self.pool_layer(x) + else: + return x + +class ConvBlock3DSEResFlipOut(tf.keras.layers.Layer): + """ + For channel-wise attention + """ + + def __init__(self, filters, kernel_size=(3,3,3), strides=(1, 1, 1), padding='same' + , dilation_rate=(1,1,1) + , activation='relu' + , trainable=False + , pool=False + , squeeze_ratio=None + , init=False + , name=''): + + super(ConvBlock3DSEResFlipOut, self).__init__(name='{}ConvBlock3DSEResFlipOut'.format(name)) + + self.init = init + if self.init: + self.convblock_filterequalizer = tf.keras.layers.Conv3D(filters=filters[0], kernel_size=(1,1,1), strides=(1,1,1), padding='same' + , activation='relu' + , kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None) + + self.convblock_res = ConvBlock3DFlipOut(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding + , dilation_rate=dilation_rate + , activation=activation + , trainable=trainable + , pool=False + , name=name + ) + + """ + Ref: https://github.com/imkhan2/se-resnet/blob/master/se_resnet.py + """ + self.squeeze_ratio = squeeze_ratio + if self.squeeze_ratio is not None: + self.seblock = tf.keras.Sequential() + self.seblock.add(tf.keras.layers.GlobalAveragePooling3D()) + self.seblock.add(tf.keras.layers.Reshape(target_shape=(1,1,1,filters[0]))) + self.seblock.add(tf.keras.layers.Conv3D(filters=filters[0]//squeeze_ratio, kernel_size=(1,1,1), strides=(1,1,1), padding='same' + , activation='relu' + , kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None)) + self.seblock.add(tf.keras.layers.Conv3D(filters=filters[0], kernel_size=(1,1,1), strides=(1,1,1), padding='same' + , activation='sigmoid' + , kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None)) + + self.pool = pool + if self.pool: + self.pool_layer = tf.keras.layers.MaxPooling3D((2,2,2), strides=(2,2,2), name='{}_Pool'.format(name)) + + def call(self, x): + + if self.init: + x = self.convblock_filterequalizer(x) + + x_res = self.convblock_res(x) + + if self.squeeze_ratio is not None: + x_se = self.seblock(x_res) # squeeze and then get excitation factor + x_res = tf.math.multiply(x_res, x_se) # excited block + + y = x + x_res + + if self.pool: + return y, self.pool_layer(y) + else: + return y + + +############################################################ +# 3D MODELS # +############################################################ + +class ModelFocusNetDropOut(tf.keras.Model): + + def __init__(self, class_count, trainable=False, verbose=False): + """ + Params + ------ + class_count: to know how many class activation maps are needed + trainable: set to False when BNorm does not need to have its parameters recalculated e.g. during testing + """ + super(ModelFocusNetDropOut, self).__init__(name='ModelFocusNetDropOut') + + # Step 0 - Init + self.verbose = verbose + self.trainable = trainable + + dropout = [None, 0.25, 0.25, 0.25, 0.25, 0.25, None, None] + filters = [[16,16], [32,32]] + dilation_xy = [1, 2, 3, 6, 12, 18] + dilation_z = [1, 1, 1, 1, 1 , 1] + + + # Se-Res Blocks + self.convblock1 = ConvBlock3DSEResDropout(filters=filters[0], kernel_size=(3,3,1), dilation_rate=(dilation_xy[0], dilation_xy[0], dilation_z[0]), trainable=trainable, dropout=dropout[0], pool=True , squeeze_ratio=2, name='Block1') # Dim/2 (e.g. 240/2=120)(rp=(3,5,10),(1,1,2)) + self.convblock2 = ConvBlock3DSEResDropout(filters=filters[0], kernel_size=(3,3,3), dilation_rate=(dilation_xy[1], dilation_xy[1], dilation_z[1]), trainable=trainable, dropout=dropout[0], pool=False, squeeze_ratio=2, name='Block2') # Dim/2 (e.g. 240/2=120)(rp=(14,18) ,(4,6)) + + # Dense ASPP + self.convblock3 = ConvBlock3DDropOut(filters=filters[1], dilation_rate=(dilation_xy[2], dilation_xy[2], dilation_z[2]), trainable=trainable, dropout=dropout[1], pool=False, name='Block3_ASPP') # Dim/2 (e.g. 240/2=120) (rp=(24,30),(8,10)) + self.convblock4 = ConvBlock3DDropOut(filters=filters[1], dilation_rate=(dilation_xy[3], dilation_xy[3], dilation_z[3]), trainable=trainable, dropout=dropout[2], pool=False, name='Block4_ASPP') # Dim/2 (e.g. 240/2=120) (rp=(42,54),(12,14)) + self.convblock5 = ConvBlock3DDropOut(filters=filters[1], dilation_rate=(dilation_xy[4], dilation_xy[4], dilation_z[4]), trainable=trainable, dropout=dropout[3], pool=False, name='Block5_ASPP') # Dim/2 (e.g. 240/2=120) (rp=(78,102),(16,18)) + self.convblock6 = ConvBlock3DDropOut(filters=filters[1], dilation_rate=(dilation_xy[5], dilation_xy[5], dilation_z[5]), trainable=trainable, dropout=dropout[4], pool=False, name='Block6_ASPP') # Dim/2 (e.g. 240/2=120) (rp=(138,176),(20,22)) + self.convblock7 = ConvBlock3DSEResDropout(filters=filters[1], dilation_rate=(dilation_xy[0], dilation_xy[0], dilation_z[0]), trainable=trainable, dropout=dropout[5], pool=False, squeeze_ratio=2, init=True, name='Block7') # Dim/2 (e.g. 240/2=120) (rp=(178,180),(24,26)) + + # Upstream + self.convblock8 = ConvBlock3DSEResDropout(filters=filters[1], dilation_rate=(dilation_xy[0], dilation_xy[0], dilation_z[0]), trainable=trainable, dropout=dropout[6], pool=False, squeeze_ratio=2, init=True, name='Block8') # Dim/2 (e.g. 240/2=120) + + self.upconvblock9 = UpConvBlock3D(filters=filters[0][0], trainable=trainable, name='Block9_1') # Dim/1 (e.g. 240/1=240) + self.convblock9 = ConvBlock3DSEResDropout(filters=filters[0], dilation_rate=(dilation_xy[0], dilation_xy[0], dilation_z[0]), trainable=trainable, dropout=dropout[7], pool=False, squeeze_ratio=2, init=True, name='Block9') # Dim/1 (e.g. 240/1=240) + + # Final + self.convblock10 = tf.keras.layers.Conv3D(filters=class_count, strides=(1,1,1), kernel_size=(3,3,3), padding='same' + , dilation_rate=(1,1,1) + , activation='softmax' + , name='Block10') + + @tf.function + def call(self, x): + + # SE-Res Blocks + conv1, pool1 = self.convblock1(x) + conv2 = self.convblock2(pool1) + + # Dense ASPP + conv3 = self.convblock3(conv2) + conv3_op = tf.concat([conv2, conv3], axis=-1) + + conv4 = self.convblock4(conv3_op) + conv4_op = tf.concat([conv3_op, conv4], axis=-1) + + conv5 = self.convblock5(conv4_op) + conv5_op = tf.concat([conv4_op, conv5], axis=-1) + + conv6 = self.convblock6(conv5_op) + conv6_op = tf.concat([conv5_op, conv6], axis=-1) + + conv7 = self.convblock7(conv6_op) + + # Upstream + conv8 = self.convblock8(tf.concat([pool1, conv7], axis=-1)) + + up9 = self.upconvblock9(conv8) + conv9 = self.convblock9(tf.concat([conv1, up9], axis=-1)) + + # Final + conv10 = self.convblock10(conv9) + + if self.verbose: + print (' ---------- Model: ', self.name) + print (' - x: ', x.shape) + print (' - conv1: ', conv1.shape) + print (' - conv2: ', conv2.shape) + print (' - conv3_op: ', conv3_op.shape) + print (' - conv4_op: ', conv4_op.shape) + print (' - conv5_op: ', conv5_op.shape) + print (' - conv6_op: ', conv6_op.shape) + print (' - conv7: ', conv7.shape) + print (' - conv8: ', conv8.shape) + print (' - conv9: ', conv9.shape) + print (' - conv10: ', conv10.shape) + + + return conv10 + +class ModelFocusNetFlipOut(tf.keras.Model): + + def __init__(self, class_count, trainable=False, verbose=False): + super(ModelFocusNetFlipOut, self).__init__(name='ModelFocusNetFlipOut') + + # Step 0 - Init + self.verbose = verbose + + filters = [[16,16], [32,32]] + dilation_xy = [1, 2, 3, 6, 12, 18] + dilation_z = [1, 1, 1, 1, 1 , 1] + + # Se-Res Blocks + self.convblock1 = ConvBlock3DSERes(filters=filters[0], kernel_size=(3,3,1), dilation_rate=(dilation_xy[0], dilation_xy[0], dilation_z[0]), trainable=trainable, pool=True , squeeze_ratio=2, name='Block1') # Dim/2 (e.g. 96/2=48, 240/2=120)(rp=(3,5,10),(1,1,2)) + self.convblock2 = ConvBlock3DSERes(filters=filters[0] , dilation_rate=(dilation_xy[1], dilation_xy[1], dilation_z[1]), trainable=trainable, pool=False, squeeze_ratio=2, name='Block2') # Dim/2 (e.g. 96/2=48, 240/2=120)(rp=(14,18) ,(4,6)) + + # Dense ASPP + self.convblock3 = ConvBlock3DFlipOut(filters=filters[1], dilation_rate=(dilation_xy[2], dilation_xy[2], dilation_z[2]), trainable=trainable, pool=False, name='Block3_ASPP') # Dim/2 (e.g. 96/2=48, 240/2=120) (rp=(24,30),(16,18)) + self.convblock4 = ConvBlock3DFlipOut(filters=filters[1], dilation_rate=(dilation_xy[3], dilation_xy[3], dilation_z[3]), trainable=trainable, pool=False, name='Block4_ASPP') # Dim/2 (e.g. 96/2=48, 240/2=120) (rp=(42,54),(20,22)) + self.convblock5 = ConvBlock3DFlipOut(filters=filters[1], dilation_rate=(dilation_xy[4], dilation_xy[4], dilation_z[4]), trainable=trainable, pool=False, name='Block5_ASPP') # Dim/2 (e.g. 96/2=48, 240/2=120) (rp=(78,102),(24,26)) + self.convblock6 = ConvBlock3DFlipOut(filters=filters[1], dilation_rate=(dilation_xy[5], dilation_xy[5], dilation_z[5]), trainable=trainable, pool=False, name='Block6_ASPP') # Dim/2 (e.g. 96/2=48, 240/2=120) (rp=(138,176),(28,30)) + self.convblock7 = ConvBlock3DSEResFlipOut(filters=filters[1], dilation_rate=(dilation_xy[0], dilation_xy[0], dilation_z[0]), trainable=trainable, pool=False, squeeze_ratio=2, init=True, name='Block7') # Dim/2 (e.g. 96/2=48) (rp=(176,180),(32,34)) + + # Upstream + self.convblock8 = ConvBlock3DSERes(filters=filters[1], dilation_rate=(dilation_xy[0], dilation_xy[0], dilation_z[0]), trainable=trainable, pool=False, squeeze_ratio=2, init=True, name='Block8') # Dim/2 (e.g. 96/2=48) + + self.upconvblock9 = UpConvBlock3D(filters=filters[0][0], trainable=trainable, name='Block9_1') # Dim/1 (e.g. 96/1 = 96) + self.convblock9 = ConvBlock3DSERes(filters=filters[0], dilation_rate=(dilation_xy[0], dilation_xy[0], dilation_z[0]), trainable=trainable, pool=False, squeeze_ratio=2, init=True, name='Block9') # Dim/1 (e.g. 96/1 = 96) + + self.convblock10 = tf.keras.layers.Conv3D(filters=class_count, strides=(1,1,1), kernel_size=(3,3,3), padding='same' + , dilation_rate=(1,1,1) + , activation='softmax' + , name='Block10') + + # @tf.function (cant call model.losses if this is enabled) + def call(self, x): + + # SE-Res Blocks + conv1, pool1 = self.convblock1(x) + conv2 = self.convblock2(pool1) + + # Dense ASPP + conv3 = self.convblock3(conv2) + conv3_op = tf.concat([conv2, conv3], axis=-1) + + conv4 = self.convblock4(conv3_op) + conv4_op = tf.concat([conv3_op, conv4], axis=-1) + + conv5 = self.convblock5(conv4_op) + conv5_op = tf.concat([conv4_op, conv5], axis=-1) + + conv6 = self.convblock6(conv5_op) + conv6_op = tf.concat([conv5_op, conv6], axis=-1) + + conv7 = self.convblock7(conv6_op) + + # Upstream + conv8 = self.convblock8(tf.concat([pool1, conv7], axis=-1)) + + up9 = self.upconvblock9(conv8) + conv9 = self.convblock9(tf.concat([conv1, up9], axis=-1)) + + # Final + conv10 = self.convblock10(conv9) + + if self.verbose: + print (' ---------- Model: ', self.name) + print (' - x: ', x.shape) + print (' - conv1: ', conv1.shape) + print (' - conv2: ', conv2.shape) + print (' - conv3_op: ', conv3_op.shape) + print (' - conv4_op: ', conv4_op.shape) + print (' - conv5_op: ', conv5_op.shape) + print (' - conv6_op: ', conv6_op.shape) + print (' - conv7: ', conv7.shape) + print (' - conv8: ', conv8.shape) + print (' - conv9: ', conv9.shape) + print (' - conv10: ', conv10.shape) + + + return conv10 + +############################################################ +# MAIN # +############################################################ + +if __name__ == "__main__": + + X = tf.random.normal((2,140,140,40,1)) + + print ('\n ------------------- ModelFocusNetDropOut ------------------- ') + model = ModelFocusNetDropOut(class_count=10, trainable=True) + y_predict = model(X, training=True) + model.summary() + + print ('\n ------------------- ModelFocusNetFlipOut ------------------- ') + model = ModelFocusNetFlipOut(class_count=10, trainable=True) + y_predict = model(X, training=True) + model.summary() \ No newline at end of file diff --git a/src/model/trainer.py b/src/model/trainer.py new file mode 100644 index 0000000..aba24a3 --- /dev/null +++ b/src/model/trainer.py @@ -0,0 +1,1718 @@ +# Import internal libraries +import src.config as config +import src.model.models as models +import src.model.utils as modutils +import src.model.losses as losses +import src.dataloader.utils as datautils + +# Import external libraries +import os +import gc +import pdb +import copy +import time +import tqdm +import datetime +import traceback +import numpy as np +import tensorflow as tf +from pathlib import Path +import tensorflow_probability as tfp + +import matplotlib.pyplot as plt; +import seaborn as sns; + +_EPSILON = tf.keras.backend.epsilon() +DEBUG = True + +############################################################ +# MODEL RELATED # +############################################################ + +def set_lr(epoch, optimizer, init_lr): + + if epoch > 1 and epoch % 20 == 0: + optimizer.lr.assign(optimizer.lr * 0.98) + +############################################################ +# METRICS RELATED # +############################################################ +class ModelMetrics(): + + def __init__(self, metric_type, params): + + self.params = params + + self.label_map = params['internal']['label_map'] + self.label_ids = params['internal']['label_ids'] + self.logging_tboard = params['metrics']['logging_tboard'] + self.metric_type = metric_type + + self.losses_obj = self.get_losses_obj(params) + + self.init_metrics(params) + if self.logging_tboard: + self.init_tboard_writers(params) + + self.reset_metrics(params) + self.init_epoch0(params) + self.reset_metrics(params) + + def get_losses_obj(self, params): + losses_obj = {} + for loss_key in params['metrics']['metrics_loss']: + if config.LOSS_DICE == params['metrics']['metrics_loss'][loss_key]: + losses_obj[loss_key] = losses.loss_dice_3d_tf_func + if config.LOSS_CE == params['metrics']['metrics_loss'][loss_key]: + losses_obj[loss_key] = losses.loss_ce_3d_tf_func + if config.LOSS_CE_BASIC == params['metrics']['metrics_loss'][loss_key]: + losses_obj[loss_key] = losses.loss_cebasic_3d_tf_func + + return losses_obj + + def init_metrics(self, params): + """ + These are metrics derived from tensorflows library + """ + # Metrics for losses (during training for smaller grids) + self.metrics_loss_obj = {} + for metric_key in params['metrics']['metrics_loss']: + self.metrics_loss_obj[metric_key] = {} + self.metrics_loss_obj[metric_key]['total'] = tf.keras.metrics.Mean(name='Avg{}-{}'.format(metric_key, self.metric_type)) + if params['metrics']['metrics_loss'][metric_key] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_CE_BASIC]: + for label_id in self.label_ids: + self.metrics_loss_obj[metric_key][label_id] = tf.keras.metrics.Mean(name='Avg{}-Label-{}-{}'.format(metric_key, label_id, self.metric_type)) + + # Metrics for eval (for full 3D volume) + self.metrics_eval_obj = {} + for metric_key in params['metrics']['metrics_eval']: + self.metrics_eval_obj[metric_key] = {} + self.metrics_eval_obj[metric_key]['total'] = tf.keras.metrics.Mean(name='Avg{}-{}'.format(metric_key, self.metric_type)) + if params['metrics']['metrics_eval'][metric_key] in [config.LOSS_DICE]: + for label_id in self.label_ids: + self.metrics_eval_obj[metric_key][label_id] = tf.keras.metrics.Mean(name='Avg{}-Label-{}-{}'.format(metric_key, label_id, self.metric_type)) + + # Time Metrics + self.metric_time_dataloader = tf.keras.metrics.Mean(name='AvgTime-Dataloader-{}'.format(self.metric_type)) + self.metric_time_model_predict = tf.keras.metrics.Mean(name='AvgTime-ModelPredict-{}'.format(self.metric_type)) + self.metric_time_model_loss = tf.keras.metrics.Mean(name='AvgTime-ModelLoss-{}'.format(self.metric_type)) + self.metric_time_model_backprop = tf.keras.metrics.Mean(name='AvgTime-ModelBackProp-{}'.format(self.metric_type)) + + def reset_metrics(self, params): + + # Metrics for losses (during training for smaller grids) + for metric_key in params['metrics']['metrics_loss']: + self.metrics_loss_obj[metric_key]['total'].reset_states() + if params['metrics']['metrics_loss'][metric_key] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_CE_BASIC]: + for label_id in self.label_ids: + self.metrics_loss_obj[metric_key][label_id].reset_states() + + # Metrics for eval (for full 3D volume) + for metric_key in params['metrics']['metrics_eval']: + self.metrics_eval_obj[metric_key]['total'].reset_states() + if params['metrics']['metrics_eval'][metric_key] in [config.LOSS_DICE]: + for label_id in self.label_ids: + self.metrics_eval_obj[metric_key][label_id].reset_states() + + # Time Metrics + self.metric_time_dataloader.reset_states() + self.metric_time_model_predict.reset_states() + self.metric_time_model_loss.reset_states() + self.metric_time_model_backprop.reset_states() + + def init_tboard_writers(self, params): + """ + These are tensorboard writer + """ + # Writers for loss (during training for smaller grids) + self.writers_loss_obj = {} + for metric_key in params['metrics']['metrics_loss']: + self.writers_loss_obj[metric_key] = {} + self.writers_loss_obj[metric_key]['total'] = modutils.get_tensorboard_writer(params['exp_name'], suffix=self.metric_type + '-Loss') + if params['metrics']['metrics_loss'][metric_key] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_CE_BASIC]: + for label_id in self.label_ids: + self.writers_loss_obj[metric_key][label_id] = modutils.get_tensorboard_writer(params['exp_name'], suffix=self.metric_type + '-Loss-' + str(label_id)) + + # Writers for eval (for full 3D volume) + self.writers_eval_obj = {} + for metric_key in params['metrics']['metrics_eval']: + self.writers_eval_obj[metric_key] = {} + self.writers_eval_obj[metric_key]['total'] = modutils.get_tensorboard_writer(params['exp_name'], suffix=self.metric_type + '-Eval') + if params['metrics']['metrics_eval'][metric_key] in [config.LOSS_DICE]: + for label_id in self.label_ids: + self.writers_eval_obj[metric_key][label_id] = modutils.get_tensorboard_writer(params['exp_name'], suffix=self.metric_type + '-Eval-' + str(label_id)) + + # Time and other writers + self.writer_lr = modutils.get_tensorboard_writer(params['exp_name'], suffix='LR') + self.writer_time_dataloader = modutils.get_tensorboard_writer(params['exp_name'], suffix='Time-Dataloader') + self.writer_time_model_predict = modutils.get_tensorboard_writer(params['exp_name'], suffix='Time-Model-Predict') + self.writer_time_model_loss = modutils.get_tensorboard_writer(params['exp_name'], suffix='Time-Model-Loss') + self.writer_time_model_backprop = modutils.get_tensorboard_writer(params['exp_name'], suffix='Time-Model-Backprop') + + def init_epoch0(self, params): + + for metric_str in self.metrics_loss_obj: + self.update_metric_loss(metric_str, 1e-6) + if params['metrics']['metrics_loss'][metric_str] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_CE_BASIC]: + self.update_metric_loss_labels(metric_str, {label_id: 1e-6 for label_id in self.label_ids}) + + for metric_str in self.metrics_eval_obj: + self.update_metric_eval_labels(metric_str, {label_id: 1e-6 for label_id in self.label_ids}) + + self.update_metrics_time(time_dataloader=1e-6, time_predict=1e-6, time_loss=1e-6, time_backprop=1e-6) + + self.write_epoch_summary(epoch=0, label_map=self.label_map, params=None, eval_condition=True) + + def update_metrics_time(self, time_dataloader, time_predict, time_loss, time_backprop): + if time_dataloader is not None: + self.metric_time_dataloader.update_state(time_dataloader) + if time_predict is not None: + self.metric_time_model_predict.update_state(time_predict) + if time_loss is not None: + self.metric_time_model_loss.update_state(time_loss) + if time_backprop is not None: + self.metric_time_model_backprop.update_state(time_backprop) + + def update_metric_loss(self, metric_str, metric_val): + # Metrics for losses (during training for smaller grids) + self.metrics_loss_obj[metric_str]['total'].update_state(metric_val) + + @tf.function + def update_metric_loss_labels(self, metric_str, metric_vals_labels): + # Metrics for losses (during training for smaller grids) + + for label_id in self.label_ids: + if metric_vals_labels[label_id] > 0.0: + self.metrics_loss_obj[metric_str][label_id].update_state(metric_vals_labels[label_id]) + + def update_metric_eval(self, metric_str, metric_val): + # Metrics for eval (for full 3D volume) + self.metrics_eval_obj[metric_str]['total'].update_state(metric_val) + + def update_metric_eval_labels(self, metric_str, metric_vals_labels, do_average=False): + # Metrics for eval (for full 3D volume) + + metric_avg = [] + for label_id in self.label_ids: + if label_id in metric_vals_labels: + if metric_vals_labels[label_id] > 0: + self.metrics_eval_obj[metric_str][label_id].update_state(metric_vals_labels[label_id]) + if do_average: + if label_id > 0: + metric_avg.append(metric_vals_labels[label_id]) + + if do_average: + if len(metric_avg): + self.metrics_eval_obj[metric_str]['total'].update_state(np.mean(metric_avg)) + + def write_epoch_summary(self, epoch, label_map, params=None, eval_condition=False): + + if self.logging_tboard: + # Metrics for losses (during training for smaller grids) + for metric_str in self.metrics_loss_obj: + modutils.make_summary('Loss/{}'.format(metric_str), epoch, writer1=self.writers_loss_obj[metric_str]['total'], value1=self.metrics_loss_obj[metric_str]['total'].result()) + if self.params['metrics']['metrics_loss'][metric_str] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_CE_BASIC]: + if len(self.metrics_loss_obj[metric_str]) > 1: # i.e. has label ids + for label_id in self.label_ids: + label_name, _ = modutils.get_info_from_label_id(label_id, label_map) + modutils.make_summary('Loss/{}/{}'.format(metric_str, label_name), epoch, writer1=self.writers_loss_obj[metric_str][label_id], value1=self.metrics_loss_obj[metric_str][label_id].result()) + + # Metrics for eval (for full 3D volume) + if eval_condition: + for metric_str in self.metrics_eval_obj: + modutils.make_summary('Eval3D/{}'.format(metric_str), epoch, writer1=self.writers_eval_obj[metric_str]['total'], value1=self.metrics_eval_obj[metric_str]['total'].result()) + if len(self.metrics_eval_obj[metric_str]) > 1: # i.e. has label ids + for label_id in self.label_ids: + label_name, _ = modutils.get_info_from_label_id(label_id, label_map) + modutils.make_summary('Eval3D/{}/{}'.format(metric_str, label_name), epoch, writer1=self.writers_eval_obj[metric_str][label_id], value1=self.metrics_eval_obj[metric_str][label_id].result()) + + # Time Metrics + modutils.make_summary('Info/Time/Dataloader' , epoch, writer1=self.writer_time_dataloader , value1=self.metric_time_dataloader.result()) + modutils.make_summary('Info/Time/ModelPredict' , epoch, writer1=self.writer_time_model_predict , value1=self.metric_time_model_predict.result()) + modutils.make_summary('Info/Time/ModelLoss' , epoch, writer1=self.writer_time_model_loss , value1=self.metric_time_model_loss.result()) + modutils.make_summary('Info/Time/ModelBackProp', epoch, writer1=self.writer_time_model_backprop, value1=self.metric_time_model_backprop.result()) + + # Learning Rate + if params is not None: + if 'optimizer' in params: + modutils.make_summary('Info/LR', epoch, writer1=self.writer_lr, value1=params['optimizer'].lr) + + def update_pbar(self, pbar): + desc_str = '' + + # Metrics for losses (during training for smaller grids) + for metric_str in self.metrics_loss_obj: + if len(desc_str): desc_str += ',' + + if metric_str in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_CE_BASIC]: + metric_avg = [] + for label_id in self.label_ids: + if label_id > 0: + label_result = self.metrics_loss_obj[metric_str][label_id].result().numpy() + if label_result > 0: + metric_avg.append(label_result) + + mean_val = 0 + if len(metric_avg): + mean_val = np.mean(metric_avg) + loss_text = '{}Loss:{:2f}'.format(metric_str, mean_val) + desc_str += loss_text + + # GPU Memory + if 1: + try: + if len(self.metrics_loss_obj) > 1: + desc_str = desc_str[:-1] # to remove the extra ',' + desc_str += ',' + str(modutils.get_tf_gpu_memory()) + except: + pass + + pbar.set_description(desc=desc_str, refresh=True) + +def eval_3D_finalize(exp_name, patient_img, patient_gt, patient_pred_processed, patient_pred, patient_pred_std, patient_pred_ent, patient_pred_mif, patient_pred_unc + , patient_id_curr + , model_folder_epoch_patches + , loss_labels_val, hausdorff_labels_val, hausdorff95_labels_val, msd_labels_vals + , spacing, label_map + , save=False): + + try: + + # Step 1 - Save 3D grid to visualize in 3D Slicer (drag-and-drop mechanism) + if save: + + datautils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_img.nrrd' , patient_img[:,:,:,0], spacing) + datautils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_mask.nrrd', np.argmax(patient_gt, axis=3),spacing) + maskpred_labels = np.argmax(patient_pred_processed, axis=3) # not "np.argmax(patient_pred, axis=3)" since it does not contain any postprocessing + datautils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpred.nrrd', maskpred_labels, spacing) + + maskpred_labels_probmean = np.take_along_axis(patient_pred, np.expand_dims(maskpred_labels,axis=-1), axis=-1)[:,:,:,0] + datautils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredmean.nrrd', maskpred_labels_probmean, spacing) + + if np.sum(patient_pred_std): + maskpred_labels_std = np.take_along_axis(patient_pred_std, np.expand_dims(maskpred_labels,axis=-1), axis=-1)[:,:,:,0] + datautils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredstd.nrrd', maskpred_labels_std, spacing) + + maskpred_std_max = np.max(patient_pred_std, axis=-1) + datautils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredstdmax.nrrd', maskpred_std_max, spacing) + + if np.sum(patient_pred_ent): + maskpred_ent = patient_pred_ent # [H,W,D] + datautils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredent.nrrd', maskpred_ent, spacing) + + if np.sum(patient_pred_mif): + maskpred_mif = patient_pred_mif + datautils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredmif.nrrd', maskpred_mif, spacing) + + if np.sum(patient_pred_unc): + if len(patient_pred_unc.shape) == 4: + maskpred_labels_unc = np.take_along_axis(patient_pred_unc, np.expand_dims(maskpred_labels,axis=-1), axis=-1)[:,:,:,0] # [H,W,D,C] --> [H,W,D] + else: + maskpred_labels_unc = patient_pred_unc + datautils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredunc.nrrd', maskpred_labels_unc, spacing) + + try: + # Step 3.1.3.2 - PLot results for that patient + f, axarr = plt.subplots(3,1, figsize=(15,10)) + boxplot_dice, boxplot_hausdorff, boxplot_hausdorff95 = {}, {}, {} + boxplot_dice_mean_list = [] + for label_id in range(len(loss_labels_val)): + label_name, _ = modutils.get_info_from_label_id(label_id, label_map) + boxplot_dice[label_name] = [loss_labels_val[label_id]] + boxplot_hausdorff[label_name] = [hausdorff_labels_val[label_id]] + boxplot_hausdorff95[label_name] = [hausdorff95_labels_val[label_id]] + if label_id > 0 and loss_labels_val[label_id] > 0: + boxplot_dice_mean_list.append(loss_labels_val[label_id]) + + axarr[0].boxplot(boxplot_dice.values()) + axarr[0].set_xticks(range(1, len(boxplot_dice)+1)) + axarr[0].set_xticklabels(boxplot_dice.keys()) + axarr[0].set_ylim([0.0,1.1]) + axarr[0].grid() + axarr[0].set_title('DICE - Avg: {} \n w/o chiasm: {}'.format( + '%.4f' % (np.mean(boxplot_dice_mean_list)) + , '%.4f' % (np.mean(boxplot_dice_mean_list[0:1] + boxplot_dice_mean_list[2:])) # avoid label_id=2 + ) + ) + + axarr[1].boxplot(boxplot_hausdorff.values()) + axarr[1].set_xticks(range(1,len(boxplot_hausdorff)+1)) + axarr[1].set_xticklabels(boxplot_hausdorff.keys()) + axarr[1].grid() + axarr[1].set_title('Hausdorff') + + axarr[2].boxplot(boxplot_hausdorff95.values()) + axarr[2].set_xticks(range(1,len(boxplot_hausdorff95)+1)) + axarr[2].set_xticklabels(boxplot_hausdorff95.keys()) + axarr[2].set_title('95% Hausdorff') + axarr[2].grid() + + plt.savefig(str(Path(model_folder_epoch_patches).joinpath('results_' + patient_id_curr + '.png')), bbox_inches='tight') # , bbox_inches='tight' + plt.close() + + except: + traceback.print_exc() + + except: + traceback.print_exc() + pdb.set_trace() + +def get_ece(y_true, y_predict, patient_id, res_global, verbose=False): + """ + Params + ------ + y_true : [H,W,D,C], np.array, binary + y_predict: [H,W,D,C], np.array, with softmax probability values + - Ref: https://github.com/sirius8050/Expected-Calibration-Error/blob/master/ECE.py + : On Calibration of Modern Neural Networks + - Ref(future): https://github.com/yding5/AdaptiveBinning + : Revisiting the evaluation of uncertainty estimation and its application to explore model complexity-uncertainty trade-off + """ + res = {} + nan_value = -0.1 + + if verbose: print (' - [get_ece()] patient_id: ', patient_id) + + try: + + # Step 0 - Init + label_count = y_true.shape[-1] + + # Step 1 - Calculate o_predict + o_true = np.argmax(y_true, axis=-1) + o_predict = np.argmax(y_predict, axis=-1) + + # Step 2 - Loop over different classes + for label_id in range(label_count): + + if label_id > -1: + + if verbose: print (' --- [get_ece()] label_id: ', label_id) + + if label_id not in res_global: res_global[label_id] = {'o_predict_label':[], 'y_predict_label':[], 'o_true_label':[]} + + # Step 2.1 - Get o_predict_label(label_ids), o_true_label(label_ids), y_predict_label(probs) [and append to global list] + ## NB: You are considering TP + FP here + o_true_label = o_true[o_predict == label_id] + o_predict_label = o_predict[o_predict == label_id] + y_predict_label = y_predict[:,:,:,label_id][o_predict == label_id] + res_global[label_id]['o_true_label'].extend(o_true_label.flatten().tolist()) + res_global[label_id]['o_predict_label'].extend(o_predict_label.flatten().tolist()) + res_global[label_id]['y_predict_label'].extend(y_predict_label.flatten().tolist()) + + if len(o_true_label) and len(y_predict_label): + + # Step 2.2 - Bin the probs and calculate their mean + y_predict_label_bin_ids = np.digitize(y_predict_label, np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.01]), right=False) - 1 + y_predict_binned_vals = [y_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)] + y_predict_bins_mean = [np.mean(vals) if len(vals) else nan_value for vals in y_predict_binned_vals] + + # Step 2.3 - Calculate the accuracy of each bin + o_predict_label_bins = [o_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)] + o_true_label_bins = [o_true_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)] + y_predict_bins_accuracy = [np.sum(o_predict_label_bins[bin_id] == o_true_label_bins[bin_id])/len(o_predict_label_bins[bin_id]) if len(o_predict_label_bins[bin_id]) else nan_value for bin_id in range(label_count)] + y_predict_bins_len = [len(o_predict_label_bins[bin_id]) for bin_id in range(label_count)] + + # Step 2.4 - Wrapup + N = np.prod(y_predict_label.shape) + ce = np.array((np.array(y_predict_bins_len)/N)*(np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean))) + ce[ce == 0] = nan_value # i.e. y_predict_bins_accuracy[bin_id] == y_predict_bins_mean[bind_id] = nan_value + res[label_id] = ce + + else: + res[label_id] = -1 + + if 0: + if label_id == 1: + diff = np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean) + print (' - [get_ece()][BStem] diff: ', ['%.4f' % each for each in diff]) + + # NB: This considers the whole volume + o_true_label = y_true[:,:,:,label_id] # [1=this label, 0=other label] + o_predict_label = np.array(o_predict, copy=True) + o_predict_label[o_predict_label != label_id] = 0 + o_predict_label[o_predict_label == label_id] = 1 # [1 - predicted this label, 0 = predicted other label] + y_predict_label = y_predict[:,:,:,label_id] + + # Step x.2 - Bin the probs and calculate their mean + y_predict_label_bin_ids = np.digitize(y_predict_label, np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.01]), right=False) - 1 + y_predict_binned_vals = [y_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)] + y_predict_bins_mean = [np.mean(vals) if len(vals) else nan_value for vals in y_predict_binned_vals] + + # Step x.3 - Calculate the accuracy of each bin + o_predict_label_bins = [o_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)] + o_true_label_bins = [o_true_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)] + y_predict_bins_accuracy = [np.sum(o_predict_label_bins[bin_id] == o_true_label_bins[bin_id])/len(o_predict_label_bins[bin_id]) if len(o_predict_label_bins[bin_id]) else nan_value for bin_id in range(label_count)] + y_predict_bins_len = [len(o_predict_label_bins[bin_id]) for bin_id in range(label_count)] + + N_new = np.prod(y_predict_label.shape) + diff_new = np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean) + ce_new = np.array((np.array(y_predict_bins_len)/N_new)*(np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean))) + print (' - [get_ece()][BStem] diff_new: ', ['%.4f' % each for each in diff_new]) + + pdb.set_trace() + + if verbose: + print (' --- [get_ece()] y_predict_bins_accuracy: ', ['%.4f' % (each) for each in np.array(y_predict_bins_accuracy)]) + print (' --- [get_ece()] CE : ', ['%.4f' % (each) for each in np.array(res[label_id])]) + print (' --- [get_ece()] ECE: ', np.sum(np.abs(res[label_id][res[label_id] != nan_value]))) + + # Prob bars + # plt.hist(y_predict[:,:,:,label_id].flatten(), bins=10) + # plt.title('Softmax Probs (label={})\nPatient:{}'.format(label_id, patient_id)) + # plt.show() + + # GT Prob bars + # plt.bar(np.arange(len(y_predict_bins_len))/10.0 + 0.1, y_predict_bins_len, width=0.05) + # plt.title('Softmax Probs (GT) (label={})\nPatient:{}'.format(label_id, patient_id)) + # plt.xlabel('Probabilities') + # plt.show() + + # GT Probs (sorted) in plt.plot (with equally-spaced bins) + # from collections import Counter + # tmp = np.sort(y_predict_label) + # plt.plot(range(len(tmp)), tmp, color='orange') + # tmp_bins = np.digitize(tmp, np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.01])) - 1 + # tmp_bins_len = Counter(tmp_bins) + # boundary_start = 0 + # plt.plot([0,0],[0.0,1.0], color='black', alpha=0.5, linestyle='dashed', label='Bins(equally-spaced)') + # for boundary in np.arange(0,len(tmp_bins_len)): plt.plot([boundary_start+tmp_bins_len[boundary], boundary_start+tmp_bins_len[boundary]], [0.0,1.0], color='black', alpha=0.5, linestyle='dashed'); boundary_start+=tmp_bins_len[boundary] + # plt.title('Sorted Softmax Probs (GT) (label={})\nPatient:{}'.format(label_id, patient_id)) + # plt.legend() + # plt.show() + + # GT Probs (sorted) in plt.plot (with equally-sized bins) + if label_id == 1: + Path('tmp').mkdir(parents=True, exist_ok=True) + tmp = np.sort(y_predict_label) + tmp_len = len(tmp) + plt.plot(range(len(tmp)), tmp, color='orange') + for boundary in np.arange(0,tmp_len, int(tmp_len//10)): plt.plot([boundary, boundary], [0.0,1.0], color='black', alpha=0.5, linestyle='dashed') + plt.plot([0,0],[0,0], color='black', alpha=0.5, linestyle='dashed', label='Bins(equally-sized)') + plt.title('Sorted Softmax Probs (GT) (label={})\nPatient:{}'.format(label_id, patient_id)) + plt.legend() + # plt.show() + plt.savefig('./tmp/ECE_SortedProbs_label_{}_{}.png'.format(label_id, patient_id, ), bbox_inches='tight');plt.close() + + # ECE plot + plt.plot(np.arange(11), np.arange(11)/10.0, linestyle='dashed', color='black', alpha=0.8) + plt.scatter(np.arange(len(y_predict_bins_mean)) + 0.5 , y_predict_bins_mean, alpha=0.5, color='g', marker='s', label='Mean Pred') + plt.scatter(np.arange(len(y_predict_bins_accuracy)) + 0.5 , y_predict_bins_accuracy, alpha=0.5, color='b', marker='x', label='Accuracy') + diff = np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean) + for bin_id in range(len(y_predict_bins_accuracy)): plt.plot([bin_id + 0.5, bin_id + 0.5],[y_predict_bins_accuracy[bin_id], y_predict_bins_mean[bin_id]], color='pink') + plt.plot([bin_id + 0.5, bin_id + 0.5],[y_predict_bins_accuracy[bin_id], y_predict_bins_mean[bin_id]], color='pink', label='CE') + plt.xticks(ticks=np.arange(11), labels=np.arange(11)/10.0) + plt.title('CE (label={})\nPatient:{}'.format(label_id, patient_id)) + plt.xlabel('Probability') + plt.ylabel('Accuracy') + plt.legend() + # plt.show() + plt.savefig('./tmp/ECE_label_{}_{}.png'.format(label_id, patient_id, ), bbox_inches='tight');plt.close() + + # pdb.set_trace() + + except: + traceback.print_exc() + pdb.set_trace() + + return res_global, res + +def eval_3D_summarize(res, ece_global_obj, model, eval_type, deepsup_eval, label_map, model_folder_epoch_patches, times_mcruns, ttotal, save=False, verbose=False): + + try: + + pid = os.getpid() + + ############################################################################### + # Summarize # + ############################################################################### + + # Step 1 - Summarize DICE + Surface Distances + if 1: + loss_labels_avg, loss_labels_std = [], [] + hausdorff_labels_avg, hausdorff_labels_std = [], [] + hausdorff95_labels_avg, hausdorff95_labels_std = [], [] + msd_labels_avg, msd_labels_std = [], [] + + loss_labels_list = np.array([res[patient_id][config.KEY_DICE_LABELS] for patient_id in res]) + hausdorff_labels_list = np.array([res[patient_id][config.KEY_HD_LABELS] for patient_id in res]) + hausdorff95_labels_list = np.array([res[patient_id][config.KEY_HD95_LABELS] for patient_id in res]) + msd_labels_list = np.array([res[patient_id][config.KEY_MSD_LABELS] for patient_id in res]) + + for label_id in range(loss_labels_list.shape[1]): + tmp_vals = loss_labels_list[:,label_id] + loss_labels_avg.append(np.mean(tmp_vals[tmp_vals > 0])) + loss_labels_std.append(np.std(tmp_vals[tmp_vals > 0])) + + if label_id > 0: + tmp_vals = hausdorff_labels_list[:,label_id] + hausdorff_labels_avg.append(np.mean(tmp_vals[tmp_vals > 0])) # avoids -1 for "erroneous" HD, and 0 for "not to be calculated" HD + hausdorff_labels_std.append(np.std(tmp_vals[tmp_vals > 0])) + + tmp_vals = hausdorff95_labels_list[:,label_id] + hausdorff95_labels_avg.append(np.mean(tmp_vals[tmp_vals > 0])) + hausdorff95_labels_std.append(np.std(tmp_vals[tmp_vals > 0])) + + tmp_vals = msd_labels_list[:,label_id] + msd_labels_avg.append(np.mean(tmp_vals[tmp_vals > 0])) + msd_labels_std.append(np.std(tmp_vals[tmp_vals > 0])) + + else: + hausdorff_labels_avg.append(0) + hausdorff_labels_std.append(0) + hausdorff95_labels_avg.append(0) + hausdorff95_labels_std.append(0) + msd_labels_avg.append(0) + msd_labels_std.append(0) + + loss_avg = np.mean([res[patient_id][config.KEY_DICE_AVG] for patient_id in res]) + print (' --------------------------- eval_type: ', eval_type) + print (' - dice_labels_3D : ', ['%.4f' % each for each in loss_labels_avg]) + print (' - dice_labels_3D : ', ['%.4f' % each for each in loss_labels_std]) + print (' - dice_3D : %.4f' % np.mean(loss_labels_avg)) + print (' - dice_3D (w/o bgd): %.4f' % np.mean(loss_labels_avg[1:])) + print (' - dice_3D (w/o bgd, w/o chiasm): %.4f' % np.mean(loss_labels_avg[1:2] + loss_labels_avg[3:])) + print ('') + print (' - hausdorff_labels_3D : ', ['%.4f' % each for each in hausdorff_labels_avg]) + print (' - hausdorff_labels_3D : ', ['%.4f' % each for each in hausdorff_labels_std]) + print (' - hausdorff_3D (w/o bgd): %.4f' % np.mean(hausdorff_labels_avg[1:])) + print (' - hausdorff_3D (w/o bgd, w/o chiasm): %.4f' % np.mean(hausdorff_labels_avg[1:2] + hausdorff_labels_avg[3:])) + print ('') + print (' - hausdorff95_labels_3D : ', ['%.4f' % each for each in hausdorff95_labels_avg]) + print (' - hausdorff95_labels_3D : ', ['%.4f' % each for each in hausdorff95_labels_std]) + print (' - hausdorff95_3D (w/o bgd): %.4f' % np.mean(hausdorff95_labels_avg[1:])) + print (' - hausdorff95_3D (w/o bgd, w/o chiasm): %.4f' % np.mean(hausdorff95_labels_avg[1:2] + hausdorff95_labels_avg[3:])) + print ('') + print (' - msd_labels_3D : ', ['%.4f' % each for each in msd_labels_avg]) + print (' - msd_labels_3D : ', ['%.4f' % each for each in msd_labels_std]) + print (' - msd_3D (w/o bgd): %.4f' % np.mean(msd_labels_avg[1:])) + print (' - msd_3D (w/o bgd, w/o chiasm): %.4f' % np.mean(msd_labels_avg[1:2] + msd_labels_avg[3:])) + + # Step 2 - Summarize ECE + if 1: + print ('') + gc.collect() + nan_value = config.VAL_ECE_NAN + ece_labels_obj = {} + ece_labels = [] + label_count = len(ece_global_obj) + pbar_desc_prefix = '[ECE]' + ece_global_obj_keys = list(ece_global_obj.keys()) + res[config.KEY_PATIENT_GLOBAL] = {} + with tqdm.tqdm(total=label_count, desc=pbar_desc_prefix, disable=True) as pbar_ece: + for label_id in ece_global_obj_keys: + o_true_label = np.array(ece_global_obj[label_id]['o_true_label']) + o_predict_label = np.array(ece_global_obj[label_id]['o_predict_label']) + y_predict_label = np.array(ece_global_obj[label_id]['y_predict_label']) + if label_id in ece_global_obj: del ece_global_obj[label_id] + gc.collect() + + # Step 1.1 - Bin the probs and calculate their mean + y_predict_label_bin_ids = np.digitize(y_predict_label, np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.01]), right=False) - 1 + y_predict_binned_vals = [y_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)] + y_predict_bins_mean = [np.mean(vals) if len(vals) else nan_value for vals in y_predict_binned_vals] + + # Step 1.2 - Calculate the accuracy of each bin + o_predict_label_bins = [o_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)] + o_true_label_bins = [o_true_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)] + y_predict_bins_accuracy = [np.sum(o_predict_label_bins[bin_id] == o_true_label_bins[bin_id])/len(o_predict_label_bins[bin_id]) if len(o_predict_label_bins[bin_id]) else nan_value for bin_id in range(label_count)] + y_predict_bins_len = [len(o_predict_label_bins[bin_id]) for bin_id in range(label_count)] + + # Step 1.3 - Wrapup + N = np.prod(y_predict_label.shape) + ce = np.array((np.array(y_predict_bins_len)/N)*(np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean))) + ce[ce == 0] = nan_value + ece_label = np.sum(np.abs(ce[ce != nan_value])) + ece_labels.append(ece_label) + ece_labels_obj[label_id] = {'y_predict_bins_mean':y_predict_bins_mean, 'y_predict_bins_accuracy':y_predict_bins_accuracy, 'ce':ce, 'ece':ece_label} + + pbar_ece.update(1) + memory = pbar_desc_prefix + ' [' + str(modutils.get_memory(pid)) + ']' + pbar_ece.set_description(desc=memory, refresh=True) + + res[config.KEY_PATIENT_GLOBAL][label_id] = {'ce':ce, 'ece':ece_label} + + print (' - ece_labels : ', ['%.4f' % each for each in ece_labels]) + print (' - ece : %.4f' % np.mean(ece_labels)) + print (' - ece (w/o bgd): %.4f' % np.mean(ece_labels[1:])) + print (' - ece (w/o bgd, w/o chiasm): %.4f' % np.mean(ece_labels[1:2] + ece_labels[3:])) + print ('') + + del ece_global_obj + gc.collect() + + # Step 3 - Plot + if 1: + if save and not deepsup_eval: + f, axarr = plt.subplots(3,1, figsize=(15,10)) + boxplot_dice, boxplot_hausdorff, boxplot_hausdorff95, boxplot_msd = {}, {}, {}, {} + for label_id in range(len(loss_labels_list[0])): + label_name, _ = modutils.get_info_from_label_id(label_id, label_map) + boxplot_dice[label_name] = loss_labels_list[:,label_id] + boxplot_hausdorff[label_name] = hausdorff_labels_list[:,label_id] + boxplot_hausdorff95[label_name] = hausdorff95_labels_list[:,label_id] + boxplot_msd[label_name] = msd_labels_list[:,label_id] + + axarr[0].boxplot(boxplot_dice.values()) + axarr[0].set_xticks(range(1, len(boxplot_dice)+1)) + axarr[0].set_xticklabels(boxplot_dice.keys()) + axarr[0].set_ylim([0.0,1.1]) + axarr[0].set_title('DICE (Avg: {}) \n w/o chiasm:{}'.format( + '%.4f' % np.mean(loss_labels_avg[1:]) + , '%.4f' % np.mean(loss_labels_avg[1:2] + loss_labels_avg[3:]) + ) + ) + + axarr[1].boxplot(boxplot_hausdorff.values()) + axarr[1].set_xticks(range(1, len(boxplot_hausdorff)+1)) + axarr[1].set_xticklabels(boxplot_hausdorff.keys()) + axarr[1].set_ylim([0.0,10.0]) + axarr[1].set_title('Hausdorff') + + axarr[2].boxplot(boxplot_hausdorff95.values()) + axarr[2].set_xticks(range(1, len(boxplot_hausdorff95)+1)) + axarr[2].set_xticklabels(boxplot_hausdorff95.keys()) + axarr[2].set_ylim([0.0,6.0]) + axarr[2].set_title('95% Hausdorff') + + path_results = str(Path(model_folder_epoch_patches).joinpath('results_all.png')) + plt.savefig(path_results, bbox_inches='tight') + plt.close() + + f, axarr = plt.subplots(1, figsize=(15,10)) + axarr.boxplot(boxplot_dice.values()) + axarr.set_xticks(range(1, len(boxplot_dice)+1)) + axarr.set_xticklabels(boxplot_dice.keys()) + axarr.set_ylim([0.0,1.1]) + axarr.set_yticks(np.arange(0,1.1,0.05)) + axarr.set_title('DICE') + axarr.grid() + path_results = str(Path(model_folder_epoch_patches).joinpath('results_all_dice.png')) + plt.savefig(path_results, bbox_inches='tight') + plt.close() + + f, axarr = plt.subplots(1, figsize=(15,10)) + axarr.boxplot(boxplot_hausdorff95.values()) + axarr.set_xticks(range(1, len(boxplot_hausdorff95)+1)) + axarr.set_xticklabels(boxplot_hausdorff95.keys()) + axarr.set_title('95% HD') + axarr.grid() + path_results = str(Path(model_folder_epoch_patches).joinpath('results_all_hd95.png')) + plt.savefig(path_results, bbox_inches='tight') + plt.close() + + f, axarr = plt.subplots(1, figsize=(15,10)) + axarr.boxplot(boxplot_hausdorff.values()) + axarr.set_xticks(range(1, len(boxplot_hausdorff)+1)) + axarr.set_xticklabels(boxplot_hausdorff.keys()) + axarr.set_title('Hausdorff Distance') + axarr.grid() + path_results = str(Path(model_folder_epoch_patches).joinpath('results_all_hd.png')) + plt.savefig(path_results, bbox_inches='tight') + plt.close() + + f, axarr = plt.subplots(1, figsize=(15,10)) + axarr.boxplot(boxplot_msd.values()) + axarr.set_xticks(range(1, len(boxplot_msd)+1)) + axarr.set_xticklabels(boxplot_msd.keys()) + axarr.set_title('MSD') + axarr.grid() + path_results = str(Path(model_folder_epoch_patches).joinpath('results_all_msd.png')) + plt.savefig(path_results, bbox_inches='tight') + plt.close() + + # ECE + for label_id in ece_labels_obj: + y_predict_bins_mean = ece_labels_obj[label_id]['y_predict_bins_mean'] + y_predict_bins_accuracy = ece_labels_obj[label_id]['y_predict_bins_accuracy'] + ece = ece_labels_obj[label_id]['ece'] + + plt.plot(np.arange(11), np.arange(11)/10.0, linestyle='dashed', color='black', alpha=0.8) + plt.scatter(np.arange(len(y_predict_bins_mean)) + 0.5 , y_predict_bins_mean, alpha=0.5, color='g', marker='s', label='Mean Pred') + plt.scatter(np.arange(len(y_predict_bins_accuracy)) + 0.5 , y_predict_bins_accuracy, alpha=0.5, color='b', marker='x', label='Accuracy') + for bin_id in range(len(y_predict_bins_accuracy)): plt.plot([bin_id + 0.5, bin_id + 0.5],[y_predict_bins_accuracy[bin_id], y_predict_bins_mean[bin_id]], color='pink') + plt.plot([bin_id + 0.5, bin_id + 0.5],[y_predict_bins_accuracy[bin_id], y_predict_bins_mean[bin_id]], color='pink', label='CE') + plt.xticks(ticks=np.arange(11), labels=np.arange(11)/10.0) + plt.title('CE (label={})\nECE: {}'.format(label_id, '%.5f' % (ece))) + plt.xlabel('Probability') + plt.ylabel('Accuracy') + plt.ylim([-0.15, 1.05]) + plt.legend() + + # plt.show() + path_results = str(Path(model_folder_epoch_patches).joinpath('results_ece_label{}.png'.format(label_id))) + plt.savefig(str(path_results), bbox_inches='tight') + plt.close() + + # Step 5 - Save data as .json + if 1: + try: + filename = str(Path(model_folder_epoch_patches).joinpath(config.FILENAME_EVAL3D_JSON)) + modutils.write_json(res, filename) + + except: + traceback.print_exc() + if DEBUG: pdb.set_trace() + + model.trainable=True + print ('\n - [eval_3D] Avg of times_mcruns : {:f} +- {:f}'.format(np.mean(times_mcruns), np.std(times_mcruns))) + print (' - [eval_3D()] Total time passed (save={}) : {}s \n'.format(save, round(time.time() - ttotal, 2))) + if verbose: pdb.set_trace() + + return loss_avg, {i:loss_labels_avg[i] for i in range(len(loss_labels_avg))} + + except: + model.trainable=True + traceback.print_exc() + if DEBUG: pdb.set_trace() + return -1, {} + +def eval_3D_process_outputs(exp_name, res, ece_global_obj, patient_id_curr, meta1_batch, patient_img, patient_gt, patient_pred_vals, patient_pred_overlap, patient_pred_std, patient_pred_ent, patient_pred_mif, patient_pred_unc + , deepsup_eval, model_folder_epoch_imgs, model_folder_epoch_patches, label_map, label_colors, t99, save=False, verbose=False): + + try: + + # Step 3.1.1 - Get stitched patient grid + if verbose: t0 = time.time() + patient_pred_ent = patient_pred_ent/patient_pred_overlap # [H,W,D]/[H,W,D] + patient_pred_mif = patient_pred_mif/patient_pred_overlap + patient_pred_overlap = np.expand_dims(patient_pred_overlap, -1) + patient_pred = patient_pred_vals/patient_pred_overlap # [H,W,D,C]/[H,W,D,1] + patient_pred_std = patient_pred_std/patient_pred_overlap + patient_pred_unc = patient_pred_unc/patient_pred_overlap + del patient_pred_vals + del patient_pred_overlap + patient_pred_unc = np.take_along_axis(patient_pred_unc, np.expand_dims(np.argmax(patient_pred, axis=-1),axis=-1), axis=-1)[:,:,:,0] + + gc.collect() # returns number of unreachable objects collected by GC + patient_pred_postprocessed = losses.remove_smaller_components(patient_gt, patient_pred, meta=patient_id_curr, label_ids_small = [2,4,5]) + if verbose: print (' - [eval_3D_process_outputs()] Post-Process time : ', time.time() - t0,'s') + + # Step 3.1.2 - Loss Calculation + spacing = np.array([meta1_batch[4], meta1_batch[5], meta1_batch[6]])/100.0 + try: + if verbose: t0 = time.time() + loss_avg_val, loss_labels_val = losses.dice_numpy(patient_gt, patient_pred_postprocessed) + hausdorff_avg_val, hausdorff_labels_val, hausdorff95_avg_val, hausdorff95_labels_val, msd_avg_val, msd_labels_vals = losses.get_surface_distances(patient_gt, patient_pred_postprocessed, spacing) + if verbose: + print (' - [eval_3D_process_outputs()] DICE : ', ['%.4f' % (each) for each in loss_labels_val]) + print (' - [eval_3D_process_outputs()] HD95 : ', ['%.4f' % (each) for each in hausdorff95_labels_val]) + + if loss_avg_val != -1 and len(loss_labels_val): + res[patient_id_curr] = { + config.KEY_DICE_AVG : loss_avg_val + , config.KEY_DICE_LABELS : loss_labels_val + , config.KEY_HD_AVG : hausdorff95_avg_val + , config.KEY_HD_LABELS : hausdorff_labels_val + , config.KEY_HD95_AVG : hausdorff95_avg_val + , config.KEY_HD95_LABELS : hausdorff95_labels_val + , config.KEY_MSD_AVG : msd_avg_val + , config.KEY_MSD_LABELS : msd_labels_vals + } + else: + print (' - [INFO][eval_3D_process_outputs()] patient_id: ', patient_id_curr) + if verbose: print (' - [eval_3D_process_outputs()] Loss calculation time: ', time.time() - t0,'s') + except: + traceback.print_exc() + if DEBUG: pdb.set_trace() + + # Step 3.1.3 - ECE calculation + if verbose: t0 = time.time() + ece_global_obj, ece_patient_obj = get_ece(patient_gt, patient_pred, patient_id_curr, ece_global_obj) + res[patient_id_curr][config.KEY_ECE_LABELS] = ece_patient_obj + if verbose: print (' - [eval_3D_process_outputs()] ECE time : ', time.time() - t0,'s') + + # Step 3.1.5 - Save/Visualize + if not deepsup_eval: + if verbose: t0 = time.time() + eval_3D_finalize(exp_name, patient_img, patient_gt, patient_pred_postprocessed, patient_pred, patient_pred_std, patient_pred_ent, patient_pred_mif, patient_pred_unc + , patient_id_curr + , model_folder_epoch_patches + , loss_labels_val, hausdorff_labels_val, hausdorff95_labels_val, msd_labels_vals + , spacing, label_map + , save=save) + if verbose: print (' - [eval_3D_process_outputs()] Save as .nrrd time : ', time.time() - t0,'s') + + if verbose: print (' - [eval_3D_process_outputs()] Total patient time : ', time.time() - t99,'s') + + # Step 3.1.6 + del patient_img + del patient_gt + del patient_pred + del patient_pred_std + del patient_pred_ent + del patient_pred_postprocessed + del patient_pred_mif + del patient_pred_unc + gc.collect() + + return res, ece_global_obj + + except: + traceback.print_exc() + if DEBUG: pdb.set_trace() + return res, ece_global_obj + +def eval_3D_get_outputs(model, X, Y, training_bool, MC_RUNS, deepsup, deepsup_eval): + + # Step 0 - Init + DO_KEYS = [config.KEY_MIF, config.KEY_ENT] + # DO_KEYS = [config.KEY_STD] + + # Step 1 - Warm up model + _ = model(X, training=training_bool) + + # Step 2 - Run Monte-Carlo predictions + error_mcruns = False + try: + tic_mcruns = time.time() + y_predict = tf.stack([model(X, training=training_bool) for _ in range(MC_RUNS)]) # [MC,B,H,W,D,C] # [Note] with model.trainable=False and model.training=True we get dropout at inference time + toc_mcruns = time.time() + except tf.errors.ResourceExhaustedError as e: + print (' - [eval_3D_get_outputs()] OOM error for MC_RUNS={}'.format(MC_RUNS)) + error_mcruns = True + + if error_mcruns: + try: + MC_RUNS = 5 + tic_mcruns = time.time() + if deepsup: + if deepsup_eval: + y_predict = tf.stack([model(X, training=training_bool)[0] for _ in range(MC_RUNS)]) # [MC,B,H,W,D,C] # [Note] with model.trainable=False and model.training=True we get dropout at inference time + X = X[:,::2,::2,::2,:] + Y = Y[:,::2,::2,::2,:] + else: + y_predict = tf.stack([model(X, training=training_bool)[1] for _ in range(MC_RUNS)]) # [MC,B,H,W,D,C] # [Note] with model.trainable=False and model.training=True we get dropout at inference time + else: + y_predict = tf.stack([model(X, training=training_bool) for _ in range(MC_RUNS)]) # [MC,B,H,W,D,C] # [Note] with model.trainable=False and model.training=True we get dropout at inference time + toc_mcruns = time.time() + except tf.errors.ResourceExhaustedError as e: + print (' - [eval_3D_get_outputs()] OOM error for MC_RUNS=5') + import sys; sys.exit(1) + + # Step 3 - Calculate different metrics + if config.KEY_MIF in DO_KEYS: + y_predict_mif = y_predict * tf.math.log(y_predict + _EPSILON) # [MC,B,H,W,D,C] + y_predict_mif = tf.math.reduce_sum(y_predict_mif, axis=[0,-1])/MC_RUNS # [MC,B,H,W,D,C] -> [B,H,W,D] + else: + y_predict_mif = [] + + if config.KEY_STD in DO_KEYS: + y_predict_std = tf.math.reduce_std(y_predict, axis=0) # [MC,B,H,W,D,C] -> [B,H,W,D,C] + else: + y_predict_std = [] + + if config.KEY_PERC in DO_KEYS: + y_predict_perc = tfp.stats.percentile(y_predict, q=[30,70], axis=0, interpolation='nearest') + y_predict_unc = y_predict_perc[1] - y_predict_perc[0] + del y_predict_perc + gc.collect() + else: + y_predict_unc = [] + + y_predict = tf.math.reduce_mean(y_predict, axis=0) # [MC,B,H,W,D,C] -> [B,H,W,D,C] + + if config.KEY_ENT in DO_KEYS: + y_predict_ent = -1*tf.math.reduce_sum(y_predict * tf.math.log(y_predict + _EPSILON), axis=-1) # [B,H,W,D,C] -> # [B,H,W,D] ent = -p.log(p) + y_predict_mif = y_predict_ent + y_predict_mif # [B,H,W,D] + [B,H,W,D] = [B,H,W,D]; MI = ent + expectation(ent) + else: + y_predict_ent = [] + y_predict_mif = [] + + return Y, y_predict, y_predict_std, y_predict_ent, y_predict_mif, y_predict_unc, toc_mcruns-tic_mcruns + +def eval_3D(model, dataset_eval, dataset_eval_gen, params, show=False, save=False, verbose=False): + + try: + + # Step 0.0 - Variables under debugging + pass + + # Step 0.1 - Extract params + PROJECT_DIR = config.PROJECT_DIR + exp_name = params['exp_name'] + pid = params['pid'] + eval_type = params['eval_type'] + batch_size = 2 + epoch = params['epoch'] + deepsup = params['deepsup'] + deepsup_eval = params['deepsup_eval'] + label_map = dict(dataset_eval.get_label_map()) + label_colors = dict(dataset_eval.get_label_colors()) + + if verbose: print (''); print (' --------------------- eval_3D({}) ---------------------'.format(eval_type)) + + # Step 0.2 - Init results array + res = {} + ece_global_obj = {} + patient_grid_count = {} + + # Step 0.3 - Init temp variables + patient_id_curr = None + w_grid, h_grid, d_grid = None, None, None + meta1_batch = None + patient_gt = None + patient_img = None + patient_pred_overlap = None + patient_pred_vals = None + model_folder_epoch_patches = None + model_folder_epoch_imgs = None + + mc_runs = params.get(config.KEY_MC_RUNS, None) + training_bool = params.get(config.KEY_TRAINING_BOOL, None) + model_folder_epoch_patches, model_folder_epoch_imgs = modutils.get_eval_folders(PROJECT_DIR, exp_name, epoch, eval_type, mc_runs, training_bool, create=True) + + # Step 0.4 - Debug vars + ttotal,t0, t99 = time.time(), None, None + times_mcruns = [] + + # Step 1 - Loop over dataset_eval (which provides patients & grids in an ordered manner) + print ('') + model.trainable = False + pbar_desc_prefix = 'Eval3D_{} [batch={}]'.format(eval_type, batch_size) + training_bool = params.get(config.KEY_TRAINING_BOOL,True) # [True, False] + with tqdm.tqdm(total=len(dataset_eval), desc=pbar_desc_prefix, leave=False) as pbar_eval: + for (X,Y,meta1,meta2) in dataset_eval_gen.repeat(1): + + MC_RUNS = params.get(config.KEY_MC_RUNS,config.VAL_MC_RUNS_DEFAULT) + Y, y_predict, y_predict_std, y_predict_ent, y_predict_mif, y_predict_unc, mcruns_time = eval_3D_get_outputs(model, X, Y, training_bool, MC_RUNS, deepsup, deepsup_eval) + times_mcruns.append(mcruns_time) + + for batch_id in range(X.shape[0]): + + # Step 2 - Get grid info + patient_id_running = meta2[batch_id].numpy().decode('utf-8') + if patient_id_running in patient_grid_count: patient_grid_count[patient_id_running] += 1 + else: patient_grid_count[patient_id_running] = 1 + + meta1_batch = meta1[batch_id].numpy() + w_start, h_start, d_start = meta1_batch[1], meta1_batch[2], meta1_batch[3] + + # Step 3 - Check if its a new patient + if patient_id_running != patient_id_curr: + + # Step 3.1 - Sort out old patient (patient_id_curr) + if patient_id_curr != None: + + res, ece_global_obj = eval_3D_process_outputs(exp_name, res, ece_global_obj, patient_id_curr, meta1_batch, patient_img, patient_gt, patient_pred_vals, patient_pred_overlap, patient_pred_std, patient_pred_ent, patient_pred_mif, patient_pred_unc + , deepsup_eval, model_folder_epoch_imgs, model_folder_epoch_patches, label_map, label_colors, t99, show=show, save=save, verbose=verbose) + + # Step 3.2 - Create variables for new patient + if verbose: t99 = time.time() + patient_id_curr = patient_id_running + patient_scan_size = meta1_batch[7:10] + dataset_name = patient_id_curr.split('-')[0] + dataset_this = dataset_eval.get_subdataset(param_name=dataset_name) + w_grid, h_grid, d_grid = dataset_this.w_grid, dataset_this.h_grid, dataset_this.d_grid + if deepsup_eval: + patient_scan_size = patient_scan_size//2 + w_grid, h_grid, d_grid = w_grid//2, h_grid//2, d_grid//2 + + patient_pred_size = list(patient_scan_size) + [len(dataset_this.LABEL_MAP)] + patient_pred_overlap = np.zeros(patient_scan_size, dtype=np.uint8) + patient_pred_ent = np.zeros(patient_scan_size, dtype=np.float32) + patient_pred_mif = np.zeros(patient_scan_size, dtype=np.float32) + patient_pred_vals = np.zeros(patient_pred_size, dtype=np.float32) + patient_pred_std = np.zeros(patient_pred_size, dtype=np.float32) + patient_pred_unc = np.zeros(patient_pred_size, dtype=np.float32) + patient_gt = np.zeros(patient_pred_size, dtype=np.float32) + if show or save: + patient_img = np.zeros(list(patient_scan_size) + [1], dtype=np.float32) + else: + patient_img = [] + + # Step 4 - If not new patient anymore, fill up data + if deepsup_eval: + w_start, h_start, d_start = w_start//2, h_start//2, d_start//2 + patient_pred_vals[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += y_predict[batch_id] + if len(y_predict_std): + patient_pred_std[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += y_predict_std[batch_id] + if len(y_predict_ent): + patient_pred_ent[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += y_predict_ent[batch_id] + if len(y_predict_mif): + patient_pred_mif[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += y_predict_mif[batch_id] + if len(y_predict_unc): + patient_pred_unc[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += y_predict_unc[batch_id] + + patient_pred_overlap[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += np.ones(y_predict[batch_id].shape[:-1], dtype=np.uint8) + patient_gt[w_start:w_start+w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] = Y[batch_id] + if show or save: + patient_img[w_start:w_start+w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] = X[batch_id] + + pbar_eval.update(batch_size) + mem_used = modutils.get_memory(pid) + memory = pbar_desc_prefix + ' [' + mem_used + ']' + pbar_eval.set_description(desc=memory, refresh=True) + + # Step 3 - For last patient + res, ece_global_obj = eval_3D_process_outputs(exp_name, res, ece_global_obj, patient_id_curr, meta1_batch, patient_img, patient_gt, patient_pred_vals, patient_pred_overlap, patient_pred_std, patient_pred_ent, patient_pred_mif, patient_pred_unc + , deepsup_eval, model_folder_epoch_imgs, model_folder_epoch_patches, label_map, label_colors, t99, show=show, save=save, verbose=verbose) + print ('\n - [eval_3D()] Time passed to accumulate grids & process patients: ', round(time.time() - ttotal, 2), 's') + + return eval_3D_summarize(res, ece_global_obj, model, eval_type, deepsup_eval, label_map, model_folder_epoch_patches, times_mcruns, ttotal, save=save, show=show, verbose=verbose) + + except: + traceback.print_exc() + if DEBUG: pdb.set_trace() + model.trainable=True + return -1, {} + +############################################################ +# VAL # +############################################################ + +class Validator: + + def __init__(self, params): + + self.params = params + + # Dataloader Params + self.data_dir = self.params['dataloader']['data_dir'] + self.dir_type = self.params['dataloader']['dir_type'] + self.grid = self.params['dataloader']['grid'] + self.crop_init = self.params['dataloader']['crop_init'] + self.resampled = self.params['dataloader']['resampled'] + + self._get_dataloader() + self._get_model() + + def _set_dataloader(self): + + # Step 1 - Init params + batch_size = self.params['dataloader']['batch_size'] + prefetch_batch = self.params['dataloader']['prefetch_batch'] + + # Step 2 - Load dataset + if self.dir_type == config.DATALOADER_MICCAI2015_TEST: + + self.dataset_test_eval = datautils.get_dataloader_3D_test_eval(self.data_dir, dir_type=self.dir_type + , grid=self.grid, crop_init=self.crop_init + , resampled=self.resampled + ) + + elif self.dir_type == config.DATALOADER_DEEPMINDTCIA_TEST: + + self.annotator_type = self.params['dataloader']['annotator_type'] + self.dataset_test_eval = datautils.get_dataloader_deepmindtcia(self.data_dir, dir_type=self.dir_type, annotator_type=self.annotator_type + , grid=self.grid, crop_init=self.crop_init, resampled=self.resampled + ) + + # Step 3 - Load dataloader/datagenerator + self.datagen_test_eval = self.dataset_test_eval.generator().batch(batch_size).prefetch(prefetch_batch) + + def _set_model(self): + + # Step 1 - Init params + exp_name = self.params['exp_name'] + load_epoch = self.params['model']['load_epoch'] + class_count = len(self.dataset_test_eval.datasets[0].LABEL_MAP.values()) + + # Step 2 - Get model arch + if self.params['model']['name'] == config.MODEL_FOCUSNET_DROPOUT: + print (' - [Trainer][_models()] ModelFocusNetDropOut') + self.model = models.ModelFocusNetDropOut(class_count=class_count, trainable=False) + + # Step 3 - Load model + load_model_params = {'PROJECT_DIR': config.PROJECT_DIR + , 'exp_name': exp_name + , 'load_epoch': load_epoch + , 'optimizer': tf.keras.optimizers.Adam() + } + modutils.load_model(self.model, load_type=config.MODE_TRAIN, params=load_model_params) + print ('') + print (' - [train.py][val()] Model({}) Loaded for {} at epoch-{} (validation purposes) !'.format(str(self.model), exp_name, load_epoch)) + print ('') + + def validate(self, verbose=False): + + save = self.params['save'] + loss_avg, loss_labels_avg = eval_3D(self.model, self.dataset_test_eval, self.datagen_test_eval, self.params, save=save, verbose=verbose) + +############################################################ +# TRAINER # +############################################################ + +class Trainer: + + def __init__(self, params): + + # Init + self.params = params + + # Print + self._train_preprint() + + # Random Seeds + self._set_seed() + + # Set the dataloaders + self._set_dataloaders() + + # Set the model + self._set_model() + + # Set Metrics + self._set_metrics() + + # Other flags + self.write_model_done = False + + def _train_preprint(self): + print ('') + print (' -------------- {} ({})'.format(self.params['exp_name'], str(datetime.datetime.now()))) + + print ('') + print (' DATALOADER ') + print (' ---------- ') + print (' - dir_type: ', self.params['dataloader']['dir_type']) + + print (' -- resampled: ', self.params['dataloader']['resampled']) + print (' -- crop_init: ', self.params['dataloader']['crop_init']) + print (' -- grid: ', self.params['dataloader']['grid']) + print (' --- filter_grid : ', self.params['dataloader']['filter_grid']) + print (' --- random_grid : ', self.params['dataloader']['random_grid']) + print (' --- centred_prob : ', self.params['dataloader']['centred_prob']) + + print (' -- batch_size: ', self.params['dataloader']['batch_size']) + print (' -- prefetch_batch : ', self.params['dataloader']['prefetch_batch']) + print (' -- parallel_calls : ', self.params['dataloader']['parallel_calls']) + print (' -- shuffle : ', self.params['dataloader']['shuffle']) + + print ('') + print (' MODEL ') + print (' ----- ') + print (' - Model: ', str(self.params['model']['name'])) + print (' -- Model TBoard: ', self.params['model']['model_tboard']) + print (' -- Profiler: ', self.params['model']['profiler']['profile']) + if self.params['model']['profiler']['profile']: + print (' ---- Profiler Epochs: ', self.params['model']['profiler']['epochs']) + print (' ---- Step Per Epochs: ', self.params['model']['profiler']['steps_per_epoch']) + print (' - Optimizer: ', str(self.params['model']['optimizer'])) + print (' -- Init LR: ', self.params['model']['init_lr']) + print (' -- Fixed LR: ', self.params['model']['fixed_lr']) + + print (' - Epochs: ', self.params['model']['epochs']) + print (' -- Save : every {} epochs'.format(self.params['model']['epochs_save'])) + print (' -- Eval3D : every {} epochs '.format(self.params['model']['epochs_eval'])) + print (' -- Viz3D : every {} epochs '.format(self.params['model']['epochs_viz'])) + + print ('') + print (' METRICS ') + print (' ------- ') + print (' - Logging-TBoard : ', self.params['metrics']['logging_tboard']) + if not self.params['metrics']['logging_tboard']: + print (' !!!!!!!!!!!!!!!!!!! NO LOGGING-TBOARD !!!!!!!!!!!!!!!!!!!') + print ('') + print (' - Eval : ', self.params['metrics']['metrics_eval']) + print (' - Loss : ', self.params['metrics']['metrics_loss']) + print (' -- Weighted Loss : ', self.params['metrics']['loss_weighted']) + print (' -- Masked Loss : ', self.params['metrics']['loss_mask']) + print (' -- Combo : ', self.params['metrics']['loss_combo']) + + print ('') + print (' DEVOPS ') + print (' ------ ') + self.pid = os.getpid() + print (' - OS-PID: ', self.pid) + print (' - Seed: ', self.params['random_seed']) + + print ('') + + def _set_seed(self): + np.random.seed(self.params['random_seed']) + tf.random.set_seed(self.params['random_seed']) + + def _set_dataloaders(self): + + print ('') + print (' DATALOADER OBJECTS') + print (' ---------- ') + + # Params - Directories + data_dir = self.params['dataloader']['data_dir'] + dir_type = self.params['dataloader']['dir_type'] + dir_type_eval = ['_'.join(dir_type)] + + # Params - Single volume + resampled = self.params['dataloader']['resampled'] + crop_init = self.params['dataloader']['crop_init'] + grid = self.params['dataloader']['grid'] + filter_grid = self.params['dataloader']['filter_grid'] + random_grid = self.params['dataloader']['random_grid'] + centred_prob = self.params['dataloader']['centred_prob'] + + # Params - Dataloader + batch_size = self.params['dataloader']['batch_size'] + prefetch_batch = self.params['dataloader']['prefetch_batch'] + parallel_calls = self.params['dataloader']['parallel_calls'] + shuffle_size = self.params['dataloader']['shuffle'] + + # Params - Debug + pass + + # Datasets + self.dataset_train = datautils.get_dataloader_3D_train(data_dir, dir_type=dir_type + , grid=grid, crop_init=crop_init, filter_grid=filter_grid + , random_grid=random_grid + , resampled=resampled + , parallel_calls=parallel_calls + , centred_dataloader_prob=centred_prob + ) + self.dataset_train_eval = datautils.get_dataloader_3D_train_eval(data_dir, dir_type=dir_type_eval + , grid=grid, crop_init=crop_init + , resampled=resampled + ) + self.dataset_test_eval = datautils.get_dataloader_3D_test_eval(data_dir + , grid=grid, crop_init=crop_init + , resampled=resampled + ) + + # Get labels Ids + self.label_map = dict(self.dataset_train.get_label_map()) + self.label_ids = self.label_map.values() + self.params['internal'] = {} + self.params['internal']['label_map'] = self.label_map # for use in Metrics + self.params['internal']['label_ids'] = self.label_ids # for use in Metrics + self.label_weights = list(self.dataset_train.get_label_weights()) + + # Generators + # repeat() -> shuffle() --> batch() --> prefetch(): endlessly-output-data --> shuffle-within-buffer --> create-batch-from-shuffle --> prefetch-enough-batches + self.dataset_train_gen = self.dataset_train.generator().repeat().shuffle(shuffle_size).batch(batch_size).apply(tf.data.experimental.prefetch_to_device(device='/GPU:0', buffer_size=prefetch_batch)) + self.dataset_train_eval_gen = self.dataset_train_eval.generator().batch(2).prefetch(4) + self.dataset_test_eval_gen = self.dataset_test_eval.generator().batch(2).prefetch(4) + + def _set_model(self): + + print ('') + print (' MODEL OBJECTS') + print (' ---------- ') + + # Step 1 - Get class ids + class_count = len(self.label_ids) + + # Step 2 - Get model arch + if self.params['model']['name'] == config.MODEL_FOCUSNET_DROPOUT: + print (' - [Trainer][_models()] ModelFocusNetDropOut') + self.model = models.ModelFocusNetDropOut(class_count=class_count, trainable=True) + self.kl_alpha_init = 0.0 + + elif self.params['model']['name'] == config.MODEL_FOCUSNET_FLIPOUT: + print (' - [Trainer][_models()] ModelFocusNetFlipOut') + self.model = models.ModelFocusNetFlipOut(class_count=class_count, trainable=True) + self.kl_schedule = self.params['model']['kl_schedule'] + self.kl_alpha_init = 0.1 + self.kl_alpha_increase_per_epoch = 0.0001 + self.initial_epoch = 250 + + # Step 3 - Get optimizer + if self.params['model']['optimizer'] == config.OPTIMIZER_ADAM: + self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.params['model']['init_lr']) + + # Step 4 - Load model (if needed) + epochs = self.params['model']['epochs'] + if not self.params['model']['load_model']['load']: + # Step 4.1 - Set epoch range under non-loading situations + self.epoch_range = range(1,epochs+1) + else: + + # Step 4.2.1 - Some model-loading params + load_epoch = self.params['model']['load_model']['load_epoch'] + load_exp_name = self.params['model']['load_model']['load_exp_name'] + load_optimizer_lr = self.params['model']['load_model']['load_optimizer_lr'] + load_model_params = {'PROJECT_DIR': config.PROJECT_DIR, 'load_epoch': load_epoch, 'optimizer':self.optimizer} + + print ('') + print (' - [Trainer][_set_model()] Loading pretrained model') + print (' - [Trainer][_set_model()] Model: ', self.model) + + # Step 4.2.2.1 - If loading is done from the same exp_name + if load_exp_name is None: + load_model_params['exp_name'] = self.params['exp_name'] + self.epoch_range = range(load_epoch+1, epochs+1) + print (' - [Trainer][_set_model()] Training from epoch:{} to {}'.format(load_epoch, epochs)) + # Step 4.2.2.1 - If loading is done from another exp_name + else: + self.epoch_range = range(1, epochs+1) + load_model_params['exp_name'] = load_exp_name + print (' - [Trainer][_set_model()] Training from epoch:{} to {}'.format(1, epochs)) + + print (' - [Trainer][_set_model()] exp_name: ', load_model_params['exp_name']) + + # Step 4.3 - Finally, load model from the checkpoint + modutils.load_model(self.model, load_type=config.MODE_TRAIN, params=load_model_params) + print (' - [Trainer][_set_model()] Model Loaded at epoch-{} !'.format(load_epoch)) + print (' -- [Trainer][_set_model()] Optimizer.lr : ', self.optimizer.lr.numpy()) + if load_optimizer_lr is not None: + self.optimizer.lr.assign(load_optimizer_lr) + print (' -- [Trainer][_set_model()] Optimizer.lr : ', self.optimizer.lr.numpy()) + + # Step 5 - Create model weights + _ = self.model(tf.random.normal((1,140,140,40,1))) # [NOTE: This needs to be the same as the input grid size in FlipOut] + print (' -- [Trainer][_set_model()] Created model weights ') + try: + print (' --------------------------------------- ') + print (self.model.summary(line_length=150)) + print (' --------------------------------------- ') + count = 0 + for var in self.model.trainable_variables: + print (' - var: ', var.name) + count += 1 + if count > 20: + print (' ... ') + break + except: + print (' - [Trainer][_set_model()] model.summary() failed') + pass + + def _set_metrics(self): + self.metrics = {} + self.metrics[config.MODE_TRAIN] = ModelMetrics(metric_type=config.MODE_TRAIN, params=self.params) + self.metrics[config.MODE_TEST] = ModelMetrics(metric_type=config.MODE_TEST, params=self.params) + + def _set_profiler(self, epoch, epoch_step): + """ + - Ref: https://www.tensorflow.org/guide/data_performance_analysis#analysis_workflow + """ + exp_name = self.params['exp_name'] + + if self.params['model']['profiler']['profile']: + if epoch in self.params['model']['profiler']['epochs']: + if epoch_step == self.params['model']['profiler']['starting_step']: + self.logdir = Path(config.MODEL_CHKPOINT_MAINFOLDER).joinpath(exp_name, config.MODEL_LOGS_FOLDERNAME, 'profiler', str(epoch)) + tf.profiler.experimental.start(str(self.logdir)) + print (' - tf.profiler.experimental.start(logdir)') + print ('') + elif epoch_step == self.params['model']['profiler']['starting_step'] + self.params['model']['profiler']['steps_per_epoch']: + print (' - tf.profiler.experimental.stop()') + tf.profiler.experimental.stop() + print ('') + + @tf.function + def _train_loss(self, Y, y_predict, meta1, mode): + + trainMetrics = self.metrics[mode] + metrics_loss = self.params['metrics']['metrics_loss'] + loss_weighted = self.params['metrics']['loss_weighted'] + loss_mask = self.params['metrics']['loss_mask'] + loss_combo = self.params['metrics']['loss_combo'] + + label_ids = self.label_ids + label_weights = tf.cast(self.label_weights, dtype=tf.float32) + + loss_vals = tf.constant(0.0, dtype=tf.float32) + mask = losses.get_mask(meta1[:,-len(label_ids):], Y) + + inf_flag = False + nan_flag = False + + for metric_str in metrics_loss: + + weights = [] + if loss_weighted[metric_str]: + weights = label_weights + + if not loss_mask[metric_str]: + mask = tf.cast(tf.cast(mask + 1, dtype=tf.bool), dtype=tf.float32) + + if metrics_loss[metric_str] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_CE_BASIC]: + loss_val_train, loss_labellist_train, metric_val_report, metric_labellist_report = trainMetrics.losses_obj[metric_str](Y, y_predict, mask, weights=weights) + if metrics_loss[metric_str] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_CE_BASIC]: + nan_list = tf.math.is_nan(loss_labellist_train) + nan_val = tf.math.is_nan(loss_val_train) + inf_list = tf.math.is_inf(loss_labellist_train) + inf_val = tf.math.is_inf(loss_val_train) + if nan_val or tf.math.reduce_any(nan_list): + nan_flag = True + tf.print (' - [ERROR][Trainer][_train_loss()] Loss NaN spotted: ', metric_str, ' || nan_list: ', nan_list, ' || nan_val: ', nan_val) + tf.print (' - [ERROR][Trainer][_train_loss()] Loss NaN spotted: ', metric_str, ' || mask: ', mask) + tf.print (' - [ERROR][Trainer][_train_loss()] Loss NaN spotted: ', metric_str, ' || loss_vals: ', loss_vals) + elif inf_val or tf.math.reduce_any(inf_list): + inf_flag = True + tf.print (' - [ERROR][Trainer][_train_loss()] Loss Inf spotted: ', metric_str, ' || loss_val_train: ', loss_val_train) + tf.print (' - [ERROR][Trainer][_train_loss()] Loss Inf spotted: ', metric_str, ' || inf_list: ', inf_list, ' || inf_val: ', inf_val) + tf.print (' - [ERROR][Trainer][_train_loss()] Loss Inf spotted: ', metric_str, ' || mask: ', mask) + tf.print (' - [ERROR][Trainer][_train_loss()] Loss NaN spotted: ', metric_str, ' || loss_vals: ', loss_vals) + else: + if len(metric_labellist_report): + trainMetrics.update_metric_loss_labels(metric_str, metric_labellist_report) # in sub-3D settings, this value is only indicative of performance + trainMetrics.update_metric_loss(metric_str, loss_val_train) + + if 0: # [DEBUG] + tf.print(' - metric_str: ', metric_str, ' || loss_val_train: ', loss_val_train) + + if metric_str in loss_combo: + loss_val_train = loss_val_train*loss_combo[metric_str] + loss_vals = tf.math.add(loss_vals, loss_val_train) # Averaged loss + + + if nan_flag or inf_flag: + loss_vals = 0.0 # no backprop when something was wrong + + return loss_vals + + @tf.function + def _train_step(self, X, Y, meta1, meta2, kl_alpha): + + try: + model = self.model + optimizer = self.optimizer + if self.params['model']['name'] == config.MODEL_FOCUSNET_FLIPOUT: + kl_scale_fac = self.params['model']['kl_scale_factor'] + else: + kl_scale_fac = 0.0 + + y_predict = None + loss_vals = None + gradients = None + + # Step 1 - Calculate loss + with tf.GradientTape() as tape: + + t2 = tf.timestamp() + y_predict = model(X, training=True) + t2_ = tf.timestamp() + + loss_vals = self._train_loss(Y, y_predict, meta1, mode=config.MODE_TRAIN) + + if self.params['model']['name'] == config.MODEL_FOCUSNET_FLIPOUT: + print (' - [Trainer][_train_step()] Model FlipOut') + kl = tf.math.add_n(model.losses) + kl_loss = kl*kl_alpha/kl_scale_fac + loss_vals = loss_vals + kl_loss + + # Step 2 - Calculate gradients + t3 = tf.timestamp() + if not tf.math.reduce_any(tf.math.is_nan(loss_vals)): + all_vars = model.trainable_variables + + gradients = tape.gradient(loss_vals, all_vars) # dL/dW + + # Step 3 - Apply gradients + optimizer.apply_gradients(zip(gradients, all_vars)) + + else: + tf.print('\n ====================== [NaN Error] ====================== ') + tf.print(' - [ERROR][Trainer][_train_step()] Loss NaN spotted || loss_vals: ', loss_vals) + tf.print(' - [ERROR][Trainer][_train_step()] meta2: ', meta2, ' || meta1: ', meta1) + + t3_ = tf.timestamp() + return t2_-t2, t3-t2_, t3_-t3 + + except tf.errors.ResourceExhaustedError as e: + tf.print('\n ====================== [OOM Error] ====================== ') + tf.print(' - [ERROR][Trainer][_train_step()] meta2: ', meta2, ' || meta1: ', meta1) + traceback.print_exc() + return None, None, None + + except: + tf.print('\n ====================== [Some Error] ====================== ') + tf.print(' - [ERROR][Trainer][_train_step()] meta2: ', meta2, ' || meta1: ', meta1) + traceback.print_exc() + return None, None, None + + def train(self): + + # Global params + exp_name = self.params['exp_name'] + + # Dataloader params + batch_size = self.params['dataloader']['batch_size'] + + # Model/Training params + fixed_lr = self.params['model']['fixed_lr'] + init_lr = self.params['model']['init_lr'] + max_epoch = self.params['model']['epochs'] + epoch_range = iter(self.epoch_range) + epoch_length = len(self.dataset_train) + params_save_model = {'PROJECT_DIR': config.PROJECT_DIR, 'exp_name': exp_name, 'optimizer':self.optimizer} + + # Metrics params + metrics_eval = self.params['metrics']['metrics_eval'] + trainMetrics = self.metrics[config.MODE_TRAIN] + + # KL Divergence Params + kl_alpha = self.kl_alpha_init # [0.0, self.kl_alpha_init] + + # Eval Params + params_eval = {'PROJECT_DIR': config.PROJECT_DIR, 'exp_name': exp_name, 'pid': self.pid + , 'eval_type': config.MODE_TRAIN, 'batch_size': batch_size} + + # Viz params + epochs_save = self.params['model']['epochs_save'] + epochs_viz = self.params['model']['epochs_viz'] + epochs_eval = self.params['model']['epochs_eval'] + + # Random vars + t_start_time = time.time() + + try: + + epoch_step = 0 + epoch = None + pbar = None + t1 = time.time() + for (X,Y,meta1,meta2) in self.dataset_train_gen: + t1_ = time.time() + + try: + # Epoch starter code + if epoch_step == 0: + + # Get Epoch + epoch = next(epoch_range) + + # Metrics + trainMetrics.reset_metrics(self.params) + + # LR + if not fixed_lr: + set_lr(epoch, self.optimizer, init_lr) + self.model.trainable = True + + # Calculate kl_alpha (only for FlipOut model) + if self.params['model']['name'] == config.MODEL_FOCUSNET_FLIPOUT: + if self.kl_schedule == config.KL_DIV_ANNEALING: + if epoch > self.initial_epoch: + if epoch % self.kl_epochs_change == 0: + kl_alpha = tf.math.minimum(self.kl_alpha_max, self.kl_alpha_init + (epoch - self.initial_epoch)/float(self.kl_epochs_change) * self.kl_alpha_increase_per_epoch) + else: + kl_alpha = 0.0 + + # Pretty print + print ('') + print (' ===== [{}] EPOCH:{}/{} (LR={:3f}) =================='.format(exp_name, epoch, max_epoch, self.optimizer.lr.numpy())) + + # Start a fresh pbar + pbar = tqdm.tqdm(total=epoch_length, desc='') + + # Model Writing to tensorboard + if self.params['model']['model_tboard'] and self.write_model_done is False : + self.write_model_done = True + modutils.write_model(self.model, X, self.params) + + # Start/Stop Profiling (after dataloader is kicked off) + self._set_profiler(epoch, epoch_step) + + # Calculate loss and gradients from them + time_predict, time_loss, time_backprop = self._train_step(X, Y, meta1, meta2, kl_alpha) + + # Update metrics (time + eval + plots) + time_dataloader = t1_ - t1 + trainMetrics.update_metrics_time(time_dataloader, time_predict, time_loss, time_backprop) + + # Update looping stuff + epoch_step += batch_size + pbar.update(batch_size) + trainMetrics.update_pbar(pbar) + + except: + modutils.print_exp_name(exp_name + '-' + config.MODE_TRAIN, epoch) + params_save_model['epoch'] = epoch + modutils.save_model(self.model, params_save_model) + traceback.print_exc() + pdb.set_trace() + + if epoch_step >= epoch_length: + + # Reset epoch-loop params + pbar.close() + epoch_step = 0 + + try: + # Model save + if epoch % epochs_save == 0: + params_save_model['epoch'] = epoch + modutils.save_model(self.model, params_save_model) + + # Eval on full 3D + if epoch % epochs_eval == 0: + self.params['epoch'] = epoch + save=False + if epoch > 0 and epoch % epochs_viz == 0: + save=True + + self.model.trainable = False + for metric_str in metrics_eval: + if metrics_eval[metric_str] in [config.LOSS_DICE]: + params_eval['epoch'] = epoch + params_eval['eval_type'] = config.MODE_TRAIN + eval_avg, eval_labels_avg = eval_3D(self.model, self.dataset_train_eval, self.dataset_train_eval_gen, params_eval, save=save) + trainMetrics.update_metric_eval_labels(metric_str, eval_labels_avg, do_average=True) + + # Test + if epoch % epochs_eval == 0: + self._test() + self.model.trainable = True + + # Epochs summary/wrapup + eval_condition = epoch % epochs_eval == 0 + trainMetrics.write_epoch_summary(epoch, self.label_map, {'optimizer':self.optimizer}, eval_condition) + + if epoch > 0 and epoch % self.params['others']['epochs_timer'] == 0: + elapsed_seconds = time.time() - t_start_time + print (' - Total time elapsed : {}'.format( str(datetime.timedelta(seconds=elapsed_seconds)) )) + if epoch % self.params['others']['epochs_memory'] == 0: + mem_before = modutils.get_memory(self.pid) + gc_n = gc.collect() + mem_after = modutils.get_memory(self.pid) + print(' - Unreachable objects collected by GC: {} || ({}) -> ({})'.format(gc_n, mem_before, mem_after)) + + # Break out of loop at end of all epochs + if epoch == max_epoch: + print ('\n\n - [Trainer][train()] All epochs finished') + break + + except: + modutils.print_exp_name(exp_name + '-' + config.MODE_TRAIN, epoch) + params_save_model['epoch'] = epoch + modutils.save_model(self.model, params_save_model) + traceback.print_exc() + pdb.set_trace() + + t1 = time.time() # reset dataloader time calculator + + except: + modutils.print_exp_name(exp_name + '-' + config.MODE_TRAIN, epoch) + traceback.print_exc() + pdb.set_trace() + + def _test(self): + + exp_name = None + epoch = None + try: + + # Step 1.1 - Params + exp_name = self.params['exp_name'] + epoch = self.params['epoch'] + deepsup = self.params['model']['deepsup'] + + metrics_eval = self.params['metrics']['metrics_eval'] + epochs_viz = self.params['model']['epochs_viz'] + batch_size = self.params['dataloader']['batch_size'] + + # vars + testMetrics = self.metrics[config.MODE_TEST] + testMetrics.reset_metrics(self.params) + params_eval = {'PROJECT_DIR': config.PROJECT_DIR, 'exp_name': exp_name, 'pid': self.pid + , 'eval_type': config.MODE_TEST, 'batch_size': batch_size + , 'epoch':epoch} + + # Step 2 - Eval on full 3D + save=False + if epoch > 0 and epoch % epochs_viz == 0: + save=True + for metric_str in metrics_eval: + if metrics_eval[metric_str] in [config.LOSS_DICE]: + params_eval['eval_type'] = config.MODE_TEST + eval_avg, eval_labels_avg = eval_3D(self.model, self.dataset_test_eval, self.dataset_test_eval_gen, params_eval, save=save) + testMetrics.update_metric_eval_labels(metric_str, eval_labels_avg, do_average=True) + + testMetrics.write_epoch_summary(epoch, self.label_map, {}, True) + + except: + modutils.print_exp_name(exp_name + '-' + config.MODE_TEST, epoch) + traceback.print_exc() + pdb.set_trace() + diff --git a/src/model/trainer_flipout.py b/src/model/trainer_flipout.py new file mode 100644 index 0000000..2b45341 --- /dev/null +++ b/src/model/trainer_flipout.py @@ -0,0 +1,2435 @@ +# Import internal libraries +import src.config as config +import src.utils as utils +import src.models as models +import src.losses as losses + +import medloader.dataloader.config as medconfig +import medloader.dataloader.tensorflow.augmentations as aug +from medloader.dataloader.tensorflow.dataset import ZipDataset +from medloader.dataloader.tensorflow.han_miccai2015v2 import HaNMICCAI2015Dataset +from medloader.dataloader.tensorflow.han_deepmindtcia import HaNDeepMindTCIADataset +import medloader.dataloader.utils as medutils + +# Import external libraries +import os +import gc +import pdb +import copy +import time +import tqdm +import datetime +import traceback +import numpy as np +import tensorflow as tf +import tensorflow_probability as tfp +from pathlib import Path + +import matplotlib.pyplot as plt; +import seaborn as sns; + +_EPSILON = tf.keras.backend.epsilon() +DEBUG = False + +############################################################ +# DATALOADER RELATED # +############################################################ + +def get_dataloader_3D_train(data_dir, dir_type=['train', 'train_additional'] + , dimension=3, grid=True, crop_init=True, resampled=True, mask_type=medconfig.MASK_TYPE_ONEHOT + , transforms=[], filter_grid=False, random_grid=True + , parallel_calls=None, deterministic=False + , patient_shuffle=True + , centred_dataloader_prob=0.0 + , debug=False, single_sample=False + , pregridnorm=True): + + debug = False + datasets = [] + + # Dataset 1 + for dir_type_ in dir_type: + + # Step 1 - Get dataset class + dataset_han_miccai2015 = HaNMICCAI2015Dataset(data_dir=data_dir, dir_type=dir_type_ + , dimension=dimension, grid=grid, crop_init=crop_init, resampled=resampled, mask_type=mask_type + , transforms=transforms, filter_grid=filter_grid, random_grid=random_grid + , parallel_calls=parallel_calls, deterministic=deterministic + , patient_shuffle=patient_shuffle + , centred_dataloader_prob=centred_dataloader_prob + , debug=debug, single_sample=single_sample + , pregridnorm=pregridnorm) + + # Step 2 - Training transforms + x_shape_w = dataset_han_miccai2015.w_grid + x_shape_h = dataset_han_miccai2015.h_grid + x_shape_d = dataset_han_miccai2015.d_grid + label_map = dataset_han_miccai2015.LABEL_MAP + HU_MIN = dataset_han_miccai2015.HU_MIN + HU_MAX = dataset_han_miccai2015.HU_MAX + transforms = [ + # minmax + aug.Rotate3DSmall(label_map, mask_type) + , aug.Deform2Punt5D((x_shape_h, x_shape_w, x_shape_d), label_map, grid_points=50, stddev=4, div_factor=2, debug=False) + , aug.Translate(label_map, translations=[40,40]) + , aug.Noise(x_shape=(x_shape_h, x_shape_w, x_shape_d, 1), mean=0.0, std=0.1) + ] + dataset_han_miccai2015.transforms = transforms + + # Step 3 - Training filters for background-only grids + if filter_grid: + dataset_han_miccai2015.filter = aug.FilterByMask(len(dataset_han_miccai2015.LABEL_MAP), dataset_han_miccai2015.SAMPLER_PERC) + + # Step 4 - Append to list + datasets.append(dataset_han_miccai2015) + + dataset = ZipDataset(datasets) + return dataset + +def get_dataloader_3D_train_eval(data_dir, dir_type=['train_train_additional'] + , dimension=3, grid=True, crop_init=True, resampled=True, mask_type=medconfig.MASK_TYPE_ONEHOT + , transforms=[], filter_grid=False, random_grid=False + , parallel_calls=None, deterministic=True + , patient_shuffle=False + , centred_dataloader_prob=0.0 + , debug=False, single_sample=False + , pregridnorm=True): + + datasets = [] + + # Dataset 1 + for dir_type_ in dir_type: + + # Step 1 - Get dataset class + dataset_han_miccai2015 = HaNMICCAI2015Dataset(data_dir=data_dir, dir_type=dir_type_ + , dimension=dimension, grid=grid, crop_init=crop_init, resampled=resampled, mask_type=mask_type + , transforms=transforms, filter_grid=filter_grid, random_grid=random_grid + , parallel_calls=parallel_calls, deterministic=deterministic + , patient_shuffle=patient_shuffle + , centred_dataloader_prob=centred_dataloader_prob + , debug=debug, single_sample=single_sample + , pregridnorm=pregridnorm) + + # Step 2 - Training transforms + # None + + # Step 3 - Training filters for background-only grids + # None + + # Step 4 - Append to list + datasets.append(dataset_han_miccai2015) + + dataset = ZipDataset(datasets) + return dataset + +def get_dataloader_3D_test_eval(data_dir, dir_type=['test_offsite'] + , dimension=3, grid=True, crop_init=True, resampled=True, mask_type=medconfig.MASK_TYPE_ONEHOT + , transforms=[], filter_grid=False, random_grid=False + , parallel_calls=None, deterministic=True + , patient_shuffle=False + , debug=False, single_sample=False + , pregridnorm=True): + + datasets = [] + + # Dataset 1 + for dir_type_ in dir_type: + + # Step 1 - Get dataset class + dataset_han_miccai2015 = HaNMICCAI2015Dataset(data_dir=data_dir, dir_type=dir_type_ + , dimension=dimension, grid=grid, crop_init=crop_init, resampled=resampled, mask_type=mask_type + , transforms=transforms, filter_grid=filter_grid, random_grid=random_grid + , parallel_calls=parallel_calls, deterministic=deterministic + , patient_shuffle=patient_shuffle + , debug=debug, single_sample=single_sample + , pregridnorm=pregridnorm) + + # Step 2 - Testing transforms + # None + + # Step 3 - Testing filters for background-only grids + # None + + # Step 4 - Append to list + datasets.append(dataset_han_miccai2015) + + dataset = ZipDataset(datasets) + return dataset + +def get_dataloader_deepmindtcia(data_dir + , dir_type=[medconfig.DATALOADER_DEEPMINDTCIA_TEST] + , annotator_type=[medconfig.DATALOADER_DEEPMINDTCIA_ONC] + , grid=True, crop_init=True, resampled=True, mask_type=medconfig.MASK_TYPE_ONEHOT + , transforms=[], filter_grid=False, random_grid=False, pregridnorm=True + , parallel_calls=None, deterministic=False + , patient_shuffle=True + , centred_dataloader_prob = 0.0 + , debug=False, single_sample=False): + + datasets = [] + + # Dataset 1 + for dir_type_ in dir_type: + + for anno_type_ in annotator_type: + + # Step 1 - Get dataset class + dataset_han_deepmindtcia = HaNDeepMindTCIADataset(data_dir=data_dir, dir_type=dir_type_, annotator_type=anno_type_ + , grid=grid, crop_init=crop_init, resampled=resampled, mask_type=mask_type, pregridnorm=pregridnorm + , transforms=transforms, filter_grid=filter_grid, random_grid=random_grid + , parallel_calls=parallel_calls, deterministic=deterministic + , patient_shuffle=patient_shuffle + , centred_dataloader_prob=centred_dataloader_prob + , debug=debug, single_sample=single_sample) + + # Step 2 - Append to list + datasets.append(dataset_han_deepmindtcia) + + dataset = ZipDataset(datasets) + return dataset + +############################################################ +# MODEL RELATED # +############################################################ + +def set_lr(epoch, optimizer, init_lr): + + # if epoch == 1: # for models that are preloaded from another model + # print (' - [set_lr()] Setting optimizer lr to ', init_lr) + # optimizer.lr.assign(init_lr) + + if epoch > 1 and epoch % 20 == 0: + optimizer.lr.assign(optimizer.lr * 0.98) + +############################################################ +# METRICS RELATED # +############################################################ +class ModelMetrics(): + + def __init__(self, metric_type, params): + + self.params = params + + self.label_map = params['internal']['label_map'] + self.label_ids = params['internal']['label_ids'] + self.logging_tboard = params['metrics']['logging_tboard'] + self.metric_type = metric_type + + self.losses_obj = self.get_losses_obj(params) + self.metrics_layers_kl_divergence = {} # empty for now + + self.init_metrics(params) + if self.logging_tboard: + self.init_tboard_writers(params) + + self.reset_metrics(params) + self.init_epoch0() + self.reset_metrics(params) + + def get_losses_obj(self, params): + losses_obj = {} + for loss_key in params['metrics']['metrics_loss']: + if config.LOSS_DICE == params['metrics']['metrics_loss'][loss_key]: + losses_obj[loss_key] = losses.loss_dice_3d_tf_func + if config.LOSS_FOCAL == params['metrics']['metrics_loss'][loss_key]: + losses_obj[loss_key] = losses.loss_focal_3d_tf_func + if config.LOSS_CE == params['metrics']['metrics_loss'][loss_key]: + losses_obj[loss_key] = losses.loss_ce_3d_tf_func + if config.LOSS_NCC == params['metrics']['metrics_loss'][loss_key]: + losses_obj[loss_key] = losses.loss_ncc_numpy + if config.LOSS_PAVPU == params['metrics']['metrics_loss'][loss_key]: + losses_obj[loss_key] = losses.loss_avu_3d_tf_func + + return losses_obj + + def init_metrics(self, params): + """ + These are metrics derived from tensorflows library + """ + # Metrics for losses (during training for smaller grids) + self.metrics_loss_obj = {} + for metric_key in params['metrics']['metrics_loss']: + self.metrics_loss_obj[metric_key] = {} + self.metrics_loss_obj[metric_key]['total'] = tf.keras.metrics.Mean(name='Avg{}-{}'.format(metric_key, self.metric_type)) + if params['metrics']['metrics_loss'][metric_key] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_FOCAL]: + for label_id in self.label_ids: + self.metrics_loss_obj[metric_key][label_id] = tf.keras.metrics.Mean(name='Avg{}-Label-{}-{}'.format(metric_key, label_id, self.metric_type)) + + # Metrics for eval (for full 3D volume) + self.metrics_eval_obj = {} + for metric_key in params['metrics']['metrics_eval']: + self.metrics_eval_obj[metric_key] = {} + self.metrics_eval_obj[metric_key]['total'] = tf.keras.metrics.Mean(name='Avg{}-{}'.format(metric_key, self.metric_type)) + if params['metrics']['metrics_eval'][metric_key] in [config.LOSS_DICE]: + for label_id in self.label_ids: + self.metrics_eval_obj[metric_key][label_id] = tf.keras.metrics.Mean(name='Avg{}-Label-{}-{}'.format(metric_key, label_id, self.metric_type)) + + # Time Metrics + self.metric_time_dataloader = tf.keras.metrics.Mean(name='AvgTime-Dataloader-{}'.format(self.metric_type)) + self.metric_time_model_predict = tf.keras.metrics.Mean(name='AvgTime-ModelPredict-{}'.format(self.metric_type)) + self.metric_time_model_loss = tf.keras.metrics.Mean(name='AvgTime-ModelLoss-{}'.format(self.metric_type)) + self.metric_time_model_backprop = tf.keras.metrics.Mean(name='AvgTime-ModelBackProp-{}'.format(self.metric_type)) + + # FlipOut Metrics + self.metric_kl_alpha = tf.keras.metrics.Mean(name='KL-Alpha') + self.metric_kl_divergence = tf.keras.metrics.Mean(name='KL-Divergence') + + # Scalar Losses + self.metric_scalarloss_data = tf.keras.metrics.Mean(name='ScalarLoss-Data') + self.metric_scalarloss_reg = tf.keras.metrics.Mean(name='ScalarLoss-Reg') + + def init_metrics_layers_kl_std(self, params, layers_kl_std): + + self.metrics_layers_kl_divergence = {} + self.tboard_layers_kl_divergence = {} + self.writer_tboard_layers_std = {} + self.writer_tboard_layers_mean = {} + for layer_name in layers_kl_std: + self.metrics_layers_kl_divergence[layer_name] = tf.keras.metrics.Mean(name='KL-Divergence-{}'.format(layer_name)) + self.metrics_layers_kl_divergence[layer_name].update_state(0) + self.tboard_layers_kl_divergence[layer_name] = utils.get_tensorboard_writer(params['exp_name'], suffix='KL-Divergence-Layer-{}'.format(layer_name)) + utils.make_summary('BayesLossExtras/FlipOut/KLDivergence-{}'.format(layer_name), 0, writer1=self.tboard_layers_kl_divergence[layer_name], value1=self.metrics_layers_kl_divergence[layer_name].result()) + self.metrics_layers_kl_divergence[layer_name].reset_states() + + if 'std' in layers_kl_std[layer_name]: + keyname = layer_name + '-std' + self.writer_tboard_layers_std[keyname] = utils.get_tensorboard_writer(params['exp_name'], suffix='Std-Layer-{}'.format(keyname)) + utils.make_summary_hist('Std/{}'.format(keyname), 0, writer1=self.writer_tboard_layers_std[keyname], value1=layers_kl_std[layer_name]['std']) + + if 'mean' in layers_kl_std[layer_name]: # keep this between [-2,+2] for better visualization in tf.summary.histogram() + keyname = layer_name + '-mean' + self.writer_tboard_layers_mean[keyname] = utils.get_tensorboard_writer(params['exp_name'], suffix='Mean-Layer-{}'.format(keyname)) + mean_vals = layers_kl_std[layer_name]['mean'].numpy() + mean_vals = mean_vals[mean_vals >= -2] + mean_vals = mean_vals[mean_vals <= 2] + utils.make_summary_hist('Mean/{}'.format(keyname), 0, writer1=self.writer_tboard_layers_mean[keyname], value1=mean_vals) + + def reset_metrics(self, params): + + # Metrics for losses (during training for smaller grids) + for metric_key in params['metrics']['metrics_loss']: + self.metrics_loss_obj[metric_key]['total'].reset_states() + if params['metrics']['metrics_loss'][metric_key] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_FOCAL]: + for label_id in self.label_ids: + self.metrics_loss_obj[metric_key][label_id].reset_states() + + # Metrics for eval (for full 3D volume) + for metric_key in params['metrics']['metrics_eval']: + self.metrics_eval_obj[metric_key]['total'].reset_states() + if params['metrics']['metrics_eval'][metric_key] in [config.LOSS_DICE]: + for label_id in self.label_ids: + self.metrics_eval_obj[metric_key][label_id].reset_states() + + # Time Metrics + self.metric_time_dataloader.reset_states() + self.metric_time_model_predict.reset_states() + self.metric_time_model_loss.reset_states() + self.metric_time_model_backprop.reset_states() + + # FlipOut Metrics + self.metric_kl_alpha.reset_states() + self.metric_kl_divergence.reset_states() + + # Scalar Losses + self.metric_scalarloss_data.reset_states() + self.metric_scalarloss_reg.reset_states() + + # FlipOut-Layers + for layer_name in self.metrics_layers_kl_divergence: + self.metrics_layers_kl_divergence[layer_name].reset_states() + + def init_tboard_writers(self, params): + """ + These are tensorboard writer + """ + # Writers for loss (during training for smaller grids) + self.writers_loss_obj = {} + for metric_key in params['metrics']['metrics_loss']: + self.writers_loss_obj[metric_key] = {} + self.writers_loss_obj[metric_key]['total'] = utils.get_tensorboard_writer(params['exp_name'], suffix=self.metric_type + '-Loss') + if params['metrics']['metrics_loss'][metric_key] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_FOCAL]: + for label_id in self.label_ids: + self.writers_loss_obj[metric_key][label_id] = utils.get_tensorboard_writer(params['exp_name'], suffix=self.metric_type + '-Loss-' + str(label_id)) + + # Writers for eval (for full 3D volume) + self.writers_eval_obj = {} + for metric_key in params['metrics']['metrics_eval']: + self.writers_eval_obj[metric_key] = {} + self.writers_eval_obj[metric_key]['total'] = utils.get_tensorboard_writer(params['exp_name'], suffix=self.metric_type + '-Eval') + if params['metrics']['metrics_eval'][metric_key] in [config.LOSS_DICE]: + for label_id in self.label_ids: + self.writers_eval_obj[metric_key][label_id] = utils.get_tensorboard_writer(params['exp_name'], suffix=self.metric_type + '-Eval-' + str(label_id)) + + # Time and other writers + self.writer_lr = utils.get_tensorboard_writer(params['exp_name'], suffix='LR') + self.writer_time_dataloader = utils.get_tensorboard_writer(params['exp_name'], suffix='Time-Dataloader') + self.writer_time_model_predict = utils.get_tensorboard_writer(params['exp_name'], suffix='Time-Model-Predict') + self.writer_time_model_loss = utils.get_tensorboard_writer(params['exp_name'], suffix='Time-Model-Loss') + self.writer_time_model_backprop = utils.get_tensorboard_writer(params['exp_name'], suffix='Time-Model-Backprop') + + # FlipOut writers + self.writer_kl_alpha = utils.get_tensorboard_writer(params['exp_name'], suffix='KL-Alpha') + self.writer_kl_divergence = utils.get_tensorboard_writer(params['exp_name'], suffix='KL-Divergence') + + # Scalar Losses + self.writer_scalarloss_data = utils.get_tensorboard_writer(params['exp_name'], suffix='ScalarLoss-Data') + self.writer_scalarloss_reg = utils.get_tensorboard_writer(params['exp_name'], suffix='ScalarLoss-Reg') + + def init_epoch0(self): + + for metric_str in self.metrics_loss_obj: + self.update_metric_loss(metric_str, 1e-6) + if params['metrics']['metrics_loss'][metric_str] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_FOCAL]: + self.update_metric_loss_labels(metric_str, {label_id: 1e-6 for label_id in self.label_ids}) + + for metric_str in self.metrics_eval_obj: + self.update_metric_eval_labels(metric_str, {label_id: 0 for label_id in self.label_ids}) + + self.update_metrics_time(time_dataloader=0, time_predict=0, time_loss=0, time_backprop=0) + self.update_metrics_kl(kl_alpha=0, kl_divergence=0, kl_divergence_layers={}) + self.update_metrics_scalarloss(loss_data=0, loss_reg=0) + + self.write_epoch_summary(epoch=0, label_map=self.label_map, params=None, eval_condition=True) + + @tf.function + def update_metrics_kl(self, kl_alpha, kl_divergence, kl_divergence_layers): + self.metric_kl_alpha.update_state(kl_alpha) + self.metric_kl_divergence.update_state(kl_divergence) + + for layer_name in kl_divergence_layers: + if layer_name in self.metrics_layers_kl_divergence: + self.metrics_layers_kl_divergence[layer_name].update_state(kl_divergence_layers[layer_name]['kl']) + + @tf.function + def update_metrics_scalarloss(self, loss_data, loss_reg): + self.metric_scalarloss_data.update_state(loss_data) + self.metric_scalarloss_reg.update_state(loss_reg) + + def update_metrics_time(self, time_dataloader, time_predict, time_loss, time_backprop): + if time_dataloader is not None: + self.metric_time_dataloader.update_state(time_dataloader) + if time_predict is not None: + self.metric_time_model_predict.update_state(time_predict) + if time_loss is not None: + self.metric_time_model_loss.update_state(time_loss) + if time_backprop is not None: + self.metric_time_model_backprop.update_state(time_backprop) + + def update_metric_loss(self, metric_str, metric_val): + # Metrics for losses (during training for smaller grids) + self.metrics_loss_obj[metric_str]['total'].update_state(metric_val) + + @tf.function + def update_metric_loss_labels(self, metric_str, metric_vals_labels): + # Metrics for losses (during training for smaller grids) + + for label_id in self.label_ids: + if metric_vals_labels[label_id] > 0.0: + self.metrics_loss_obj[metric_str][label_id].update_state(metric_vals_labels[label_id]) + + def update_metric_eval(self, metric_str, metric_val): + # Metrics for eval (for full 3D volume) + self.metrics_eval_obj[metric_str]['total'].update_state(metric_val) + + def update_metric_eval_labels(self, metric_str, metric_vals_labels, do_average=False): + # Metrics for eval (for full 3D volume) + + try: + metric_avg = [] + for label_id in self.label_ids: + if label_id in metric_vals_labels: + if metric_vals_labels[label_id] > 0: + self.metrics_eval_obj[metric_str][label_id].update_state(metric_vals_labels[label_id]) + if do_average: + if label_id > 0: + metric_avg.append(metric_vals_labels[label_id]) + + if do_average: + if len(metric_avg): + self.metrics_eval_obj[metric_str]['total'].update_state(np.mean(metric_avg)) + + except: + traceback.print_exc() + + def write_epoch_summary(self, epoch, label_map, params=None, eval_condition=False): + + if self.logging_tboard: + # Metrics for losses (during training for smaller grids) + for metric_str in self.metrics_loss_obj: + utils.make_summary('Loss/{}'.format(metric_str), epoch, writer1=self.writers_loss_obj[metric_str]['total'], value1=self.metrics_loss_obj[metric_str]['total'].result()) + if self.params['metrics']['metrics_loss'][metric_str] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_FOCAL]: + if len(self.metrics_loss_obj[metric_str]) > 1: # i.e. has label ids + for label_id in self.label_ids: + label_name, _ = utils.get_info_from_label_id(label_id, label_map) + utils.make_summary('Loss/{}/{}'.format(metric_str, label_name), epoch, writer1=self.writers_loss_obj[metric_str][label_id], value1=self.metrics_loss_obj[metric_str][label_id].result()) + + # Metrics for eval (for full 3D volume) + if eval_condition: + for metric_str in self.metrics_eval_obj: + utils.make_summary('Eval3D/{}'.format(metric_str), epoch, writer1=self.writers_eval_obj[metric_str]['total'], value1=self.metrics_eval_obj[metric_str]['total'].result()) + if len(self.metrics_eval_obj[metric_str]) > 1: # i.e. has label ids + for label_id in self.label_ids: + label_name, _ = utils.get_info_from_label_id(label_id, label_map) + utils.make_summary('Eval3D/{}/{}'.format(metric_str, label_name), epoch, writer1=self.writers_eval_obj[metric_str][label_id], value1=self.metrics_eval_obj[metric_str][label_id].result()) + + # Time Metrics + utils.make_summary('Info/Time/Dataloader' , epoch, writer1=self.writer_time_dataloader , value1=self.metric_time_dataloader.result()) + utils.make_summary('Info/Time/ModelPredict' , epoch, writer1=self.writer_time_model_predict , value1=self.metric_time_model_predict.result()) + utils.make_summary('Info/Time/ModelLoss' , epoch, writer1=self.writer_time_model_loss , value1=self.metric_time_model_loss.result()) + utils.make_summary('Info/Time/ModelBackProp', epoch, writer1=self.writer_time_model_backprop, value1=self.metric_time_model_backprop.result()) + + # FlipOut Metrics + utils.make_summary('BayesLoss/FlipOut/KLAlpha' , epoch, writer1=self.writer_kl_alpha , value1=self.metric_kl_alpha.result()) + utils.make_summary('BayesLoss/FlipOut/KLDivergence' , epoch, writer1=self.writer_kl_divergence , value1=self.metric_kl_divergence.result()) + for layer_name in self.metrics_layers_kl_divergence: + utils.make_summary('BayesLossExtras/FlipOut/KLDivergence-{}'.format(layer_name), epoch, writer1=self.tboard_layers_kl_divergence[layer_name], value1=self.metrics_layers_kl_divergence[layer_name].result()) + + # Scalar Loss Metrics + utils.make_summary('BayesLoss/FlipOut/ScalarLossData' , epoch, writer1=self.writer_scalarloss_data , value1=self.metric_scalarloss_data.result()) + utils.make_summary('BayesLoss/FlipOut/ScalarLossReg' , epoch, writer1=self.writer_scalarloss_reg , value1=self.metric_scalarloss_reg.result()) + + # Learning Rate + if params is not None: + if 'optimizer' in params: + utils.make_summary('Info/LR', epoch, writer1=self.writer_lr, value1=params['optimizer'].lr) + + def write_epoch_summary_std(self, layers_kl_std, epoch): + + for layer_name in layers_kl_std: + + if 'std' in layers_kl_std[layer_name]: + keyname = layer_name + '-std' + utils.make_summary_hist('Std/{}'.format(keyname), epoch, writer1=self.writer_tboard_layers_std[keyname], value1=layers_kl_std[layer_name]['std']) + + if 'mean' in layers_kl_std[layer_name]: + keyname = layer_name + '-mean' + mean_vals = layers_kl_std[layer_name]['mean'].numpy() + mean_vals = mean_vals[mean_vals >= -2] + mean_vals = mean_vals[mean_vals <= 2] + utils.make_summary_hist('Mean/{}'.format(keyname), epoch, writer1=self.writer_tboard_layers_mean[keyname], value1=mean_vals) + + def update_pbar(self, pbar): + desc_str = '' + + # Metrics for losses (during training for smaller grids) + # Metrics for losses (during training for smaller grids) + if config.LOSS_PURATIO in self.metrics_loss_obj or config.LOSS_PAVPU in self.metrics_loss_obj: + for metric_str in self.metrics_loss_obj: + if metric_str in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_FOCAL, config.LOSS_PURATIO, config.LOSS_PAVPU]: + result = self.metrics_loss_obj[metric_str]['total'].result().numpy() + loss_text = '{}:{:2f},'.format(metric_str, result) + desc_str += loss_text + else: + for metric_str in self.metrics_loss_obj: + if len(desc_str): desc_str += ',' + + if metric_str in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_FOCAL]: + metric_avg = [] + for label_id in self.label_ids: + if label_id > 0: + label_result = self.metrics_loss_obj[metric_str][label_id].result().numpy() + if label_result > 0: + metric_avg.append(label_result) + + mean_val = 0 + if len(metric_avg): + mean_val = np.mean(metric_avg) + loss_text = '{}Loss:{:2f}'.format(metric_str, mean_val) + desc_str += loss_text + + # GPU Memory + if 1: + try: + if len(self.metrics_loss_obj) > 1: + desc_str = desc_str[:-1] # to remove the extra ',' + desc_str += ',' + str(utils.get_tf_gpu_memory()) + except: + pass + + pbar.set_description(desc=desc_str, refresh=True) + +def eval_3D_finalize(exp_name, patient_img, patient_gt, patient_pred_processed, patient_pred, patient_pred_std, patient_pred_ent, patient_pred_mif, patient_pred_unc, patient_pred_error + , patient_id_curr + , model_folder_epoch_imgs, model_folder_epoch_patches + , loss_labels_val, hausdorff_labels_val, hausdorff95_labels_val, msd_labels_vals + , spacing, label_map, label_colors + , show=False, save=False): + + try: + + # Step 3.1.2 - Vizualize + if show: + if len(patient_pred_std): + maskpred_std = np.max(patient_pred_std, axis=-1) + maskpred_std = np.expand_dims(maskpred_std, axis=-1) + maskpred_std = np.repeat(maskpred_std, repeats=10, axis=-1) + + maskpred_ent = np.expand_dims(patient_pred_ent, axis=-1) # [H,W,D] --> [H,W,D,1] + maskpred_ent = np.repeat(maskpred_ent, repeats=10, axis=-1) # [H,W,D,1] --> [H,W,D,10] + + maskpred_mif = np.expand_dims(patient_pred_mif, axis=-1) # [H,W,D] --> [H,W,D,1] + maskpred_mif = np.repeat(maskpred_mif, repeats=10, axis=-1) # [H,W,D,1] --> [H,W,D,10] + + if 1: + print (' - patient_id_curr: ', patient_id_curr) + f,axarr = plt.subplots(1,2) + axarr[0].hist(maskpred_ent[:,:,:,0].flatten(), bins=30) + axarr[0].set_title('Entropy') + axarr[1].hist(maskpred_mif[:,:,:,0].flatten(), bins=30) + axarr[1].set_title('MutInf') + plt.suptitle('Exp: {}\nPatient:{}'.format(exp_name, patient_id_curr)) + plt.show() + pdb.set_trace() + + utils.viz_model_output_3d(exp_name, patient_img, patient_gt, patient_pred, maskpred_std, patient_id_curr, model_folder_epoch_imgs, label_map, label_colors + , vmax_unc=0.06, unc_title='Predictive Std', unc_savesufix='stdmax') + + utils.viz_model_output_3d(exp_name, patient_img, patient_gt, patient_pred, maskpred_ent, patient_id_curr, model_folder_epoch_imgs, label_map, label_colors + , vmax_unc=1.2, unc_title='Predictive Entropy', unc_savesufix='ent') + + utils.viz_model_output_3d(exp_name, patient_img, patient_gt, patient_pred, maskpred_mif, patient_id_curr, model_folder_epoch_imgs, label_map, label_colors + , vmax_unc=0.06, unc_title='Mutual Information', unc_savesufix='mif') + + # Step 3.1.3 - Save 3D grid to visualize in 3D Slicer (drag-and-drop mechanism) + if save: + + import medloader.dataloader.utils as medutils + medutils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_img.nrrd' , patient_img[:,:,:,0], spacing) + medutils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_mask.nrrd', np.argmax(patient_gt, axis=3),spacing) + maskpred_labels = np.argmax(patient_pred_processed, axis=3) # not "np.argmax(patient_pred, axis=3)" since it does not contain any postprocessing + medutils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpred.nrrd', maskpred_labels, spacing) + + maskpred_labels_probmean = np.take_along_axis(patient_pred, np.expand_dims(maskpred_labels,axis=-1), axis=-1)[:,:,:,0] + medutils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredmean.nrrd', maskpred_labels_probmean, spacing) + + if np.sum(patient_pred_std): + maskpred_labels_std = np.take_along_axis(patient_pred_std, np.expand_dims(maskpred_labels,axis=-1), axis=-1)[:,:,:,0] + medutils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredstd.nrrd', maskpred_labels_std, spacing) + + maskpred_std_max = np.max(patient_pred_std, axis=-1) + medutils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredstdmax.nrrd', maskpred_std_max, spacing) + + if np.sum(patient_pred_ent): + maskpred_ent = patient_pred_ent # [H,W,D] + medutils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredent.nrrd', maskpred_ent, spacing) + + if np.sum(patient_pred_mif): + maskpred_mif = patient_pred_mif + medutils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredmif.nrrd', maskpred_mif, spacing) + + if np.sum(patient_pred_unc): + if len(patient_pred_unc.shape) == 4: + maskpred_labels_unc = np.take_along_axis(patient_pred_unc, np.expand_dims(maskpred_labels,axis=-1), axis=-1)[:,:,:,0] # [H,W,D,C] --> [H,W,D] + else: + maskpred_labels_unc = patient_pred_unc + medutils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskpredunc.nrrd', maskpred_labels_unc, spacing) + + if np.sum(patient_pred_error): + medutils.write_nrrd(str(Path(model_folder_epoch_patches).joinpath('nrrd_' + patient_id_curr)) + '_maskprederror.nrrd', patient_pred_error, spacing) + + try: + # Step 3.1.3.2 - PLot results for that patient + f, axarr = plt.subplots(3,1, figsize=(15,10)) + boxplot_dice, boxplot_hausdorff, boxplot_hausdorff95 = {}, {}, {} + boxplot_dice_mean_list = [] + for label_id in range(len(loss_labels_val)): + label_name, _ = utils.get_info_from_label_id(label_id, label_map) + boxplot_dice[label_name] = [loss_labels_val[label_id]] + boxplot_hausdorff[label_name] = [hausdorff_labels_val[label_id]] + boxplot_hausdorff95[label_name] = [hausdorff95_labels_val[label_id]] + if label_id > 0 and loss_labels_val[label_id] > 0: + boxplot_dice_mean_list.append(loss_labels_val[label_id]) + + axarr[0].boxplot(boxplot_dice.values()) + axarr[0].set_xticks(range(1, len(boxplot_dice)+1)) + axarr[0].set_xticklabels(boxplot_dice.keys()) + axarr[0].set_ylim([0.0,1.1]) + axarr[0].grid() + axarr[0].set_title('DICE - Avg: {} \n w/o chiasm: {}'.format( + '%.4f' % (np.mean(boxplot_dice_mean_list)) + , '%.4f' % (np.mean(boxplot_dice_mean_list[0:1] + boxplot_dice_mean_list[2:])) # avoid label_id=2 + ) + ) + + axarr[1].boxplot(boxplot_hausdorff.values()) + axarr[1].set_xticks(range(1,len(boxplot_hausdorff)+1)) + axarr[1].set_xticklabels(boxplot_hausdorff.keys()) + axarr[1].grid() + axarr[1].set_title('Hausdorff') + + axarr[2].boxplot(boxplot_hausdorff95.values()) + axarr[2].set_xticks(range(1,len(boxplot_hausdorff95)+1)) + axarr[2].set_xticklabels(boxplot_hausdorff95.keys()) + axarr[2].set_title('95% Hausdorff') + axarr[2].grid() + + plt.savefig(str(Path(model_folder_epoch_patches).joinpath('results_' + patient_id_curr + '.png')), bbox_inches='tight') # , bbox_inches='tight' + plt.close() + + except: + traceback.print_exc() + + except: + pdb.set_trace() + traceback.print_exc() + +def get_ece(y_true, y_predict, patient_id, res_global, verbose=False): + """ + Params + ------ + y_true : [H,W,D,C], np.array, binary + y_predict: [H,W,D,C], np.array, with softmax probability values + - Ref: https://github.com/sirius8050/Expected-Calibration-Error/blob/master/ECE.py + : On Calibration of Modern Neural Networks + - Ref(future): https://github.com/yding5/AdaptiveBinning + : Revisiting the evaluation of uncertainty estimation and its application to explore model complexity-uncertainty trade-off + """ + res = {} + nan_value = -0.1 + + if verbose: print (' - [get_ece()] patient_id: ', patient_id) + + try: + + # Step 0 - Init + label_count = y_true.shape[-1] + + # Step 1 - Calculate o_predict + o_true = np.argmax(y_true, axis=-1) + o_predict = np.argmax(y_predict, axis=-1) + + # Step 2 - Loop over different classes + for label_id in range(label_count): + + if label_id > -1: + + if verbose: print (' --- [get_ece()] label_id: ', label_id) + + if label_id not in res_global: res_global[label_id] = {'o_predict_label':[], 'y_predict_label':[], 'o_true_label':[]} + + # Step 2.1 - Get o_predict_label(label_ids), o_true_label(label_ids), y_predict_label(probs) [and append to global list] + ## NB: You are considering TP + FP here + o_true_label = o_true[o_predict == label_id] + o_predict_label = o_predict[o_predict == label_id] + y_predict_label = y_predict[:,:,:,label_id][o_predict == label_id] + res_global[label_id]['o_true_label'].extend(o_true_label.flatten().tolist()) + res_global[label_id]['o_predict_label'].extend(o_predict_label.flatten().tolist()) + res_global[label_id]['y_predict_label'].extend(y_predict_label.flatten().tolist()) + + if len(o_true_label) and len(y_predict_label): + + # Step 2.2 - Bin the probs and calculate their mean + y_predict_label_bin_ids = np.digitize(y_predict_label, np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.01]), right=False) - 1 + y_predict_binned_vals = [y_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)] + y_predict_bins_mean = [np.mean(vals) if len(vals) else nan_value for vals in y_predict_binned_vals] + + # Step 2.3 - Calculate the accuracy of each bin + o_predict_label_bins = [o_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)] + o_true_label_bins = [o_true_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)] + y_predict_bins_accuracy = [np.sum(o_predict_label_bins[bin_id] == o_true_label_bins[bin_id])/len(o_predict_label_bins[bin_id]) if len(o_predict_label_bins[bin_id]) else nan_value for bin_id in range(label_count)] + y_predict_bins_len = [len(o_predict_label_bins[bin_id]) for bin_id in range(label_count)] + + # Step 2.4 - Wrapup + N = np.prod(y_predict_label.shape) + ce = np.array((np.array(y_predict_bins_len)/N)*(np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean))) + ce[ce == 0] = nan_value # i.e. y_predict_bins_accuracy[bin_id] == y_predict_bins_mean[bind_id] = nan_value + res[label_id] = ce + + else: + res[label_id] = -1 + + if 0: + if label_id == 1: + diff = np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean) + print (' - [get_ece()][BStem] diff: ', ['%.4f' % each for each in diff]) + + # NB: This considers the whole volume + o_true_label = y_true[:,:,:,label_id] # [1=this label, 0=other label] + o_predict_label = np.array(o_predict, copy=True) + o_predict_label[o_predict_label != label_id] = 0 + o_predict_label[o_predict_label == label_id] = 1 # [1 - predicted this label, 0 = predicted other label] + y_predict_label = y_predict[:,:,:,label_id] + + # Step x.2 - Bin the probs and calculate their mean + y_predict_label_bin_ids = np.digitize(y_predict_label, np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.01]), right=False) - 1 + y_predict_binned_vals = [y_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)] + y_predict_bins_mean = [np.mean(vals) if len(vals) else nan_value for vals in y_predict_binned_vals] + + # Step x.3 - Calculate the accuracy of each bin + o_predict_label_bins = [o_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)] + o_true_label_bins = [o_true_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)] + y_predict_bins_accuracy = [np.sum(o_predict_label_bins[bin_id] == o_true_label_bins[bin_id])/len(o_predict_label_bins[bin_id]) if len(o_predict_label_bins[bin_id]) else nan_value for bin_id in range(label_count)] + y_predict_bins_len = [len(o_predict_label_bins[bin_id]) for bin_id in range(label_count)] + + N_new = np.prod(y_predict_label.shape) + diff_new = np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean) + ce_new = np.array((np.array(y_predict_bins_len)/N_new)*(np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean))) + print (' - [get_ece()][BStem] diff_new: ', ['%.4f' % each for each in diff_new]) + + pdb.set_trace() + + if verbose: + print (' --- [get_ece()] y_predict_bins_accuracy: ', ['%.4f' % (each) for each in np.array(y_predict_bins_accuracy)]) + print (' --- [get_ece()] CE : ', ['%.4f' % (each) for each in np.array(res[label_id])]) + print (' --- [get_ece()] ECE: ', np.sum(np.abs(res[label_id][res[label_id] != nan_value]))) + + # Prob bars + # plt.hist(y_predict[:,:,:,label_id].flatten(), bins=10) + # plt.title('Softmax Probs (label={})\nPatient:{}'.format(label_id, patient_id)) + # plt.show() + + # GT Prob bars + # plt.bar(np.arange(len(y_predict_bins_len))/10.0 + 0.1, y_predict_bins_len, width=0.05) + # plt.title('Softmax Probs (GT) (label={})\nPatient:{}'.format(label_id, patient_id)) + # plt.xlabel('Probabilities') + # plt.show() + + # GT Probs (sorted) in plt.plot (with equally-spaced bins) + # from collections import Counter + # tmp = np.sort(y_predict_label) + # plt.plot(range(len(tmp)), tmp, color='orange') + # tmp_bins = np.digitize(tmp, np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.01])) - 1 + # tmp_bins_len = Counter(tmp_bins) + # boundary_start = 0 + # plt.plot([0,0],[0.0,1.0], color='black', alpha=0.5, linestyle='dashed', label='Bins(equally-spaced)') + # for boundary in np.arange(0,len(tmp_bins_len)): plt.plot([boundary_start+tmp_bins_len[boundary], boundary_start+tmp_bins_len[boundary]], [0.0,1.0], color='black', alpha=0.5, linestyle='dashed'); boundary_start+=tmp_bins_len[boundary] + # plt.title('Sorted Softmax Probs (GT) (label={})\nPatient:{}'.format(label_id, patient_id)) + # plt.legend() + # plt.show() + + # GT Probs (sorted) in plt.plot (with equally-sized bins) + if label_id == 1: + Path('tmp').mkdir(parents=True, exist_ok=True) + tmp = np.sort(y_predict_label) + tmp_len = len(tmp) + plt.plot(range(len(tmp)), tmp, color='orange') + for boundary in np.arange(0,tmp_len, int(tmp_len//10)): plt.plot([boundary, boundary], [0.0,1.0], color='black', alpha=0.5, linestyle='dashed') + plt.plot([0,0],[0,0], color='black', alpha=0.5, linestyle='dashed', label='Bins(equally-sized)') + plt.title('Sorted Softmax Probs (GT) (label={})\nPatient:{}'.format(label_id, patient_id)) + plt.legend() + # plt.show() + plt.savefig('./tmp/ECE_SortedProbs_label_{}_{}.png'.format(label_id, patient_id, ), bbox_inches='tight');plt.close() + + # ECE plot + plt.plot(np.arange(11), np.arange(11)/10.0, linestyle='dashed', color='black', alpha=0.8) + plt.scatter(np.arange(len(y_predict_bins_mean)) + 0.5 , y_predict_bins_mean, alpha=0.5, color='g', marker='s', label='Mean Pred') + plt.scatter(np.arange(len(y_predict_bins_accuracy)) + 0.5 , y_predict_bins_accuracy, alpha=0.5, color='b', marker='x', label='Accuracy') + diff = np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean) + for bin_id in range(len(y_predict_bins_accuracy)): plt.plot([bin_id + 0.5, bin_id + 0.5],[y_predict_bins_accuracy[bin_id], y_predict_bins_mean[bin_id]], color='pink') + plt.plot([bin_id + 0.5, bin_id + 0.5],[y_predict_bins_accuracy[bin_id], y_predict_bins_mean[bin_id]], color='pink', label='CE') + plt.xticks(ticks=np.arange(11), labels=np.arange(11)/10.0) + plt.title('CE (label={})\nPatient:{}'.format(label_id, patient_id)) + plt.xlabel('Probability') + plt.ylabel('Accuracy') + plt.legend() + # plt.show() + plt.savefig('./tmp/ECE_label_{}_{}.png'.format(label_id, patient_id, ), bbox_inches='tight');plt.close() + + # pdb.set_trace() + + except: + traceback.print_exc() + pdb.set_trace() + + return res_global, res + +def eval_3D_summarize(res, ece_global_obj, model, eval_type, deepsup_eval, label_map, model_folder_epoch_patches, times_mcruns, ttotal, save=False, show=False, verbose=False): + + try: + + pid = os.getpid() + + ############################################################################### + # Summarize # + ############################################################################### + + # Step 1 - Summarize DICE + Surface Distances + if 1: + loss_labels_avg, loss_labels_std = [], [] + hausdorff_labels_avg, hausdorff_labels_std = [], [] + hausdorff95_labels_avg, hausdorff95_labels_std = [], [] + msd_labels_avg, msd_labels_std = [], [] + + loss_labels_list = np.array([res[patient_id][config.KEY_DICE_LABELS] for patient_id in res]) + hausdorff_labels_list = np.array([res[patient_id][config.KEY_HD_LABELS] for patient_id in res]) + hausdorff95_labels_list = np.array([res[patient_id][config.KEY_HD95_LABELS] for patient_id in res]) + msd_labels_list = np.array([res[patient_id][config.KEY_MSD_LABELS] for patient_id in res]) + + for label_id in range(loss_labels_list.shape[1]): + tmp_vals = loss_labels_list[:,label_id] + loss_labels_avg.append(np.mean(tmp_vals[tmp_vals > 0])) + loss_labels_std.append(np.std(tmp_vals[tmp_vals > 0])) + + if label_id > 0: + tmp_vals = hausdorff_labels_list[:,label_id] + hausdorff_labels_avg.append(np.mean(tmp_vals[tmp_vals > 0])) # avoids -1 for "erroneous" HD, and 0 for "not to be calculated" HD + hausdorff_labels_std.append(np.std(tmp_vals[tmp_vals > 0])) + + tmp_vals = hausdorff95_labels_list[:,label_id] + hausdorff95_labels_avg.append(np.mean(tmp_vals[tmp_vals > 0])) + hausdorff95_labels_std.append(np.std(tmp_vals[tmp_vals > 0])) + + tmp_vals = msd_labels_list[:,label_id] + msd_labels_avg.append(np.mean(tmp_vals[tmp_vals > 0])) + msd_labels_std.append(np.std(tmp_vals[tmp_vals > 0])) + + else: + hausdorff_labels_avg.append(0) + hausdorff_labels_std.append(0) + hausdorff95_labels_avg.append(0) + hausdorff95_labels_std.append(0) + msd_labels_avg.append(0) + msd_labels_std.append(0) + + loss_avg = np.mean([res[patient_id][config.KEY_DICE_AVG] for patient_id in res]) + print (' --------------------------- eval_type: ', eval_type) + print (' - dice_labels_3D : ', ['%.4f' % each for each in loss_labels_avg]) + print (' - dice_labels_3D : ', ['%.4f' % each for each in loss_labels_std]) + print (' - dice_3D : %.4f' % np.mean(loss_labels_avg)) + print (' - dice_3D (w/o bgd): %.4f' % np.mean(loss_labels_avg[1:])) + print (' - dice_3D (w/o bgd, w/o chiasm): %.4f' % np.mean(loss_labels_avg[1:2] + loss_labels_avg[3:])) + print ('') + print (' - hausdorff_labels_3D : ', ['%.4f' % each for each in hausdorff_labels_avg]) + print (' - hausdorff_labels_3D : ', ['%.4f' % each for each in hausdorff_labels_std]) + print (' - hausdorff_3D (w/o bgd): %.4f' % np.mean(hausdorff_labels_avg[1:])) + print (' - hausdorff_3D (w/o bgd, w/o chiasm): %.4f' % np.mean(hausdorff_labels_avg[1:2] + hausdorff_labels_avg[3:])) + print ('') + print (' - hausdorff95_labels_3D : ', ['%.4f' % each for each in hausdorff95_labels_avg]) + print (' - hausdorff95_labels_3D : ', ['%.4f' % each for each in hausdorff95_labels_std]) + print (' - hausdorff95_3D (w/o bgd): %.4f' % np.mean(hausdorff95_labels_avg[1:])) + print (' - hausdorff95_3D (w/o bgd, w/o chiasm): %.4f' % np.mean(hausdorff95_labels_avg[1:2] + hausdorff95_labels_avg[3:])) + print ('') + print (' - msd_labels_3D : ', ['%.4f' % each for each in msd_labels_avg]) + print (' - msd_labels_3D : ', ['%.4f' % each for each in msd_labels_std]) + print (' - msd_3D (w/o bgd): %.4f' % np.mean(msd_labels_avg[1:])) + print (' - msd_3D (w/o bgd, w/o chiasm): %.4f' % np.mean(msd_labels_avg[1:2] + msd_labels_avg[3:])) + + # Step 2 - Summarize AvU + if 1: + try: + print ('') + if config.KEY_AVU_PAC_ENT in res[list(res.keys())[0]]: + p_ac_list_ent = np.array([res[patient_id][config.KEY_AVU_PAC_ENT] for patient_id in res]) + p_ui_list_ent = np.array([res[patient_id][config.KEY_AVU_PUI_ENT] for patient_id in res]) + pavpu_list_ent = np.array([res[patient_id][config.KEY_AVU_ENT] for patient_id in res]) + pavpu_unc_threshold_list_ent = np.array([res[patient_id][config.KEY_THRESH_ENT] for patient_id in res]) + print (' - AvU values for entropy') + print (' - p(acc|cer) : %.4f +- %.4f' % ( np.mean([p_ac_list_ent[p_ac_list_ent > -1]]) , np.std([p_ac_list_ent[p_ac_list_ent > -1]]) )) + print (' - p(unc|inac) : %.4f +- %.4f' % ( np.mean([p_ui_list_ent[p_ui_list_ent > -1]]) , np.std([p_ui_list_ent[p_ui_list_ent > -1]]) )) + print (' - pavpu_3D : %.4f +- %.4f' % ( np.mean([pavpu_list_ent[pavpu_list_ent > -1]]) , np.std([pavpu_list_ent[pavpu_list_ent > -1]]) )) + print (' - unc_threshold : %.4f +- %.4f' % ( np.mean(pavpu_unc_threshold_list_ent[pavpu_unc_threshold_list_ent > -1]), np.std(pavpu_unc_threshold_list_ent[pavpu_unc_threshold_list_ent > -1]) )) + + if config.KEY_AVU_PAC_MIF in res[list(res.keys())[0]]: + p_ac_list_mif = np.array([res[patient_id][config.KEY_AVU_PAC_MIF] for patient_id in res]) + p_ui_list_mif = np.array([res[patient_id][config.KEY_AVU_PUI_MIF] for patient_id in res]) + pavpu_list_mif = np.array([res[patient_id][config.KEY_AVU_MIF] for patient_id in res]) + pavpu_unc_threshold_list_mif = np.array([res[patient_id][config.KEY_THRESH_MIF] for patient_id in res]) + print (' - AvU values for mutual info') + print (' - p(acc|cer) : %.4f +- %.4f' % ( np.mean([p_ac_list_mif[p_ac_list_mif > -1]]) , np.std([p_ac_list_mif[p_ac_list_mif > -1]]) )) + print (' - p(unc|inac) : %.4f +- %.4f' % ( np.mean([p_ui_list_mif[p_ui_list_mif > -1]]) , np.std([p_ui_list_mif[p_ui_list_mif > -1]]) )) + print (' - pavpu_3D : %.4f +- %.4f' % ( np.mean([pavpu_list_mif[pavpu_list_mif > -1]]) , np.std([pavpu_list_mif[pavpu_list_mif > -1]]) )) + print (' - unc_threshold : %.4f +- %.4f' % ( np.mean(pavpu_unc_threshold_list_mif[pavpu_unc_threshold_list_mif > -1]), np.std(pavpu_unc_threshold_list_mif[pavpu_unc_threshold_list_mif > -1]) )) + + if config.KEY_AVU_PAC_UNC in res[list(res.keys())[0]]: + p_ac_list_unc = np.array([res[patient_id][config.KEY_AVU_PAC_UNC] for patient_id in res]) + p_ui_list_unc = np.array([res[patient_id][config.KEY_AVU_PUI_UNC] for patient_id in res]) + pavpu_list_unc = np.array([res[patient_id][config.KEY_AVU_UNC] for patient_id in res]) + pavpu_unc_threshold_list_unc = np.array([res[patient_id][config.KEY_THRESH_UNC] for patient_id in res]) + print (' - AvU values for percentile subtracts') + print (' - p(acc|cer) : %.4f +- %.4f' % ( np.mean([p_ac_list_unc[p_ac_list_unc > -1]]) , np.std([p_ac_list_unc[p_ac_list_unc > -1]]) )) + print (' - p(unc|inac) : %.4f +- %.4f' % ( np.mean([p_ui_list_unc[p_ui_list_unc > -1]]) , np.std([p_ui_list_unc[p_ui_list_unc > -1]]) )) + print (' - pavpu_3D : %.4f +- %.4f' % ( np.mean([pavpu_list_unc[pavpu_list_unc > -1]]) , np.std([pavpu_list_unc[pavpu_list_unc > -1]]) )) + print (' - unc_threshold : %.4f +- %.4f' % ( np.mean(pavpu_unc_threshold_list_unc[pavpu_unc_threshold_list_unc > -1]), np.std(pavpu_unc_threshold_list_unc[pavpu_unc_threshold_list_unc > -1]) )) + + print (' - pavpu params: PAVPU_UNC_THRESHOLD: ', config.PAVPU_UNC_THRESHOLD) + # print (' - pavpu params: PAVPU_GRID_SIZE : ', PAVPU_GRID_SIZE) + # print (' - pavpu params: PAVPU_RATIO_NEG : ', PAVPU_RATIO_NEG) + + except: + traceback.print_exc() + if DEBUG: pdb.set_trace() + + # Step 3 - Summarize ECE + if 1: + print ('') + gc.collect() + + nan_value = config.VAL_ECE_NAN + ece_labels_obj = {} + ece_labels = [] + label_count = len(ece_global_obj) + pbar_desc_prefix = '[ECE]' + ece_global_obj_keys = list(ece_global_obj.keys()) + res[config.KEY_PATIENT_GLOBAL] = {} + with tqdm.tqdm(total=label_count, desc=pbar_desc_prefix, disable=True) as pbar_ece: + for label_id in ece_global_obj_keys: + o_true_label = np.array(ece_global_obj[label_id]['o_true_label']) + o_predict_label = np.array(ece_global_obj[label_id]['o_predict_label']) + y_predict_label = np.array(ece_global_obj[label_id]['y_predict_label']) + if label_id in ece_global_obj: del ece_global_obj[label_id] + gc.collect() + + # Step 1.1 - Bin the probs and calculate their mean + y_predict_label_bin_ids = np.digitize(y_predict_label, np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.01]), right=False) - 1 + y_predict_binned_vals = [y_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)] + y_predict_bins_mean = [np.mean(vals) if len(vals) else nan_value for vals in y_predict_binned_vals] + + # Step 1.2 - Calculate the accuracy of each bin + o_predict_label_bins = [o_predict_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)] + o_true_label_bins = [o_true_label[y_predict_label_bin_ids == bin_id] for bin_id in range(label_count)] + y_predict_bins_accuracy = [np.sum(o_predict_label_bins[bin_id] == o_true_label_bins[bin_id])/len(o_predict_label_bins[bin_id]) if len(o_predict_label_bins[bin_id]) else nan_value for bin_id in range(label_count)] + y_predict_bins_len = [len(o_predict_label_bins[bin_id]) for bin_id in range(label_count)] + + # Step 1.3 - Wrapup + N = np.prod(y_predict_label.shape) + ce = np.array((np.array(y_predict_bins_len)/N)*(np.array(y_predict_bins_accuracy)-np.array(y_predict_bins_mean))) + ce[ce == 0] = nan_value + ece_label = np.sum(np.abs(ce[ce != nan_value])) + ece_labels.append(ece_label) + ece_labels_obj[label_id] = {'y_predict_bins_mean':y_predict_bins_mean, 'y_predict_bins_accuracy':y_predict_bins_accuracy, 'ce':ce, 'ece':ece_label} + + pbar_ece.update(1) + memory = pbar_desc_prefix + ' [' + str(utils.get_memory(pid)) + ']' + pbar_ece.set_description(desc=memory, refresh=True) + + res[config.KEY_PATIENT_GLOBAL][label_id] = {'ce':ce, 'ece':ece_label} + + print (' - ece_labels : ', ['%.4f' % each for each in ece_labels]) + print (' - ece : %.4f' % np.mean(ece_labels)) + print (' - ece (w/o bgd): %.4f' % np.mean(ece_labels[1:])) + print (' - ece (w/o bgd, w/o chiasm): %.4f' % np.mean(ece_labels[1:2] + ece_labels[3:])) + print ('') + + del ece_global_obj + gc.collect() + + # Step 4 - Plot + if 1: + if save and not deepsup_eval: + f, axarr = plt.subplots(3,1, figsize=(15,10)) + boxplot_dice, boxplot_hausdorff, boxplot_hausdorff95, boxplot_msd = {}, {}, {}, {} + for label_id in range(len(loss_labels_list[0])): + label_name, _ = utils.get_info_from_label_id(label_id, label_map) + boxplot_dice[label_name] = loss_labels_list[:,label_id] + boxplot_hausdorff[label_name] = hausdorff_labels_list[:,label_id] + boxplot_hausdorff95[label_name] = hausdorff95_labels_list[:,label_id] + boxplot_msd[label_name] = msd_labels_list[:,label_id] + + axarr[0].boxplot(boxplot_dice.values()) + axarr[0].set_xticks(range(1, len(boxplot_dice)+1)) + axarr[0].set_xticklabels(boxplot_dice.keys()) + axarr[0].set_ylim([0.0,1.1]) + axarr[0].set_title('DICE (Avg: {}) \n w/o chiasm:{}'.format( + '%.4f' % np.mean(loss_labels_avg[1:]) + , '%.4f' % np.mean(loss_labels_avg[1:2] + loss_labels_avg[3:]) + ) + ) + + axarr[1].boxplot(boxplot_hausdorff.values()) + axarr[1].set_xticks(range(1, len(boxplot_hausdorff)+1)) + axarr[1].set_xticklabels(boxplot_hausdorff.keys()) + axarr[1].set_ylim([0.0,10.0]) + axarr[1].set_title('Hausdorff') + + axarr[2].boxplot(boxplot_hausdorff95.values()) + axarr[2].set_xticks(range(1, len(boxplot_hausdorff95)+1)) + axarr[2].set_xticklabels(boxplot_hausdorff95.keys()) + axarr[2].set_ylim([0.0,6.0]) + axarr[2].set_title('95% Hausdorff') + + path_results = str(Path(model_folder_epoch_patches).joinpath('results_all.png')) + plt.savefig(path_results, bbox_inches='tight') + plt.close() + + f, axarr = plt.subplots(1, figsize=(15,10)) + axarr.boxplot(boxplot_dice.values()) + axarr.set_xticks(range(1, len(boxplot_dice)+1)) + axarr.set_xticklabels(boxplot_dice.keys()) + axarr.set_ylim([0.0,1.1]) + axarr.set_yticks(np.arange(0,1.1,0.05)) + axarr.set_title('DICE') + axarr.grid() + path_results = str(Path(model_folder_epoch_patches).joinpath('results_all_dice.png')) + plt.savefig(path_results, bbox_inches='tight') + plt.close() + + f, axarr = plt.subplots(1, figsize=(15,10)) + axarr.boxplot(boxplot_hausdorff95.values()) + axarr.set_xticks(range(1, len(boxplot_hausdorff95)+1)) + axarr.set_xticklabels(boxplot_hausdorff95.keys()) + axarr.set_title('95% HD') + axarr.grid() + path_results = str(Path(model_folder_epoch_patches).joinpath('results_all_hd95.png')) + plt.savefig(path_results, bbox_inches='tight') + plt.close() + + f, axarr = plt.subplots(1, figsize=(15,10)) + axarr.boxplot(boxplot_hausdorff.values()) + axarr.set_xticks(range(1, len(boxplot_hausdorff)+1)) + axarr.set_xticklabels(boxplot_hausdorff.keys()) + axarr.set_title('Hausdorff Distance') + axarr.grid() + path_results = str(Path(model_folder_epoch_patches).joinpath('results_all_hd.png')) + plt.savefig(path_results, bbox_inches='tight') + plt.close() + + f, axarr = plt.subplots(1, figsize=(15,10)) + axarr.boxplot(boxplot_msd.values()) + axarr.set_xticks(range(1, len(boxplot_msd)+1)) + axarr.set_xticklabels(boxplot_msd.keys()) + axarr.set_title('MSD') + axarr.grid() + path_results = str(Path(model_folder_epoch_patches).joinpath('results_all_msd.png')) + plt.savefig(path_results, bbox_inches='tight') + plt.close() + + # ECE + for label_id in ece_labels_obj: + y_predict_bins_mean = ece_labels_obj[label_id]['y_predict_bins_mean'] + y_predict_bins_accuracy = ece_labels_obj[label_id]['y_predict_bins_accuracy'] + ece = ece_labels_obj[label_id]['ece'] + + plt.plot(np.arange(11), np.arange(11)/10.0, linestyle='dashed', color='black', alpha=0.8) + plt.scatter(np.arange(len(y_predict_bins_mean)) + 0.5 , y_predict_bins_mean, alpha=0.5, color='g', marker='s', label='Mean Pred') + plt.scatter(np.arange(len(y_predict_bins_accuracy)) + 0.5 , y_predict_bins_accuracy, alpha=0.5, color='b', marker='x', label='Accuracy') + for bin_id in range(len(y_predict_bins_accuracy)): plt.plot([bin_id + 0.5, bin_id + 0.5],[y_predict_bins_accuracy[bin_id], y_predict_bins_mean[bin_id]], color='pink') + plt.plot([bin_id + 0.5, bin_id + 0.5],[y_predict_bins_accuracy[bin_id], y_predict_bins_mean[bin_id]], color='pink', label='CE') + plt.xticks(ticks=np.arange(11), labels=np.arange(11)/10.0) + plt.title('CE (label={})\nECE: {}'.format(label_id, '%.5f' % (ece))) + plt.xlabel('Probability') + plt.ylabel('Accuracy') + plt.ylim([-0.15, 1.05]) + plt.legend() + + # plt.show() + path_results = str(Path(model_folder_epoch_patches).joinpath('results_ece_label{}.png'.format(label_id))) + plt.savefig(str(path_results), bbox_inches='tight') + plt.close() + + # Step 5 - Save data as .json + if 1: + try: + + filename = str(Path(model_folder_epoch_patches).joinpath(config.FILENAME_EVAL3D_JSON)) + utils.write_json(res, filename) + + except: + traceback.print_exc() + pdb.set_trace() + + model.trainable=True + print ('\n - [eval_3D] Avg of times_mcruns : {:f} +- {:f}'.format(np.mean(times_mcruns), np.std(times_mcruns))) + print (' - [eval_3D()] Total time passed (save={}) : {}s \n'.format(save, round(time.time() - ttotal, 2))) + if verbose: pdb.set_trace() + + return loss_avg, {i:loss_labels_avg[i] for i in range(len(loss_labels_avg))} + + except: + model.trainable=True + traceback.print_exc() + return -1, {} + +def eval_3D_process_outputs(res, ece_global_obj, patient_id_curr, meta1_batch, patient_img, patient_gt, patient_pred_vals, patient_pred_overlap, patient_pred_std, patient_pred_ent, patient_pred_mif, patient_pred_unc + , deepsup_eval, model_folder_epoch_imgs, model_folder_epoch_patches, label_map, label_colors, t99, show=False, save=False, verbose=False): + + try: + + # Step 3.1.1 - Get stitched patient grid + if verbose: t0 = time.time() + patient_pred_ent = patient_pred_ent/patient_pred_overlap # [H,W,D]/[H,W,D] + patient_pred_mif = patient_pred_mif/patient_pred_overlap + patient_pred_overlap = np.expand_dims(patient_pred_overlap, -1) + patient_pred = patient_pred_vals/patient_pred_overlap # [H,W,D,C]/[H,W,D,1] + patient_pred_std = patient_pred_std/patient_pred_overlap + patient_pred_unc = patient_pred_unc/patient_pred_overlap + del patient_pred_vals + del patient_pred_overlap + if 1: + patient_pred_unc = np.take_along_axis(patient_pred_unc, np.expand_dims(np.argmax(patient_pred, axis=-1),axis=-1), axis=-1)[:,:,:,0] + + + gc.collect() # returns number of unreachable objects collected by GC + patient_pred_postprocessed = losses.remove_smaller_components(patient_gt, patient_pred, meta=patient_id_curr, label_ids_small = [2,4,5]) + if verbose: print (' - [eval_3D()] Post-Process time : ', time.time() - t0,'s') + + # Step 3.1.2 - Loss Calculation + spacing = np.array([meta1_batch[4], meta1_batch[5], meta1_batch[6]])/100.0 + try: + if verbose: t0 = time.time() + loss_avg_val, loss_labels_val = losses.dice_numpy(patient_gt, patient_pred_postprocessed) + hausdorff_avg_val, hausdorff_labels_val, hausdorff95_avg_val, hausdorff95_labels_val, msd_avg_val, msd_labels_vals = losses.get_surface_distances(patient_gt, patient_pred_postprocessed, spacing) + if verbose: + print (' - [eval_3D()] DICE : ', ['%.4f' % (each) for each in loss_labels_val]) + print (' - [eval_3D()] HD95 : ', ['%.4f' % (each) for each in hausdorff95_labels_val]) + + if loss_avg_val != -1 and len(loss_labels_val): + res[patient_id_curr] = { + config.KEY_DICE_AVG : loss_avg_val + , config.KEY_DICE_LABELS : loss_labels_val + , config.KEY_HD_AVG : hausdorff95_avg_val + , config.KEY_HD_LABELS : hausdorff_labels_val + , config.KEY_HD95_AVG : hausdorff95_avg_val + , config.KEY_HD95_LABELS : hausdorff95_labels_val + , config.KEY_MSD_AVG : msd_avg_val + , config.KEY_MSD_LABELS : msd_labels_vals + } + else: + print (' - [ERROR][eval_3D()] patient_id: ', patient_id_curr) + if verbose: print (' - [eval_3D()] Loss calculation time: ', time.time() - t0,'s') + except: + traceback.print_exc() + if DEBUG: pdb.set_trace() + + # Step 3.1.3 - ECE calculation + if verbose: t0 = time.time() + ece_global_obj, ece_patient_obj = get_ece(patient_gt, patient_pred, patient_id_curr, ece_global_obj) + res[patient_id_curr][config.KEY_ECE_LABELS] = ece_patient_obj + if verbose: print (' - [eval_3D()] ECE time : ', time.time() - t0,'s') + + # Step 3.1.4 - Uncertainty Quantification PAvPU + if verbose: t0 = time.time() + if 0: + prob_acc_cer_ent, prob_unc_inacc_ent, pavpu_ent, thresh_ent, patient_pred_error = losses.get_pavpu_errorareas(patient_gt, patient_pred, patient_pred_ent, unc_threshold=config.PAVPU_ENT_THRESHOLD, unc_type='entropy') + res[patient_id_curr][config.KEY_AVU_PAC_ENT] = prob_acc_cer_ent + res[patient_id_curr][config.KEY_AVU_PUI_ENT] = prob_unc_inacc_ent + res[patient_id_curr][config.KEY_AVU_ENT] = pavpu_ent + res[patient_id_curr][config.KEY_THRESH_ENT] = thresh_ent + + prob_acc_cer_mif, prob_unc_inacc_mif, pavpu_mif, thresh_mif, patient_pred_error = losses.get_pavpu_errorareas(patient_gt, patient_pred, patient_pred_mif, unc_threshold=config.PAVPU_MIF_THRESHOLD, unc_type='mutual info') + res[patient_id_curr][config.KEY_AVU_PAC_MIF] = prob_acc_cer_mif + res[patient_id_curr][config.KEY_AVU_PUI_MIF] = prob_unc_inacc_mif + res[patient_id_curr][config.KEY_AVU_MIF] = pavpu_mif + res[patient_id_curr][config.KEY_THRESH_MIF] = thresh_mif + + # prob_acc_cer_unc, prob_unc_inacc_unc, pavpu_unc, thresh_unc, patient_pred_error = losses.get_pavpu_errorareas(patient_gt, patient_pred, patient_pred_unc, unc_threshold=config.PAVPU_UNC_THRESHOLD, unc_type='percentile subtracts') + # res[patient_id_curr][config.KEY_AVU_PAC_UNC] = prob_acc_cer_unc + # res[patient_id_curr][config.KEY_AVU_PUI_UNC] = prob_unc_inacc_unc + # res[patient_id_curr][config.KEY_AVU_UNC] = pavpu_unc + # res[patient_id_curr][config.KEY_THRESH_UNC] = thresh_unc + elif 1: + prob_acc_cer_ent, prob_unc_inacc_ent, pavpu_ent, patient_pred_error = losses.get_pavpu_gtai(patient_gt, patient_pred, patient_pred_ent, unc_threshold=config.PAVPU_ENT_THRESHOLD, unc_type='entropy') + res[patient_id_curr][config.KEY_AVU_PAC_ENT] = prob_acc_cer_ent + res[patient_id_curr][config.KEY_AVU_PUI_ENT] = prob_unc_inacc_ent + res[patient_id_curr][config.KEY_AVU_ENT] = pavpu_ent + res[patient_id_curr][config.KEY_THRESH_ENT] = config.PAVPU_ENT_THRESHOLD + + prob_acc_cer_mif, prob_unc_inacc_mif, pavpu_mif, patient_pred_error = losses.get_pavpu_gtai(patient_gt, patient_pred, patient_pred_mif, unc_threshold=config.PAVPU_MIF_THRESHOLD, unc_type='mutual info') + res[patient_id_curr][config.KEY_AVU_PAC_MIF] = prob_acc_cer_mif + res[patient_id_curr][config.KEY_AVU_PUI_MIF] = prob_unc_inacc_mif + res[patient_id_curr][config.KEY_AVU_MIF] = pavpu_mif + res[patient_id_curr][config.KEY_THRESH_MIF] = config.PAVPU_MIF_THRESHOLD + + if verbose: print (' - [eval_3D()] PAvPU time : ', time.time() - t0,'s') + + # Step 3.1.5 - Save/Visualize + if not deepsup_eval: + if verbose: t0 = time.time() + eval_3D_finalize(exp_name, patient_img, patient_gt, patient_pred_postprocessed, patient_pred, patient_pred_std, patient_pred_ent, patient_pred_mif, patient_pred_unc, patient_pred_error + , patient_id_curr + , model_folder_epoch_imgs, model_folder_epoch_patches + , loss_labels_val, hausdorff_labels_val, hausdorff95_labels_val, msd_labels_vals + , spacing, label_map, label_colors + , show=show, save=save) + if verbose: print (' - [eval_3D()] Save as .nrrd time : ', time.time() - t0,'s') + + if verbose: print (' - [eval_3D()] Total patient time : ', time.time() - t99,'s') + + # Step 3.1.6 + del patient_img + del patient_gt + del patient_pred + del patient_pred_std + del patient_pred_ent + del patient_pred_postprocessed + del patient_pred_mif + del patient_pred_unc + gc.collect() + + return res, ece_global_obj + + except: + traceback.print_exc() + if DEBUG: pdb.set_trace() + return res, ece_global_obj + +def eval_3D_get_outputs(model, X, Y, training_bool, MC_RUNS, deepsup, deepsup_eval): + + # Step 0 - Init + # DO_KEYS = [config.KEY_PERC] + DO_KEYS = [config.KEY_MIF, config.KEY_ENT] + + # Step 1 - Warm up model + _ = model(X, training=training_bool) + + # Step 2 - Run Monte-Carlo predictions + try: + tic_mcruns = time.time() + if deepsup: + if deepsup_eval: + y_predict = tf.stack([model(X, training=training_bool)[0] for _ in range(MC_RUNS)]) # [MC,B,H,W,D,C] # [Note] with model.trainable=False and model.training=True we get dropout at inference time + X = X[:,::2,::2,::2,:] + Y = Y[:,::2,::2,::2,:] + else: + y_predict = tf.stack([model(X, training=training_bool)[1] for _ in range(MC_RUNS)]) # [MC,B,H,W,D,C] # [Note] with model.trainable=False and model.training=True we get dropout at inference time + else: + y_predict = tf.stack([model(X, training=training_bool) for _ in range(MC_RUNS)]) # [MC,B,H,W,D,C] # [Note] with model.trainable=False and model.training=True we get dropout at inference time + toc_mcruns = time.time() + except tf.errors.ResourceExhaustedError as e: + print (' - [eval_3D_get_outputs()] OOM error for MC_RUNS={}'.format(MC_RUNS)) + + try: + MC_RUNS = 5 + tic_mcruns = time.time() + if deepsup: + if deepsup_eval: + y_predict = tf.stack([model(X, training=training_bool)[0] for _ in range(MC_RUNS)]) # [MC,B,H,W,D,C] # [Note] with model.trainable=False and model.training=True we get dropout at inference time + X = X[:,::2,::2,::2,:] + Y = Y[:,::2,::2,::2,:] + else: + y_predict = tf.stack([model(X, training=training_bool)[1] for _ in range(MC_RUNS)]) # [MC,B,H,W,D,C] # [Note] with model.trainable=False and model.training=True we get dropout at inference time + else: + y_predict = tf.stack([model(X, training=training_bool) for _ in range(MC_RUNS)]) # [MC,B,H,W,D,C] # [Note] with model.trainable=False and model.training=True we get dropout at inference time + toc_mcruns = time.time() + except tf.errors.ResourceExhaustedError as e: + print (' - [eval_3D_get_outputs()] OOM error for MC_RUNS=10') + + # Step 3 - Calculate different metrics + if config.KEY_MIF in DO_KEYS: + y_predict_mif = y_predict * tf.math.log(y_predict + _EPSILON) # [MC,B,H,W,D,C] + y_predict_mif = tf.math.reduce_sum(y_predict_mif, axis=[0,-1])/MC_RUNS # [MC,B,H,W,D,C] -> [B,H,W,D] + else: + y_predict_mif = [] + + if config.KEY_STD in DO_KEYS: + y_predict_std = tf.math.reduce_std(y_predict, axis=0) # [MC,B,H,W,D,C] -> [B,H,W,D,C] + else: + y_predict_std = [] + + if config.KEY_PERC in DO_KEYS: + y_predict_perc = tfp.stats.percentile(y_predict, q=[30,70], axis=0, interpolation='nearest') + y_predict_unc = y_predict_perc[1] - y_predict_perc[0] + del y_predict_perc + gc.collect() + else: + y_predict_unc = [] + + y_predict = tf.math.reduce_mean(y_predict, axis=0) + + if config.KEY_ENT in DO_KEYS: + y_predict_ent = -1*tf.math.reduce_sum(y_predict * tf.math.log(y_predict + _EPSILON), axis=-1) # [B,H,W,D,C] -> # [B,H,W,D] ent = -p.log(p) + y_predict_mif = y_predict_ent + y_predict_mif # [B,H,W,D] + [B,H,W,D] = [B,H,W,D]; MI = ent + expectation(ent) + else: + y_predict_ent = [] + y_predict_mif = [] + + return Y, y_predict, y_predict_std, y_predict_ent, y_predict_mif, y_predict_unc, toc_mcruns-tic_mcruns + +def eval_3D(model, dataset_eval, dataset_eval_gen, params, show=False, save=False, verbose=False): + + try: + + # Step 0.0 - Variables under debugging + pass + + # Step 0.1 - Extract params + PROJECT_DIR = params['PROJECT_DIR'] + exp_name = params['exp_name'] + pid = params['pid'] + eval_type = params['eval_type'] + batch_size = params['batch_size'] + batch_size = 2 + epoch = params['epoch'] + deepsup = params['deepsup'] + deepsup_eval = params['deepsup_eval'] + label_map = dict(dataset_eval.get_label_map()) + label_colors = dict(dataset_eval.get_label_colors()) + + if verbose: print (''); print (' --------------------- eval_3D({}) ---------------------'.format(eval_type)) + + # Step 0.2 - Init results array + res = {} + ece_global_obj = {} + patient_grid_count = {} + + # Step 0.3 - Init temp variables + patient_id_curr = None + w_grid, h_grid, d_grid = None, None, None + meta1_batch = None + patient_gt = None + patient_img = None + patient_pred_overlap = None + patient_pred_vals = None + model_folder_epoch_patches = None + model_folder_epoch_imgs = None + + mc_runs = params.get(config.KEY_MC_RUNS, None) + training_bool = params.get(config.KEY_TRAINING_BOOL, None) + model_folder_epoch_patches, model_folder_epoch_imgs = utils.get_eval_folders(PROJECT_DIR, exp_name, epoch, eval_type, mc_runs, training_bool, create=True) + + # Step 0.4 - Debug vars + ttotal,t0, t99 = time.time(), None, None + times_mcruns = [] + + # Step 1 - Loop over dataset_eval (which provides patients & grids in an ordered manner) + print ('') + model.trainable = False + pbar_desc_prefix = 'Eval3D_{} [batch={}]'.format(eval_type, batch_size) + training_bool = params.get('training_bool',True) # [True, False] + with tqdm.tqdm(total=len(dataset_eval), desc=pbar_desc_prefix, leave=False) as pbar_eval: + for (X,Y,meta1,meta2) in dataset_eval_gen.repeat(1): + + # Step 1.1 - Get MC results + MC_RUNS = params.get(config.KEY_MC_RUNS,10) + Y, y_predict, y_predict_std, y_predict_ent, y_predict_mif, y_predict_unc, mcruns_time = eval_3D_get_outputs(model, X, Y, training_bool, MC_RUNS, deepsup, deepsup_eval) + times_mcruns.append(mcruns_time) + + for batch_id in range(X.shape[0]): + + # Step 2 - Get grid info + patient_id_running = meta2[batch_id].numpy().decode('utf-8') + if patient_id_running in patient_grid_count: patient_grid_count[patient_id_running] += 1 + else: patient_grid_count[patient_id_running] = 1 + + meta1_batch = meta1[batch_id].numpy() + w_start, h_start, d_start = meta1_batch[1], meta1_batch[2], meta1_batch[3] + + # Step 3 - Check if its a new patient + if patient_id_running != patient_id_curr: + + # Step 3.1 - Sort out old patient (patient_id_curr) + if patient_id_curr != None: + + res, ece_global_obj = eval_3D_process_outputs(res, ece_global_obj, patient_id_curr, meta1_batch, patient_img, patient_gt, patient_pred_vals, patient_pred_overlap, patient_pred_std, patient_pred_ent, patient_pred_mif, patient_pred_unc + , deepsup_eval, model_folder_epoch_imgs, model_folder_epoch_patches, label_map, label_colors, t99, show=show, save=save, verbose=verbose) + + # Step 3.2 - Create variables for new patient + if verbose: t99 = time.time() + patient_id_curr = patient_id_running + patient_scan_size = meta1_batch[7:10] + dataset_name = patient_id_curr.split('-')[0] + dataset_this = dataset_eval.get_subdataset(param_name=dataset_name) + w_grid, h_grid, d_grid = dataset_this.w_grid, dataset_this.h_grid, dataset_this.d_grid + if deepsup_eval: + patient_scan_size = patient_scan_size//2 + w_grid, h_grid, d_grid = w_grid//2, h_grid//2, d_grid//2 + + patient_pred_size = list(patient_scan_size) + [len(dataset_this.LABEL_MAP)] + patient_pred_overlap = np.zeros(patient_scan_size, dtype=np.uint8) + patient_pred_ent = np.zeros(patient_scan_size, dtype=np.float32) + patient_pred_mif = np.zeros(patient_scan_size, dtype=np.float32) + patient_pred_vals = np.zeros(patient_pred_size, dtype=np.float32) + patient_pred_std = np.zeros(patient_pred_size, dtype=np.float32) + patient_pred_unc = np.zeros(patient_pred_size, dtype=np.float32) + patient_gt = np.zeros(patient_pred_size, dtype=np.float32) + if show or save: + patient_img = np.zeros(list(patient_scan_size) + [1], dtype=np.float32) + else: + patient_img = [] + + # Step 4 - If not new patient anymore, fill up data + if deepsup_eval: + w_start, h_start, d_start = w_start//2, h_start//2, d_start//2 + patient_pred_vals[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += y_predict[batch_id] + if len(y_predict_std): + patient_pred_std[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += y_predict_std[batch_id] + if len(y_predict_ent): + patient_pred_ent[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += y_predict_ent[batch_id] + if len(y_predict_mif): + patient_pred_mif[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += y_predict_mif[batch_id] + if len(y_predict_unc): + patient_pred_unc[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += y_predict_unc[batch_id] + + patient_pred_overlap[w_start:w_start + w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] += np.ones(y_predict[batch_id].shape[:-1], dtype=np.uint8) + patient_gt[w_start:w_start+w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] = Y[batch_id] + if show or save: + patient_img[w_start:w_start+w_grid, h_start:h_start+h_grid, d_start:d_start+d_grid] = X[batch_id] + + pbar_eval.update(batch_size) + mem_used = utils.get_memory(pid) + memory = pbar_desc_prefix + ' [' + mem_used + ']' + pbar_eval.set_description(desc=memory, refresh=True) + + # Step 3 - For last patient + res, ece_global_obj = eval_3D_process_outputs(res, ece_global_obj, patient_id_curr, meta1_batch, patient_img, patient_gt, patient_pred_vals, patient_pred_overlap, patient_pred_std, patient_pred_ent, patient_pred_mif, patient_pred_unc + , deepsup_eval, model_folder_epoch_imgs, model_folder_epoch_patches, label_map, label_colors, t99, show=show, save=save, verbose=verbose) + print ('\n - [eval_3D()] Time passed to accumulate grids & process patients: ', round(time.time() - ttotal, 2), 's') + + return eval_3D_summarize(res, ece_global_obj, model, eval_type, deepsup_eval, label_map, model_folder_epoch_patches, times_mcruns, ttotal, save=save, show=show, verbose=verbose) + + except: + traceback.print_exc() + model.trainable = True + return -1, {} + +############################################################ +# VAL # +############################################################ + +def val(model, dataset, params, show=False, save=False, verbose=False): + try: + + # Step 1 - Load Model + + load_model_params = {'PROJECT_DIR': params['PROJECT_DIR'] + , 'exp_name': params['exp_name'] + , 'load_epoch': params['epoch'] + , 'optimizer': tf.keras.optimizers.Adam() + } + utils.load_model(model, load_type=config.MODE_VAL, params=load_model_params) + print ('') + print (' - [train.py][val()] Model({}) Loaded for {} at epoch-{} (validation purposes) !'.format(str(model), params['exp_name'], params['epoch'])) + print ('') + + # Step 3 - Calculate losses + dataset_gen = dataset.generator().batch(params['batch_size']).prefetch(params['prefetch_batch']) + loss_avg, loss_labels_avg = eval_3D(model, dataset, dataset_gen, params, show=show, save=save, verbose=verbose) + + except: + traceback.print_exc() + pdb.set_trace() + +def val_dropout(exp_name, model, dataset, epoch, batch_size=32, show=False): + try: + utils.load_model(exp_name, model, epoch, {'optimizer':tf.keras.optimizers.Adam()}, load_type='test') + print (' - [train.py][val()] Model Loaded at epoch-{} !'.format(epoch)) + + MC_RUNS = 10 + + with tqdm.tqdm(total=len(dataset), desc='') as pbar: + for (X,Y,meta1, meta2) in dataset.generator().batch(batch_size): + + if 1: + y_predict = tf.stack([model(X, training=True) for _ in range(MC_RUNS)]) # [Note] with model.trainable=False and model.training=True we get dropout at inference time + y_predict_mean = tf.math.reduce_mean(y_predict, axis=0) + y_predict_std = tf.math.reduce_std(y_predict, axis=0) + utils.viz_model_output_dropout(X,Y,y_predict_mean, y_predict_std, exp_name, epoch, dataset, meta2, mode='val_dropout') + + if 1: + y_predict = model(X, training=False) + utils.viz_model_output(X,Y,y_predict, exp_name, epoch, dataset, meta2, mode='val') + + pbar.update(batch_size) + + except: + traceback.print_exc() + pdb.set_trace() + +############################################################ +# TRAINER # +############################################################ + +class Trainer: + + def __init__(self, params): + + # Init + self.params = params + + # Print + self._train_preprint() + + # Random Seeds + self._set_seed() + + # Set the dataloaders + self._set_dataloaders() + + # Set the model + self._set_model() + + # Set Metrics + self._set_metrics() + + # Other flags + self.write_model_done = False + + def _train_preprint(self): + print ('') + print (' -------------- {} ({})'.format(self.params['exp_name'], str(datetime.datetime.now()))) + + print ('') + print (' DATALOADER ') + print (' ---------- ') + print (' - dir_type: ', self.params['dataloader']['dir_type']) + + print (' -- resampled: ', self.params['dataloader']['resampled']) + print (' -- crop_init: ', self.params['dataloader']['crop_init']) + print (' -- grid: ', self.params['dataloader']['grid']) + print (' --- filter_grid : ', self.params['dataloader']['filter_grid']) + print (' --- random_grid : ', self.params['dataloader']['random_grid']) + print (' --- centred_prob : ', self.params['dataloader']['centred_prob']) + + print (' -- batch_size: ', self.params['dataloader']['batch_size']) + print (' -- prefetch_batch : ', self.params['dataloader']['prefetch_batch']) + print (' -- parallel_calls : ', self.params['dataloader']['parallel_calls']) + print (' -- shuffle : ', self.params['dataloader']['shuffle']) + + print (' -- single_sample: ', self.params['dataloader']['single_sample']) + if self.params['dataloader']['single_sample']: + print (' !!!!!!!!!!!!!!!!!!! SINGLE SAMPLE !!!!!!!!!!!!!!!!!!!') + print ('') + + print ('') + print (' MODEL ') + print (' ----- ') + print (' - Model: ', str(self.params['model']['name'])) + print (' -- KL Schedule : ', self.params['model']['kl_schedule']) + print (' -- KL Alpha Init: ', self.params['model']['kl_alpha_init']) + print (' -- KL Scaler : ', self.params['model']['kl_scale_factor']) + print (' -- Activation : ', self.params['model']['activation']) + print (' -- Kernel Reg : ', self.params['model']['kernel_reg']) + print (' -- Model TBoard : ', self.params['model']['model_tboard']) + print (' -- Profiler : ', self.params['model']['profiler']['profile']) + if self.params['model']['profiler']['profile']: + print (' ---- Profiler Epochs: ', self.params['model']['profiler']['epochs']) + print (' ---- Step Per Epochs: ', self.params['model']['profiler']['steps_per_epoch']) + print (' - Optimizer: ', str(self.params['model']['optimizer'])) + print (' -- Init LR : ', self.params['model']['init_lr']) + print (' -- Fixed LR : ', self.params['model']['fixed_lr']) + print (' -- Grad Persistent: ', self.params['model']['grad_persistent']) + if self.params['model']['grad_persistent']: + print (' !!!!!!!!!!!!!!!!!!! GRAD PERSISTENT !!!!!!!!!!!!!!!!!!!') + print ('') + print (' - Epochs: ', self.params['model']['epochs']) + print (' -- Save : every {} epochs'.format(self.params['model']['epochs_save'])) + print (' -- Eval3D : every {} epochs '.format(self.params['model']['epochs_eval'])) + print (' -- Viz3D : every {} epochs '.format(self.params['model']['epochs_viz'])) + + print ('') + print (' METRICS ') + print (' ------- ') + print (' - Logging-TBoard: ', self.params['metrics']['logging_tboard']) + if not self.params['metrics']['logging_tboard']: + print (' !!!!!!!!!!!!!!!!!!! NO LOGGING-TBOARD !!!!!!!!!!!!!!!!!!!') + print ('') + print (' - Eval: ', self.params['metrics']['metrics_eval']) + print (' - Loss: ', self.params['metrics']['metrics_loss']) + print (' -- Type of Loss : ', self.params['metrics']['loss_type']) + print (' -- Weighted Loss : ', self.params['metrics']['loss_weighted']) + print (' -- Masked Loss : ', self.params['metrics']['loss_mask']) + print (' -- Combo : ', self.params['metrics']['loss_combo']) + print (' -- Loss Epoch : ', self.params['metrics']['loss_epoch']) + print (' -- Loss Rate : ', self.params['metrics']['loss_rate']) + + print ('') + print (' DEVOPS ') + print (' ------ ') + self.pid = os.getpid() + print (' - OS-PID: ', self.pid) + print (' - Seed: ', self.params['random_seed']) + + print ('') + + def _set_seed(self): + np.random.seed(self.params['random_seed']) + tf.random.set_seed(self.params['random_seed']) + + def _set_dataloaders(self): + + # Params - Directories + data_dir = self.params['dataloader']['data_dir'] + dir_type = self.params['dataloader']['dir_type'] + dir_type_eval = ['_'.join(dir_type)] + + # Params - Single volume + resampled = self.params['dataloader']['resampled'] + crop_init = self.params['dataloader']['crop_init'] + grid = self.params['dataloader']['grid'] + filter_grid = self.params['dataloader']['filter_grid'] + random_grid = self.params['dataloader']['random_grid'] + centred_prob = self.params['dataloader']['centred_prob'] + + # Params - Dataloader + batch_size = self.params['dataloader']['batch_size'] + prefetch_batch = self.params['dataloader']['prefetch_batch'] + parallel_calls = self.params['dataloader']['parallel_calls'] + shuffle_size = self.params['dataloader']['shuffle'] + + # Params - Debug + single_sample = self.params['dataloader']['single_sample'] + + + # Datasets + self.dataset_train = get_dataloader_3D_train(data_dir, dir_type=dir_type + , grid=grid, crop_init=crop_init, filter_grid=filter_grid + , random_grid=random_grid + , resampled=resampled, single_sample=single_sample + , parallel_calls=parallel_calls + , centred_dataloader_prob=centred_prob + ) + self.dataset_train_eval = get_dataloader_3D_train_eval(data_dir, dir_type=dir_type_eval + , grid=grid, crop_init=crop_init + , resampled=resampled, single_sample=single_sample + ) + self.dataset_test_eval = get_dataloader_3D_test_eval(data_dir + , grid=grid, crop_init=crop_init + , resampled=resampled, single_sample=single_sample + ) + + # Get labels Ids + self.label_map = dict(self.dataset_train.get_label_map()) + self.label_ids = self.label_map.values() + self.params['internal'] = {} + self.params['internal']['label_map'] = self.label_map # for use in Metrics + self.params['internal']['label_ids'] = self.label_ids # for use in Metrics + self.label_weights = list(self.dataset_train.get_label_weights()) + + # Generators + self.dataset_train_gen = self.dataset_train.generator().repeat().shuffle(shuffle_size).batch(batch_size).apply(tf.data.experimental.prefetch_to_device(device='/GPU:0', buffer_size=prefetch_batch)) + self.dataset_train_eval_gen = self.dataset_train_eval.generator().batch(2).prefetch(prefetch_batch) + self.dataset_test_eval_gen = self.dataset_test_eval.generator().batch(2).prefetch(prefetch_batch) + + def _set_model(self): + + # Step 1 - Get class ids + class_count = len(self.label_ids) + deepsup = self.params['model']['deepsup'] + activation = self.params['model']['activation'] + + # Step 2 - Get model arch + self.kl_schedule = self.params['model']['kl_schedule'] + if self.kl_schedule == config.KL_DIV_FIXED: + self.kl_alpha_init = self.params['model']['kl_alpha_init'] + elif self.kl_schedule == config.KL_DIV_ANNEALING: + if 0: + self.kl_alpha_init = 0.100 # [0.100, 0.050, 0.020, 0.010] + self.kl_alpha_increase_per_epoch = 0.001 + self.kl_alpha_max = 0.030 + self.initial_epoch = 250 + self.kl_epochs_change = 10 + elif 1: + self.kl_alpha_init = 0.05 + self.kl_alpha_increase_per_epoch = 0.0001 + self.initial_epoch = 250 + elif 0: + self.kl_alpha_init = 0.001 + self.kl_alpha_increase_per_epoch = 0.0001 + self.initial_epoch = 250 + + if self.params['model']['name'] == config.MODEL_FOCUSNET_FLIPOUT: + print (' - [Trainer][_models()] ModelFocusNetFlipOut') + self.model = models.ModelFocusNetFlipOut(class_count=class_count, trainable=True, activation=activation, deepsup=deepsup) + + elif self.params['model']['name'] == config.MODEL_FOCUSNET_FLIPOUT_V2: + print (' - [Trainer][_models()] ModelFocusNetFlipOutV2') + self.model = models.ModelFocusNetFlipOutV2(class_count=class_count, trainable=True, activation=activation, deepsup=deepsup) + + elif self.params['model']['name'] == config.MODEL_FOCUSNET_FLIPOUT_POOL2: + print (' - [Trainer][_models()] ModelFocusNetFlipOutPool2') + self.model = models.ModelFocusNetFlipOutPool2(class_count=class_count, trainable=True, activation=activation, deepsup=deepsup) + + # print (' - [Trainer][_set_model()] initial_epoch: ', self.initial_epoch) + print (' - [Trainer][_set_model()] kl_alpha_init: ', self.kl_alpha_init) + # print (' - [Trainer][_set_model()] kl_alpha_increase_per_epoch: ', self.kl_alpha_increase_per_epoch) + + # Step 3 - Get optimizer + if self.params['model']['optimizer'] == config.OPTIMIZER_ADAM: + self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.params['model']['init_lr']) + + # Step 4 - Load model if needed + epochs = self.params['model']['epochs'] + if not self.params['model']['load_model']['load']: + # Step 4.1 - Set epoch range under non-loading situations + self.epoch_range = range(1,epochs+1) + else: + + # Step 4.2.1 - Some model-loading params + load_epoch = self.params['model']['load_model']['load_epoch'] + load_exp_name = self.params['model']['load_model']['load_exp_name'] + load_optimizer_lr = self.params['model']['load_model']['load_optimizer_lr'] + load_model_params = {'PROJECT_DIR': self.params['PROJECT_DIR'], 'load_epoch': load_epoch, 'optimizer':self.optimizer} + + print ('') + print (' - [Trainer][_set_model()] Loading pretrained model') + print (' - [Trainer][_set_model()] Model: ', self.model) + + # Step 4.2.2.1 - If loading is done from the same exp_name + if load_exp_name is None: + load_model_params['exp_name'] = exp_name + self.epoch_range = range(load_epoch+1, epochs+1) + print (' - [Trainer][_set_model()] Training from epoch:{} to {}'.format(load_epoch, epochs)) + # Step 4.2.2.1 - If loading is done from another exp_name + else: + self.epoch_range = range(1, epochs+1) + load_model_params['exp_name'] = load_exp_name + print (' - [Trainer][_set_model()] Training from epoch:{} to {}'.format(1, epochs)) + + print (' - [Trainer][_set_model()] exp_name: ', load_model_params['exp_name']) + + # Step 4.3 - Finally, load model from the checkpoint + utils.load_model(self.model, load_type=config.MODE_TRAIN, params=load_model_params) + print (' - [Trainer][_set_model()] Model Loaded at epoch-{} !'.format(load_epoch)) + print (' -- [Trainer][_set_model()] Optimizer.lr : ', self.optimizer.lr.numpy()) + if load_optimizer_lr is not None: + self.optimizer.lr.assign(load_optimizer_lr) + print (' -- [Trainer][_set_model()] Optimizer.lr : ', self.optimizer.lr.numpy()) + + # Step 5 - Creae model weights + # init_size = ((1,240,240,40,1)) + init_size = ((1,140,140,40,1)) + print ('\n -- [Trainer][_set_model()] Mode weight creation with ', init_size, '\n') + X_tmp = tf.random.normal(init_size) # if the final dataloader does not have the same input size, the weight initialization gets screwed up. + _ = self.model(X_tmp) + self.layers_kl_std = self.get_layers_kl_std(std=True) + print (' -- [Trainer][_set_model()] Created model weights ') + try: + print (' --------------------------------------- ') + print (self.model.summary(line_length=150)) + print (' --------------------------------------- ') + count = 0 + for var in self.model.trainable_variables: + print (' - var: ', var.name) + count += 1 + if count > 20: + print (' ... ') + break + except: + print (' - [Trainer][_set_model()] model.summary() failed') + pass + + def _set_metrics(self): + + self.metrics = {} + self.metrics[config.MODE_TRAIN] = ModelMetrics(metric_type=config.MODE_TRAIN, params=self.params) + self.metrics[config.MODE_TEST] = ModelMetrics(metric_type=config.MODE_TEST, params=self.params) + + deepsup = self.params['model']['deepsup'] + if deepsup: + self.metrics[config.MODE_TRAIN_DEEPSUP] = ModelMetrics(metric_type=config.MODE_TRAIN_DEEPSUP, params=self.params) + self.metrics[config.MODE_TEST_DEEPSUP] = ModelMetrics(metric_type=config.MODE_TEST_DEEPSUP, params=self.params) + + def _set_profiler(self, epoch, epoch_step): + exp_name = self.params['exp_name'] + + if self.params['model']['profiler']['profile']: + if epoch in self.params['model']['profiler']['epochs']: + if epoch_step == self.params['model']['profiler']['starting_step']: + self.logdir = Path(config.MODEL_CHKPOINT_MAINFOLDER).joinpath(exp_name, config.MODEL_LOGS_FOLDERNAME, 'profiler', str(epoch)) + tf.profiler.experimental.start(str(self.logdir)) + print (' - tf.profiler.experimental.start(logdir)') + print ('') + elif epoch_step == self.params['model']['profiler']['starting_step'] + self.params['model']['profiler']['steps_per_epoch']: + print (' - tf.profiler.experimental.stop()') + tf.profiler.experimental.stop() + print ('') + + @tf.function + def get_layers_kl_std(self, std=False): + + res = {} + for layer in self.model.layers: + for loss_id, loss in enumerate(layer.losses): + layer_name = layer.name + '_' + str(loss_id) + res[layer_name] = {'kl': loss} + + if std: + if hasattr(layer, 'conv_layer'): + if loss_id == 0: + mean_vals = layer.conv_layer.submodules[1].kernel_posterior.distribution.loc + std_vals = layer.conv_layer.submodules[1].kernel_posterior.distribution.scale + res[layer_name]['std'] = std_vals + res[layer_name]['mean'] = mean_vals + elif loss_id == 1: + mean_vals = layer.conv_layer.submodules[3].kernel_posterior.distribution.loc + std_vals = layer.conv_layer.submodules[3].kernel_posterior.distribution.scale + res[layer_name]['std'] = std_vals + res[layer_name]['mean'] = mean_vals + + elif hasattr(layer, 'convblock_res'): + if loss_id == 0: + mean_vals = layer.convblock_res.conv_layer.submodules[1].kernel_posterior.distribution.loc + std_vals = layer.convblock_res.conv_layer.submodules[1].kernel_posterior.distribution.scale + res[layer_name]['std'] = std_vals + res[layer_name]['mean'] = mean_vals + elif loss_id == 1: + mean_vals = layer.convblock_res.conv_layer.submodules[3].kernel_posterior.distribution.loc + std_vals = layer.convblock_res.conv_layer.submodules[3].kernel_posterior.distribution.scale + res[layer_name]['std'] = std_vals + res[layer_name]['mean'] = mean_vals + + elif type(layer) == tfp.layers.Convolution3DFlipout: + mean_vals = layer.kernel_posterior.distribution.loc + std_vals = layer.kernel_posterior.distribution.scale + res[layer_name]['std'] = std_vals + res[layer_name]['mean'] = mean_vals + + return res + + @tf.function + def _train_loss(self, Y, y_predict, meta1, epoch, mode): + + trainMetrics = self.metrics[mode] + metrics_loss = self.params['metrics']['metrics_loss'] + loss_weighted = self.params['metrics']['loss_weighted'] + loss_mask = self.params['metrics']['loss_mask'] + loss_type = self.params['metrics']['loss_type'] + loss_combo = self.params['metrics']['loss_combo'] + loss_epoch = self.params['metrics']['loss_epoch'] + loss_rate = self.params['metrics']['loss_rate'] + + label_ids = self.label_ids + label_weights = tf.cast(self.label_weights, dtype=tf.float32) + + loss_vals = tf.constant(0.0, dtype=tf.float32) + mask = losses.get_mask(meta1[:,-len(label_ids):], Y) + + inf_flag = False + nan_flag = False + + for metric_str in metrics_loss: + + weights = [] + if loss_weighted[metric_str]: + weights = label_weights + + if not loss_mask[metric_str]: + mask = tf.cast(tf.cast(mask + 1, dtype=tf.bool), dtype=tf.float32) + + loss_epoch_metric = tf.constant(loss_epoch[metric_str], dtype=tf.float32) + + if metrics_loss[metric_str] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_FOCAL, config.LOSS_PAVPU] and epoch >= loss_epoch_metric: + loss_val_train, loss_labellist_train, metric_val_report, metric_labellist_report = trainMetrics.losses_obj[metric_str](Y, y_predict, label_mask=mask, weights=weights) + if metrics_loss[metric_str] in [config.LOSS_DICE, config.LOSS_CE, config.LOSS_FOCAL, config.LOSS_PAVPU]: + nan_list = tf.math.is_nan(loss_labellist_train) + nan_val = tf.math.is_nan(loss_val_train) + inf_list = tf.math.is_inf(loss_labellist_train) + inf_val = tf.math.is_inf(loss_val_train) + if nan_val or tf.math.reduce_any(nan_list): + nan_flag = True + tf.print (' - [ERROR][Trainer][_train_loss()] Loss NaN spotted: ', metric_str, ' || nan_list: ', nan_list, ' || nan_val: ', nan_val) + tf.print (' - [ERROR][Trainer][_train_loss()] Loss NaN spotted: ', metric_str, ' || mask: ', mask) + tf.print (' - [ERROR][Trainer][_train_loss()] Loss NaN spotted: ', metric_str, ' || loss_vals: ', loss_vals) + elif inf_val or tf.math.reduce_any(inf_list): + inf_flag = True + tf.print (' - [ERROR][Trainer][_train_loss()] Loss Inf spotted: ', metric_str, ' || loss_val_train: ', loss_val_train) + tf.print (' - [ERROR][Trainer][_train_loss()] Loss Inf spotted: ', metric_str, ' || inf_list: ', inf_list, ' || inf_val: ', inf_val) + tf.print (' - [ERROR][Trainer][_train_loss()] Loss Inf spotted: ', metric_str, ' || mask: ', mask) + tf.print (' - [ERROR][Trainer][_train_loss()] Loss NaN spotted: ', metric_str, ' || loss_vals: ', loss_vals) + else: + if len(metric_labellist_report): + trainMetrics.update_metric_loss_labels(metric_str, metric_labellist_report) # in sub-3D settings, this value is only indicative of performance + trainMetrics.update_metric_loss(metric_str, loss_val_train) + + if 0: # [DEBUG] + tf.print(' - metric_str: ', metric_str, ' || loss_val_train: ', loss_val_train) + + if loss_type[metric_str] == config.LOSS_SCALAR: + if metric_str in loss_combo: + print (' - [Trainer][_train_loss()] Scalar loss') + if loss_epoch_metric == 0.0: + loss_val_train = loss_val_train*loss_combo[metric_str] + else: + loss_rate_metric = tf.constant(loss_rate[metric_str], dtype=tf.float32) + loss_factor = (epoch - loss_epoch_metric)/loss_rate_metric + if loss_factor > 1:loss_factor = tf.constant(1.0, dtype=tf.float32) + loss_val_train = loss_val_train*loss_combo[metric_str]*loss_factor + loss_vals = tf.math.add(loss_vals, loss_val_train) # Averaged loss + + elif loss_type[metric_str] == config.LOSS_VECTOR: + if metric_str in loss_combo: + print (' - [Trainer][_train_loss()] Vector loss') + loss_labellist_train = loss_labellist_train*loss_combo[metric_str] + loss_vals = tf.math.add(loss_vals, loss_labellist_train) # Averaged loss by label + + if nan_flag or inf_flag: + print (' - [INFO][Trainer][_train_loss()] loss_vals:', loss_vals) + loss_vals = 0.0 # no backprop when something was wrong + + return loss_vals + + @tf.function + def _train_step(self, X, Y, meta1, meta2, kl_alpha, epoch): + + try: + + if 1: + model = self.model + deepsup = self.params['model']['deepsup'] + optimizer = self.optimizer + grad_persistent = self.params['model']['grad_persistent'] + trainMetrics = self.metrics[config.MODE_TRAIN] + kl_scale_fac = self.params['model']['kl_scale_factor'] + MC_RUNS = 5 + + y_predict = None + loss_vals = None + gradients = None + + # Step 1 - Calculate loss + with tf.GradientTape(persistent=grad_persistent) as tape: + + + t2 = tf.timestamp() + if deepsup: (y_predict_deepsup, y_predict) = model(X, training=True) + else : y_predict = model(X, training=True) + t2_ = tf.timestamp() + + loss_vals = self._train_loss(Y, y_predict, meta1, epoch, mode=config.MODE_TRAIN) + + if loss_vals > 0: + if deepsup: + print (' - [Trainer][_train_step()] deepsup training') + Y_deepsup = Y[:,::2,::2,::2,:] + loss_vals_deepsup = self._train_loss(Y_deepsup, y_predict_deepsup, meta1, epoch, mode=config.MODE_TRAIN_DEEPSUP) + loss_vals += loss_vals_deepsup + + print (' - [Trainer][_train_step()] Model FlipOut') + kl = tf.math.add_n(model.losses) + kl_loss = kl*kl_alpha/kl_scale_fac + + kl_layers = {} + for layer in model.layers: + for loss_id, loss in enumerate(layer.losses): + layer_name = layer.name + '_' + str(loss_id) + kl_layers[layer_name] = {'kl': loss} + trainMetrics.update_metrics_kl(kl_alpha, kl, kl_layers) + trainMetrics.update_metrics_scalarloss(loss_vals, kl_loss) + + if 0: + with tf.GradientTape(watch_accessed_variables=False) as meh: + print (' - [Trainer][_train_step()] Loopy Loss Expectation') + loss_loops = 2 + for _ in tf.range(loss_loops-1): + y_predict = model(X, training=False) + loss_vals = tf.math.add(loss_vals, self._train_loss(Y, y_predict, meta1)) + loss_vals = tf.math.divide(loss_vals, loss_loops) + + loss_vals = loss_vals + kl_loss + + # Step 2 - Calculate gradients + t3 = tf.timestamp() + if not tf.math.reduce_any(tf.math.is_nan(loss_vals)) and loss_vals > 0: + all_vars = model.trainable_variables + + gradients = tape.gradient(loss_vals, all_vars) # dL/dW + + # Step 3 - Apply gradients + optimizer.apply_gradients(zip(gradients, all_vars)) + else: + tf.print('\n ====================== [NaN Error] ====================== ') + tf.print(' - [ERROR][Trainer][_train_step()] Loss NaN spotted || loss_vals: ', loss_vals) + tf.print(' - [ERROR][Trainer][_train_step()] meta2: ', meta2, ' || meta1: ', meta1) + + t3_ = tf.timestamp() + return t2_-t2, t3-t2_, t3_-t3 + + except tf.errors.ResourceExhaustedError as e: + print (' - [ERROR][Trainer][_train_step()] OOM error') + return None, None, None + + except: + tf.print('\n ====================== [Some Error] ====================== ') + tf.print(' - [ERROR][Trainer][_train_step()] meta2: ', meta2, ' || meta1: ', meta1) + traceback.print_exc() + return None, None, None + + def train(self): + + # Global params + exp_name = self.params['exp_name'] + + # Dataloader params + batch_size = self.params['dataloader']['batch_size'] + + # Model/Training params + fixed_lr = self.params['model']['fixed_lr'] + init_lr = self.params['model']['init_lr'] + max_epoch = self.params['model']['epochs'] + epoch_range = iter(self.epoch_range) + epoch_length = len(self.dataset_train) + deepsup = self.params['model']['deepsup'] + params_save_model = {'PROJECT_DIR': self.params['PROJECT_DIR'], 'exp_name': exp_name, 'optimizer':self.optimizer} + + # Metrics params + metrics_eval = self.params['metrics']['metrics_eval'] + trainMetrics = self.metrics[config.MODE_TRAIN] + trainMetrics.init_metrics_layers_kl_std(self.params, self.layers_kl_std) + trainMetricsDeepSup = None + if deepsup: trainMetricsDeepSup = self.metrics[config.MODE_TRAIN_DEEPSUP] + + # Eval Params + params_eval = {'PROJECT_DIR': self.params['PROJECT_DIR'], 'exp_name': exp_name, 'pid': self.pid + , 'eval_type': config.MODE_TRAIN, 'batch_size': batch_size} + + # Viz params + epochs_save = self.params['model']['epochs_save'] + epochs_viz = self.params['model']['epochs_viz'] + epochs_eval = self.params['model']['epochs_eval'] + + # KL Divergence Params + kl_alpha = self.kl_alpha_init # [0.0, self.kl_alpha_init] + + # Random vars + t_start_time = time.time() + + epoch = None + try: + + epoch_step = 0 + pbar = None + t1 = time.time() + for (X,Y,meta1,meta2) in self.dataset_train_gen: + t1_ = time.time() + + try: + # Epoch starter code + if epoch_step == 0: + + # Get Epoch + epoch = next(epoch_range) + + # Metrics + trainMetrics.reset_metrics(self.params) + + # LR + if not fixed_lr: + set_lr(epoch, self.optimizer, init_lr) + self.model.trainable = True + + # Calculate kl_alpha (commented if alpha is fixed) + if self.kl_schedule == config.KL_DIV_ANNEALING: + if epoch > self.initial_epoch: + if epoch % self.kl_epochs_change == 0: + kl_alpha = tf.math.minimum(self.kl_alpha_max, self.kl_alpha_init + (epoch - self.initial_epoch)/float(self.kl_epochs_change) * self.kl_alpha_increase_per_epoch) + + # Pretty print + print ('') + print (' ===== [{}] EPOCH:{}/{} (LR={:3f}, kl_alpha={:3f}) =================='.format(exp_name, epoch, max_epoch,self.optimizer.lr.numpy(), kl_alpha)) + + # Start a fresh pbar + pbar = tqdm.tqdm(total=epoch_length, desc='') + + # Model Writing to tensorboard + if self.params['model']['model_tboard'] and self.write_model_done is False : + self.write_model_done = True + utils.write_model_tboard(self.model, X, self.params) + + # Start/Stop Profiling (after dataloader is kicked off) + self._set_profiler(epoch, epoch_step) + + # Calculate loss and gradients from them + time_predict, time_loss, time_backprop = self._train_step(X, Y, meta1, meta2, tf.constant(kl_alpha, dtype=tf.float32), tf.constant(epoch, dtype=tf.float32)) + + # Update metrics (time + eval + plots) + time_dataloader = t1_ - t1 + trainMetrics.update_metrics_time(time_dataloader, time_predict, time_loss, time_backprop) + + # Update looping stuff + epoch_step += batch_size + pbar.update(batch_size) + trainMetrics.update_pbar(pbar) + + except: + utils.print_exp_name(exp_name + '-' + config.MODE_TRAIN, epoch) + params_save_model['epoch'] = epoch + utils.save_model(self.model, params_save_model) + traceback.print_exc() + + if epoch_step >= epoch_length: + + # Reset epoch-loop params + pbar.close() + epoch_step = 0 + + try: + # Model save + if epoch % epochs_save == 0: + params_save_model['epoch'] = epoch + utils.save_model(self.model, params_save_model) + + # Tensorboard for std + if epoch % epochs_save == 0: + layers_kl_std = self.get_layers_kl_std(std=True) + trainMetrics.write_epoch_summary_std(layers_kl_std, epoch=epoch) + + # Eval on full 3D + if epoch % epochs_eval == 0: + self.params['epoch'] = epoch + save=False + if epoch > 0 and epoch % epochs_viz == 0: + save=True + + self.model.trainable = False + for metric_str in metrics_eval: + if metrics_eval[metric_str] in [config.LOSS_DICE]: + params_eval['epoch'] = epoch + params_eval['deepsup'] = deepsup + params_eval['deepsup_eval'] = False + params_eval['eval_type'] = config.MODE_TRAIN + eval_avg, eval_labels_avg = eval_3D(self.model, self.dataset_train_eval, self.dataset_train_eval_gen, params_eval, save=save) + trainMetrics.update_metric_eval_labels(metric_str, eval_labels_avg, do_average=True) + + if deepsup: + params_eval['deepsup'] = deepsup + params_eval['deepsup_eval'] = True + params_eval['eval_type'] = config.MODE_TRAIN_DEEPSUP + eval_avg, eval_labels_avg = eval_3D(self.model, self.dataset_train_eval, self.dataset_train_eval_gen, params_eval, save=save) + trainMetricsDeepSup.update_metric_eval_labels(metric_str, eval_labels_avg, do_average=True) + + # Test + if epoch % epochs_eval == 0: + self._test() + self.model.trainable = True + + # Epochs summary/wrapup + eval_condition = epoch % epochs_eval == 0 + trainMetrics.write_epoch_summary(epoch, self.label_map, {'optimizer':self.optimizer}, eval_condition) + if deepsup: + trainMetricsDeepSup.write_epoch_summary(epoch, self.label_map, {'optimizer':self.optimizer}, eval_condition) + + if epoch > 0 and epoch % self.params['others']['epochs_timer'] == 0: + elapsed_seconds = time.time() - t_start_time + print (' - Total time elapsed : {}'.format( str(datetime.timedelta(seconds=elapsed_seconds)) )) + if epoch % self.params['others']['epochs_memory'] == 0: + mem_before = utils.get_memory(self.pid) + gc_n = gc.collect() + mem_after = utils.get_memory(self.pid) + print(' - Unreachable objects collected by GC: {} || ({}) -> ({})'.format(gc_n, mem_before, mem_after)) + + # Break out of loop at end of all epochs + if epoch == max_epoch: + print ('\n\n - [Trainer][train()] All epochs finished') + break + + except: + utils.print_exp_name(exp_name + '-' + config.MODE_TRAIN, epoch) + params_save_model['epoch'] = epoch + utils.save_model(self.model, params_save_model) + traceback.print_exc() + pdb.set_trace() + + t1 = time.time() # reset dataloader time calculator + + except: + utils.print_exp_name(exp_name + '-' + config.MODE_TRAIN, epoch) + traceback.print_exc() + + def _test(self): + + exp_name = None + epoch = None + try: + + # Step 1.1 - Params + exp_name = self.params['exp_name'] + epoch = self.params['epoch'] + deepsup = self.params['model']['deepsup'] + + metrics_eval = self.params['metrics']['metrics_eval'] + epochs_viz = self.params['model']['epochs_viz'] + batch_size = self.params['dataloader']['batch_size'] + + # vars + testMetrics = self.metrics[config.MODE_TEST] + testMetrics.reset_metrics(self.params) + testMetricsDeepSup = None + if deepsup: + testMetricsDeepSup = self.metrics[config.MODE_TEST_DEEPSUP] + testMetricsDeepSup.reset_metrics(self.params) + params_eval = {'PROJECT_DIR': self.params['PROJECT_DIR'], 'exp_name': exp_name, 'pid': self.pid + , 'eval_type': config.MODE_TEST, 'batch_size': batch_size + , 'epoch':epoch} + + # Step 2 - Eval on full 3D + save=False + if epoch > 0 and epoch % epochs_viz == 0: + save=True + for metric_str in metrics_eval: + if metrics_eval[metric_str] in [config.LOSS_DICE]: + params_eval['deepsup'] = deepsup + params_eval['deepsup_eval'] = False + params_eval['eval_type'] = config.MODE_TEST + eval_avg, eval_labels_avg = eval_3D(self.model, self.dataset_test_eval, self.dataset_test_eval_gen, params_eval, save=save) + testMetrics.update_metric_eval_labels(metric_str, eval_labels_avg, do_average=True) + + if deepsup: + params_eval['deepsup'] = deepsup + params_eval['deepsup_eval'] = True + params_eval['eval_type'] = config.MODE_TEST_DEEPSUP + eval_avg, eval_labels_avg = eval_3D(self.model, self.dataset_test_eval, self.dataset_test_eval_gen, params_eval, save=save) + testMetricsDeepSup.update_metric_eval_labels(metric_str, eval_labels_avg, do_average=True) + + testMetrics.write_epoch_summary(epoch, self.label_map, {}, True) + if deepsup: + testMetricsDeepSup.write_epoch_summary(epoch, self.label_map, {}, True) + + except: + utils.print_exp_name(exp_name + '-' + config.MODE_TEST, epoch) + traceback.print_exc() + pdb.set_trace() + +if __name__ == "__main__": + + ## Step 0 - Exp Name + + # exp_name = 'Grid_FocusNetFlipV21632GNorm_FixedKL010_33samB214014040PreGridNorm_WithNegCEScalar10_seed42' # <---- This + + # exp_name = 'Fin__FocusNetFlipV2-GNorm-1632-FixedKl010__33sam-B2-PreWindowNorm-24024040__CEScalar10-WithNeg_seed42' + # exp_name = 'Fin__FocusNetFlipV2-GNorm-1632-FixedKl010__33sam-B2-PreWindowNorm-24024040__10PA01Thresh05vPU-CEScalar10-WithNeg_seed42' + + # exp_name = 'Grid__FocusNetFlipV2-GNorm-1224-FixedKl010__33sam-B2-PreWindowNorm-14014040__CEScalar10-WithNeg_seed42' + # exp_name = 'Grid__FocusNetFlipV2-GNorm-1224-FixedKl010__33sam-B2-PreWindowNorm-14014040__CEScalar10-WithNeg_seed42_v2' + # exp_name = 'Grid__FocusNetFlipV2-GNorm-1224-FixedKl005__33sam-B2-PreWindowNorm-14014040__CEScalar10-WithNeg_seed42' + # exp_name = 'Grid__FocusNetFlipV2-GNorm-1224-FixedKl001__33sam-B2-PreWindowNorm-14014040__CEScalar10-WithNeg_seed42' + + # exp_name = 'Grid__FocusNetFlipV2-GNorm-1632-FixedKl001__33sam-B2-PreWindowNorm-14014040__05PAvPU-Th05-nAC005-1CEScalar10-WithoutNeg_seed42_v2' + # exp_name = 'Grid__FocusNetFlipV2-GNorm-1632-FixedKl001__33sam-B2-PreWindowNorm-14014040__05PAvPU-Th05-nAC005-Annealed-Ep300-Fac10-1CEScalar10-WithNeg_seed42' + # exp_name = 'Grid__FocusNetFlipV2-GNorm-1632-FixedKl001__33sam-B2-PreWindowNorm-14014040__05PAvPU-Th04-nAC005-Annealed-Ep300-Fac10-1CEScalar10-WithNeg_seed42' + # exp_name = 'Grid__FocusNetFlipV2-GNorm-1632-FixedKl001__33sam-B2-PreWindowNorm-14014040__10PAvPU-Th04-nAC005-Annealed-Ep300-Fac10-1CEScalar10-WithNeg_seed42' + + # exp_name = 'Grid__FocusNetFlipV2-GNorm-1632-FixedKl001__33sam-B2-PreWindowNorm-14014040__05PAvPU-HErr-Th05-Annealed-Ep300-Fac10-1CEScalar10-WithNeg_seed42' + # exp_name = 'tmp_Grid__FocusNetFlipV2-GNorm-1632-FixedKl001__33sam-B2-PreWindowNorm-14014040__05PAvPU-HErr-Th05-Annealed-Ep300-Fac10-1CEScalar10-WithNeg_seed42' + + exp_name = 'Grid__FocusNetFlipV2-GNorm-1632-FixedKl001__38sam-B2-PreWindowNorm-14014040__CEScalar10-WithNeg_seed42' + + if 0: + + ## Step 1 - Params + params = { + 'PROJECT_DIR': config.PROJECT_DIR + , 'random_seed':42 + , 'exp_name': exp_name + , 'dataloader':{ + 'data_dir': Path(config.MAIN_DIR).joinpath('medical_dataloader', '_data') + , 'dir_type': [config.DATALOADER_MICCAI2015_TRAIN, config.DATALOADER_MICCAI2015_TRAIN_ADD, config.DATALOADER_MICCAI2015_TESTONSITE] + # , 'dir_type' : [config.DATALOADER_MICCAI2015_TRAIN, config.DATALOADER_MICCAI2015_TRAIN_ADD] + , 'resampled' : True + , 'crop_init' : True + , 'grid' : True + , 'random_grid' : True + , 'filter_grid' : False + , 'centred_prob' : 0.3 + , 'batch_size' : 2 # [2,4,8] + , 'shuffle' : 5 # [5,5,12] + , 'prefetch_batch': 4 # [4,3,2] + , 'parallel_calls': 4 # [4,4,4] + , 'single_sample' : False # [ !!!!!!!! WATCH OUT !!!!!!!!! ] + } + , 'model': { + 'name': config.MODEL_FOCUSNET_FLIPOUT_V2 # [MODEL_FOCUSNET_FLIPOUT, MODEL_FOCUSNET_FLIPOUT_V2, MODEL_FOCUSNET_FLIPOUT_POOL2] + , 'kl_alpha_init' : 0.01 + , 'kl_schedule' : config.KL_DIV_FIXED # [config.KL_DIV_FIXED, config.KL_DIV_ANNEALING] + , 'kl_scale_factor' : 154 # distribute the kl_div loss 'kl_scale_factor' times across an epoch i.e. dataset_len/batch_size=> + # |(240,240,80)= 78.0/2.0|, [240,240,40]=66/2, |(140,140,40)=308/2.0=154, 264/2=132|, |(96,96,40)=684/2.0=342| + , 'deepsup' : False + , 'activation': tf.keras.activations.softmax # ['softmax', 'tf.keras.activations.sigmoid'] + , 'kernel_reg': False + , 'optimizer' : config.OPTIMIZER_ADAM + , 'grad_persistent': False + , 'init_lr' : 0.001 # 0.01 # 0.005 + , 'fixed_lr' : False + , 'epochs' : 1500 # 1000 + , 'epochs_save': 50 # 20 + , 'epochs_eval': 50 # 40 + , 'epochs_viz' : 500 # 100 + , 'load_model':{ + 'load':False, 'load_exp_name': None, 'load_epoch':-1, 'load_optimizer_lr':None # DEFAULT OPTION + # 'load_exp_name':'Grid_FocusNetFlip1632_E250KL0001_WSum1My_14014040_NoAug_FocalScalar09_seed42', 'load':True, 'load_epoch':200 + # 'load_exp_name':None, 'load':True, 'load_epoch':100 + } + , 'profiler': { + 'profile': False + , 'epochs': [252,253] + , 'steps_per_epoch': 30 + , 'starting_step': 2 + } + , 'model_tboard': False + } + , 'metrics' : { + 'logging_tboard': True + # for full 3D volume + , 'metrics_eval': {'Dice': config.LOSS_DICE} + ## for smaller grid + # , 'metrics_loss':{'Dice': config.LOSS_DICE} # [config.LOSS_CE, config.LOSS_DICE, config.LOSS_FOCAL] + # , 'loss_weighted': {'Dice': True} + # , 'loss_mask': {'Dice':True} + # , 'loss_type': {'Dice': config.LOSS_VECTOR} # [config.LOSS_SCALAR, config.LOSS_VECTOR] + # , 'loss_combo': {'Dice':1.0} + , 'metrics_loss' : {'CE': config.LOSS_CE} + , 'loss_type' : {'CE': config.LOSS_SCALAR} + , 'loss_weighted': {'CE': True} + , 'loss_mask' : {'CE': True} + , 'loss_combo' : {'CE': 1.0} + , 'loss_epoch' : {'CE': 0} + , 'loss_rate' : {'CE': 1} + # , 'metrics_loss' : {'Focal': config.LOSS_FOCAL} + # , 'loss_type' : {'Focal': config.LOSS_SCALAR} + # , 'loss_weighted': {'Focal': True} + # , 'loss_mask' : {'Focal': True} + # , 'loss_combo' : {'Focal': 1.0} + # , 'metrics_loss':{'Focal': config.LOSS_FOCAL, 'Dice': config.LOSS_DICE} # [config.LOSS_CE, config.LOSS_DICE, config.LOSS_FOCAL] + # , 'loss_type': {'Focal': config.LOSS_SCALAR, 'Dice': config.LOSS_SCALAR} + # , 'loss_weighted': {'Focal': True, 'Dice': False} + # , 'loss_mask': {'Focal':True, 'Dice': True} + # , 'loss_combo': {'Focal': 1.0, 'Dice': 1.0} #[, ['Dice', 'CE']] + # , 'metrics_loss' : {'CE': config.LOSS_CE, 'pavpu': config.LOSS_PAVPU} # [config.LOSS_CE, config.LOSS_DICE, config.LOSS_FOCAL] + # , 'loss_weighted' : {'CE': True, 'pavpu': False} + # , 'loss_mask' : {'CE': True, 'pavpu': False} + # , 'loss_type' : {'CE': config.LOSS_SCALAR, 'pavpu': config.LOSS_SCALAR} + # , 'loss_combo' : {'CE': 1.0, 'pavpu': 5.0} + # , 'loss_epoch' : {'CE': 0 , 'pavpu': 300} + # , 'loss_rate' : {'CE': 1 , 'pavpu':10} + } + , 'others': { + 'epochs_timer': 20 + , 'epochs_memory':5 + } + } + + # Call the trainer + trainer = Trainer(params) + trainer.train() + + elif 1: + + # Step 1 - Dataloaders + resampled = True + single_sample = False + grid = True + crop_init = True + data_dir = Path(config.MAIN_DIR).joinpath('medical_dataloader', '_data') + if 0: + # dir_type = [config.DATALOADER_MICCAI2015_TRAIN, config.DATALOADER_MICCAI2015_TRAIN_ADD, config.DATALOADER_MICCAI2015_TESTONSITE]; eval_type = config.MODE_TRAIN_VAL; dir_type = ['_'.join(dir_type)] + dir_type = [config.DATALOADER_MICCAI2015_TEST]; eval_type = config.MODE_TEST + # dir_type = [config.DATALOADER_MICCAI2015_TESTONSITE]; eval_type = config.MODE_TEST + + dataset_test_eval = get_dataloader_3D_test_eval(data_dir, dir_type=dir_type + , grid=grid, crop_init=crop_init + , resampled=resampled, single_sample=single_sample + ) + else: + dir_type = [medconfig.DATALOADER_DEEPMINDTCIA_TEST]; annotator_type = [medconfig.DATALOADER_DEEPMINDTCIA_ONC]; eval_type = config.MODE_DEEPMINDTCIA_TEST_ONC + # dir_type = [medconfig.DATALOADER_DEEPMINDTCIA_TEST]; annotator_type = [medconfig.DATALOADER_DEEPMINDTCIA_RAD]; eval_type = config.MODE_DEEPMINDTCIA_TEST_RAD + dataset_test_eval = get_dataloader_deepmindtcia(data_dir, dir_type=dir_type, annotator_type=annotator_type + , grid=grid, crop_init=crop_init, resampled=resampled + , single_sample=single_sample + ) + + # Step 2 - Model Loading + class_count = len(dataset_test_eval.datasets[0].LABEL_MAP.values()) + activation = 'softmax' # tf.keras.activations.softmax + deepsup = False + # model = models.ModelFocusNetFlipOutPool2(class_count=class_count, trainable=False, activation=activation) + # model = models.ModelFocusNetFlipOut(class_count=class_count, trainable=False, activation=activation) + model = models.ModelFocusNetFlipOutV2(class_count=class_count, trainable=False, activation=activation) + + # Step 3 - Final function call + params = { + 'PROJECT_DIR' : config.PROJECT_DIR + , 'exp_name' : exp_name + , 'pid' : os.getpid() + , 'batch_size' : 2 + , 'prefetch_batch': 1 + , 'epoch' : 500 # [1000, 1500, 2500] + , 'eval_type' : eval_type + , 'MC_RUNS' : 30 + , 'deepsup' : deepsup + , 'deepsup_eval' : False + , 'training_bool' : True # [True=dropout-at-test-time, False=no-dropout-at-test-time] + } + print ('\n - eval_type: ', params['eval_type']) + print (' - epoch : ', params['epoch']) + print (' - MC_RUNS : ', params['MC_RUNS']) + val(model, dataset_test_eval, params, show=False, save=True, verbose=False) + + diff --git a/src/model/utils.py b/src/model/utils.py new file mode 100644 index 0000000..3c47491 --- /dev/null +++ b/src/model/utils.py @@ -0,0 +1,585 @@ +# Import internal libraries +import src.config as config + +# Import external libraries +import os +import pdb +import copy +import tqdm +import json +import psutil +import humanize +import traceback +import numpy as np +import matplotlib +from pathlib import Path +import tensorflow as tf + +from src.config import PROJECT_DIR + +############################################################ +# MODEL RELATED # +############################################################ +def save_model(model, params): + """ + The phrase "Saving a TensorFlow model" typically means one of two things: + - using the Checkpoints format (this code does checkpointing) + - using the SavedModel format. + - Ref: https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint + - ckpt_obj = tf.train.Checkpoint(optimizer=optimizer, model=model); ckpt_obj.save(file_prefix='my_model_ckpt') + - tf.keras.Model.save_weights('my_model_save_weights') + - tf.keras.Model.save('my_model_save') + - Questions + - What is the SavedModel format + - SavedModel saves the execution graph. + """ + try: + PROJECT_DIR = params['PROJECT_DIR'] + exp_name = params['exp_name'] + epoch = params['epoch'] + + folder_name = config.MODEL_CHKPOINT_NAME_FMT.format(epoch) + model_folder = Path(PROJECT_DIR).joinpath(config.MODEL_CHKPOINT_MAINFOLDER, exp_name, folder_name) + model_folder.mkdir(parents=True, exist_ok=True) + model_path = Path(model_folder).joinpath(folder_name) + + optimizer = params['optimizer'] + ckpt_obj = tf.train.Checkpoint(optimizer=optimizer, model=model) + ckpt_obj.save(file_prefix=model_path) + + if 0: # CHECK + model.save(str(Path(model_path).joinpath('model_save'))) # SavedModel format: creates a folder "model_save" with assets/, variables/ (contains weights) and a saved_model.pb (model architecture) + model.save_weights(str(Path(model_path).joinpath('model_save_weights'))) + + except: + traceback.print_exc() + pdb.set_trace() + +def load_model(model, load_type, params): + """ + - Ref: https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint + """ + try: + + PROJECT_DIR = params['PROJECT_DIR'] + exp_name = params['exp_name'] + load_epoch = params['load_epoch'] + + folder_name = config.MODEL_CHKPOINT_NAME_FMT.format(load_epoch) + model_folder = Path(PROJECT_DIR).joinpath(config.MODEL_CHKPOINT_MAINFOLDER, exp_name, folder_name) + + if load_type == config.MODE_TRAIN: + if 'optimizer' in params: + ckpt_obj = tf.train.Checkpoint(optimizer=params['optimizer'], model=model) + # ckpt_obj.restore(save_path=tf.train.latest_checkpoint(str(model_folder))).assert_existing_objects_matched() # shows errors + # ckpt_obj.restore(save_path=tf.train.latest_checkpoint(str(model_folder))).assert_consumed() + ckpt_obj.restore(save_path=tf.train.latest_checkpoint(str(model_folder))).expect_partial() + else: + print (' - [ERROR][utils.load_model] Optimizer not passed !') + pdb.set_trace() + + elif load_type in [config.MODE_VAL, config.MODE_TEST]: + ckpt_obj = tf.train.Checkpoint(model=model) + ckpt_obj.restore(save_path=tf.train.latest_checkpoint(str(model_folder))).expect_partial() + + elif load_type == config.MODE_VAL_NEW: + model.load_weights(str(Path(model_folder).joinpath('model.h5')), by_name=True, skip_mismatch=False) + + else: + print (' - [ERROR][utils.load_model] It should not be here!') + pdb.set_trace() + # tf.keras.Model.load_weights + # tf.train.list_variables(tf.train.latest_checkpoint(str(model_folder))) + + except: + traceback.print_exc() + pdb.set_trace() + +def get_tensorboard_writer(exp_name, suffix): + try: + import tensorflow as tf + + logdir = Path(config.MODEL_CHKPOINT_MAINFOLDER).joinpath(exp_name, config.MODEL_LOGS_FOLDERNAME, suffix) + writer = tf.summary.create_file_writer(str(logdir)) + return writer + + except: + traceback.print_exc() + pdb.set_trace() + +def make_summary(fieldname, epoch, writer1=None, value1=None, writer2=None, value2=None): + try: + import tensorflow as tf + + if writer1 is not None and value1 is not None: + with writer1.as_default(): + tf.summary.scalar(fieldname, value1, epoch) + writer1.flush() + if writer2 is not None and value2 is not None: + with writer2.as_default(): + tf.summary.scalar(fieldname, value2, epoch) + writer2.flush() + except: + traceback.print_exc() + pdb.set_trace() + +def make_summary_hist(fieldname, epoch, writer1=None, value1=None, writer2=None, value2=None): + try: + import tensorflow as tf + + if writer1 is not None and value1 is not None: + with writer1.as_default(): + tf.summary.histogram(fieldname, value1, epoch) + writer1.flush() + if writer2 is not None and value2 is not None: + with writer2.as_default(): + tf.summary.histogram(fieldname, value2, epoch) + writer2.flush() + except: + traceback.print_exc() + pdb.set_trace() + +def write_model(model, X, params, suffix='model'): + """ + - Ref: + - https://www.tensorflow.org/api_docs/python/tf/summary/trace_on + - https://www.tensorflow.org/api_docs/python/tf/summary/trace_export + - https://www.tensorflow.org/api_docs/python/tf/keras/utils/plot_model + - https://stackoverflow.com/questions/56690089/how-to-graph-tf-keras-model-in-tensorflow-2-0 + """ + + # Step 1 - Start trace + tf.summary.trace_on(graph=True, profiler=False) + + # Step 2 - Perform operation + _ = write_model_trace(model, X) + + # Step 3 - Export trace + writer = get_tensorboard_writer(params['exp_name'], suffix) + with writer.as_default(): + tf.summary.trace_export(name=model.name, step=0, profiler_outdir=None) + writer.flush() + + # Step 4 - Save as .png + # tf.keras.utils.plot_model(model, to_file='model.png', show_shapes=True, expand_nested=True) # only works for the functional API. :( + +@tf.function +def write_model_trace(model, X): + return model(X) + +def set_lr(epoch, optimizer): + # if epoch == 200: + # optimizer.lr.assign(0.0001) + pass + +############################################################ +# WRITE RELATED # +############################################################ + +class NpEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + else: + return super(NpEncoder, self).default(obj) + +def write_json(json_data, json_filepath): + + Path(json_filepath).parent.absolute().mkdir(parents=True, exist_ok=True) + + with open(str(json_filepath), 'w') as fp: + json.dump(json_data, fp, indent=4, cls=NpEncoder) + +############################################################ +# DEBUG RELATED # +############################################################ +def print_break_msg(): + print ('') + print (' ========================= break operator applied here =========================') + print ('') + +def get_memory_usage(filename): + mem = os.popen("ps aux | grep %s | awk '{sum=sum+$6}; END {print sum/1024 \" MB\"}'"% (filename)).read() + return mem.rstrip() + +def print_exp_name(exp_name, epoch): + print ('') + print (' [ERROR] ========================= {} (epoch={}) ========================='.format(exp_name, epoch)) + print ('') + +def get_memory(pid): + try: + process = psutil.Process(pid) + return humanize.naturalsize(process.memory_info().rss) + except: + return '-1' + +def get_tf_gpu_memory(): + # https://www.tensorflow.org/api_docs/python/tf/config/experimental/get_memory_usage + gpu_devices = tf.config.list_physical_devices('GPU') + if len(gpu_devices): + memory_bytes = tf.config.experimental.get_memory_usage(gpu_devices[0].name.split('/physical_device:')[-1]) + return 'GPU: {:2f}GB'.format(memory_bytes/1024.0/1024.0/1024.0) + else: + return '-1' + +############################################################ +# DATALOADER RELATED # +############################################################ +def get_info_from_label_id(label_id, label_map, label_colors=None): + """ + The label_id param has to be greater than 0 + """ + + label_name = [label for label in label_map if label_map[label] == label_id] + if len(label_name): + label_name = label_name[0] + else: + label_name = None + + if label_colors is not None: + label_color = np.array(label_colors[label_id]) + if np.any(label_color > 0): + label_color = label_color/255.0 + else: + label_color = None + + return label_name, label_color + +def cmap_for_dataset(label_colors): + cmap_me = matplotlib.colors.ListedColormap(np.array([*label_colors.values()])/255.0) + norm = matplotlib.colors.BoundaryNorm(boundaries=range(0,cmap_me.N+1), ncolors=cmap_me.N) + + return cmap_me, norm + +############################################################ +# EVAL RELATED # +############################################################ +def get_eval_folders(PROJECT_DIR, exp_name, epoch, mode, mc_runs=None, training_bool=None, create=False): + folder_name = config.MODEL_CHKPOINT_NAME_FMT.format(epoch) + + if mc_runs is not None and training_bool is not None: # During manual eval + if mc_runs == 1 and training_bool == False: + model_folder_epoch_save = Path(PROJECT_DIR).joinpath(config.MODEL_CHKPOINT_MAINFOLDER, exp_name, folder_name, config.MODEL_IMGS_FOLDERNAME, mode + config.SUFFIX_DET) + else: + model_folder_epoch_save = Path(PROJECT_DIR).joinpath(config.MODEL_CHKPOINT_MAINFOLDER, exp_name, folder_name, config.MODEL_IMGS_FOLDERNAME, mode + config.SUFFIX_MC.format(mc_runs)) + else: # During automated training + eval + model_folder_epoch_save = Path(PROJECT_DIR).joinpath(config.MODEL_CHKPOINT_MAINFOLDER, exp_name, folder_name, config.MODEL_IMGS_FOLDERNAME, mode) + + model_folder_epoch_patches = Path(model_folder_epoch_save).joinpath('patches') + model_folder_epoch_imgs = Path(model_folder_epoch_save).joinpath('imgs') + + if create: + Path(model_folder_epoch_patches).mkdir(parents=True, exist_ok=True) + Path(model_folder_epoch_imgs).mkdir(parents=True, exist_ok=True) + + return model_folder_epoch_patches, model_folder_epoch_imgs + +############################################################ +# 3D # +############################################################ + +def viz_model_output_3d_old(X, y_true, y_predict, y_predict_std, patient_id, path_save, label_map, label_colors, VMAX_STD=0.3): + """ + X : [H,W,D,1] + y_true: [H,W,D,Labels] + + Takes only a single batch of data + """ + try: + import matplotlib + import matplotlib.pyplot as plt + import skimage + import skimage.measure + + Path(path_save).mkdir(exist_ok=True, parents=True) + + labels_ids = sorted(list(label_map.values())) + cmap_me, norm_me = cmap_for_dataset(label_colors) + cmap_img = 'rainbow' + VMIN = 0 + VMAX_STD = VMAX_STD + VMAX_MEAN = 1.0 + + slice_ids = list(range(X.shape[2])) + # with tqdm.tqdm(total=len(slice_ids), leave=False, desc='ID:' + patient_id) as pbar_slices: + with tqdm.tqdm(total=len(slice_ids), leave=False) as pbar_slices: + for slice_id in range(X.shape[2]): + + # Data + X_slice = X[:,:,slice_id,0] + y_true_slice = y_true[:,:,slice_id,:] + y_predict_slice = y_predict[:,:,slice_id,:] + y_predict_std_slice = y_predict_std[:,:,slice_id,:] + + # Matplotlib figure + filename = '\n'.join(patient_id.split('-')) + suptitle_str = 'Slice: {}'.format(filename) + fig = plt.figure(figsize=(15,15), dpi=200) + fig_std = plt.figure(figsize=(15,15), dpi=200) + spec = fig.add_gridspec(nrows=2 + np.ceil(len(labels_ids)/5).astype(int), ncols=5) + spec_std = fig_std.add_gridspec(nrows=2 + np.ceil(len(labels_ids)/5).astype(int), ncols=5) + fig.suptitle(suptitle_str + '\n Predictive Mean') + fig_std.suptitle(suptitle_str + '\n Predictive Std') + + # Top two images + ax3 = fig.add_subplot(spec[0, 4]) + ax3.imshow(X_slice, cmap='gray') + ax3.axis('off') + ax3.set_title('Raw data') + ax3_std = fig_std.add_subplot(spec_std[0, 4]) + ax3_std.imshow(X_slice, cmap='gray') + ax3_std.axis('off') + ax3_std.set_title('Raw data') + img_slice_mask_plot = np.zeros(y_true_slice[:,:,0].shape) + + # Other images + i,j = 2,0 + for label_id in labels_ids: + if label_id not in config.IGNORE_LABELS: + + # Get ground-truth and prediction slices + img_slice_mask_predict = copy.deepcopy(y_predict_slice[:,:,label_id]) + img_slice_mask_predict_std = copy.deepcopy(y_predict_std_slice[:,:,label_id]) + img_slice_mask_gt = copy.deepcopy(y_true_slice[:,:,label_id]) + + # Plot prediction heatmap + if j >= 5: + i = i + 1 + j = 0 + ax = fig.add_subplot(spec[i,j]) + ax_std = fig_std.add_subplot(spec_std[i,j]) + j += 1 + + # img_slice_mask_predict[img_slice_mask_predict > config.PREDICT_THRESHOLD_MASK] = label_id + ax.imshow(img_slice_mask_predict, interpolation='none', cmap=cmap_img, vmin=0, vmax=VMAX_MEAN) + ax_std.imshow(img_slice_mask_predict_std, interpolation='none', cmap=cmap_img, vmin=0, vmax=VMAX_STD) + + # Plot Gt contours + label, color = get_info_from_label_id(label_id, label_map, label_colors) + if label_id == 0: + ax4 = fig.add_subplot(spec[1, 4]) + fig.colorbar(matplotlib.cm.ScalarMappable(cmap=cmap_img), ax=ax4) + ax4.axis('off') + ax4.set_title('Colorbar') + + ax4_std = fig_std.add_subplot(spec_std[1, 4]) + fig_std.colorbar(matplotlib.cm.ScalarMappable(cmap=cmap_img, norm=matplotlib.colors.Normalize(vmin=0, vmax=VMAX_STD)), ax=ax4_std) + ax4_std.axis('off') + ax4_std.set_title('Colorbar') + + else: + contours_mask = skimage.measure.find_contours(img_slice_mask_gt, level=0.99) + for _, contour_mask in enumerate(contours_mask): + ax3.plot(contour_mask[:, 1], contour_mask[:, 0], linewidth=2, color=color) + ax3_std.plot(contour_mask[:, 1], contour_mask[:, 0], linewidth=2, color=color) + + if label is not None: + ax.set_title(label + '(' + str(label_id) + ')', color=color) + ax_std.set_title(label + '(' + str(label_id) + ')', color=color) + ax.axis('off') + ax_std.axis('off') + + # Gather GT mask + idxs_gt = np.argwhere(img_slice_mask_gt > 0) + img_slice_mask_plot[idxs_gt[:,0], idxs_gt[:,1]] = label_id + + # GT mask + ax1 = fig.add_subplot(spec[0:2, 0:2]) + ax1.imshow(img_slice_mask_plot, cmap=cmap_me, norm=norm_me, interpolation='none') + ax1.set_title('GT Mask') + ax1.tick_params(labelsize=6) + ax1_std = fig_std.add_subplot(spec_std[0:2, 0:2]) + ax1_std.imshow(img_slice_mask_plot, cmap=cmap_me, norm=norm_me, interpolation='none') + ax1_std.set_title('GT Mask') + ax1_std.tick_params(labelsize=6) + + # Predicted Mask + ax2 = fig.add_subplot(spec[0:2, 2:4]) + ax2.imshow(np.argmax(y_predict_slice, axis=2), cmap=cmap_me, norm=norm_me, interpolation='none') + ax2.set_title('Predicted Mask (mean)') + ax2.tick_params(labelsize=6) + ax2_std = fig_std.add_subplot(spec_std[0:2, 2:4]) + ax2_std.imshow(np.argmax(y_predict_slice, axis=2), cmap=cmap_me, norm=norm_me, interpolation='none') + ax2_std.set_title('Predicted Mask (mean)') + ax2_std.tick_params(labelsize=6) + + # Show and save + # path_savefig = Path(model_folder_epoch_images).joinpath(filename_meta.replace('.npy','.png')) + path_savefig = Path(path_save).joinpath(patient_id + '_' + '%.3d' % (slice_id) + '_mean.png') + fig.savefig(str(path_savefig), bbox_inches='tight') + path_savefig_std = Path(path_save).joinpath(patient_id + '_' + '%.3d' % (slice_id) + '_std.png') + fig_std.savefig(str(path_savefig_std), bbox_inches='tight') + plt.close(fig=fig) + plt.close(fig=fig_std) + pbar_slices.update(1) + + except: + traceback.print_exc() + pdb.set_trace() + +def viz_model_output_3d(exp_name, X, y_true, y_predict, y_predict_unc, patient_id, path_save, label_map, label_colors, vmax_unc=0.3, unc_title='', unc_savesufix=''): + """ + X : [H,W,D,1] + y_true: [H,W,D,Labels] + + Takes only a single batch of data + """ + try: + import matplotlib + import matplotlib.pyplot as plt + import skimage + import skimage.measure + + Path(path_save).mkdir(exist_ok=True, parents=True) + + labels_ids = sorted(list(label_map.values())) + cmap_me, norm_me = cmap_for_dataset(label_colors) + cmap_img = 'rainbow' + VMIN = 0 + VMAX_UNC = vmax_unc + VMAX_MEAN = 1.0 + + slice_ids = list(range(X.shape[2])) + # with tqdm.tqdm(total=len(slice_ids), leave=False, desc='ID:' + patient_id) as pbar_slices: + with tqdm.tqdm(total=len(slice_ids), leave=False) as pbar_slices: + for slice_id in range(X.shape[2]): + + if slice_id < 20: + pbar_slices.update(1) + continue + + # Data + X_slice = X[:,:,slice_id,0] + y_true_slice = y_true[:,:,slice_id,:] + y_predict_slice = y_predict[:,:,slice_id,:] + y_predict_unc_slice = y_predict_unc[:,:,slice_id,:] + + # Matplotlib figure + filename = '\n'.join(patient_id.split('-')) + suptitle_str = 'Exp: {}\nPatient: {}\nSlice: {}'.format(exp_name, filename, slice_id) + fig = plt.figure(figsize=(15,15), dpi=200) + fig_unc = plt.figure(figsize=(15,15), dpi=200) + spec = fig.add_gridspec(nrows=2 + np.ceil(len(labels_ids)/5).astype(int), ncols=5) + spec_unc = fig_unc.add_gridspec(nrows=2 + np.ceil(len(labels_ids)/5).astype(int), ncols=5) + fig.suptitle(suptitle_str + '\n Predictive Mean') + fig_unc.suptitle(suptitle_str + '\n {}'.format(unc_title)) + + # Top two images + ax3 = fig.add_subplot(spec[0, 4]) + ax3.imshow(X_slice, cmap='gray') + ax3.axis('off') + ax3.set_title('Raw data') + ax3_unc = fig_unc.add_subplot(spec_unc[0, 4]) + ax3_unc.imshow(X_slice, cmap='gray') + ax3_unc.axis('off') + ax3_unc.set_title('Raw data') + img_slice_mask_plot = np.zeros(y_true_slice[:,:,0].shape) + + # Other images + i,j = 2,0 + for label_id in labels_ids: + if label_id not in config.IGNORE_LABELS: + + # Get ground-truth and prediction slices + img_slice_mask_predict = copy.deepcopy(y_predict_slice[:,:,label_id]) + img_slice_mask_predict_unc = copy.deepcopy(y_predict_unc_slice[:,:,label_id]) + img_slice_mask_gt = copy.deepcopy(y_true_slice[:,:,label_id]) + + # Plot prediction heatmap + if j >= 5: + i = i + 1 + j = 0 + ax = fig.add_subplot(spec[i,j]) + # ax_unc = fig_unc.add_subplot(spec_unc[i,j]) + j += 1 + + # img_slice_mask_predict[img_slice_mask_predict > config.PREDICT_THRESHOLD_MASK] = label_id + ax.imshow(img_slice_mask_predict, interpolation='none', cmap=cmap_img, vmin=0, vmax=VMAX_MEAN) + # ax_unc.imshow(img_slice_mask_predict_unc, interpolation='none', cmap=cmap_img, vmin=0, vmax=VMAX_UNC) + + # Plot Gt contours + label, color = get_info_from_label_id(label_id, label_map, label_colors) + if label_id == 0: + ax4 = fig.add_subplot(spec[1, 4]) + fig.colorbar(matplotlib.cm.ScalarMappable(cmap=cmap_img), ax=ax4) + ax4.axis('off') + ax4.set_title('Colorbar') + + ax4_unc = fig_unc.add_subplot(spec_unc[1, 4]) + fig_unc.colorbar(matplotlib.cm.ScalarMappable(cmap=cmap_img, norm=matplotlib.colors.Normalize(vmin=0, vmax=VMAX_UNC)), ax=ax4_unc) + ax4_unc.axis('off') + ax4_unc.set_title('Colorbar') + + else: + contours_mask = skimage.measure.find_contours(img_slice_mask_gt, level=0.99) + for _, contour_mask in enumerate(contours_mask): + ax3.plot(contour_mask[:, 1], contour_mask[:, 0], linewidth=2, color=color) + ax3_unc.plot(contour_mask[:, 1], contour_mask[:, 0], linewidth=2, color=color) + + if label is not None: + ax.set_title(label + '(' + str(label_id) + ')', color=color) + # ax_unc.set_title(label + '(' + str(label_id) + ')', color=color) + ax.axis('off') + # ax_unc.axis('off') + + # Gather GT mask + idxs_gt = np.argwhere(img_slice_mask_gt > 0) + img_slice_mask_plot[idxs_gt[:,0], idxs_gt[:,1]] = label_id + + # if 'ent' in unc_savesufix: + # ax_unc.cla() + + # GT mask + ax1 = fig.add_subplot(spec[0:2, 0:2]) + ax1.imshow(img_slice_mask_plot, cmap=cmap_me, norm=norm_me, interpolation='none') + ax1.set_title('GT Mask') + ax1.tick_params(labelsize=6) + ax1_unc = fig_unc.add_subplot(spec_unc[0:2, 0:2]) + ax1_unc.imshow(img_slice_mask_plot, cmap=cmap_me, norm=norm_me, interpolation='none') + ax1_unc.set_title('GT Mask') + ax1_unc.tick_params(labelsize=6) + + # Predicted Mask + ax2 = fig.add_subplot(spec[0:2, 2:4]) + ax2.imshow(np.argmax(y_predict_slice, axis=2), cmap=cmap_me, norm=norm_me, interpolation='none') + ax2.set_title('Predicted Mask (mean)') + ax2.tick_params(labelsize=6) + ax2_unc = fig_unc.add_subplot(spec_unc[0:2, 2:4]) + ax2_unc.imshow(np.argmax(y_predict_slice, axis=2), cmap=cmap_me, norm=norm_me, interpolation='none') + ax2_unc.set_title('Predicted Mask (mean)') + ax2_unc.tick_params(labelsize=6) + + # Specifically for entropy + if unc_savesufix in ['stdmax', 'ent', 'mif']: + + ax_unc_gt = fig_unc.add_subplot(spec[2:4, 0:2]) + ax_unc_gt.imshow(img_slice_mask_predict_unc, interpolation='none', cmap=cmap_img, vmin=0, vmax=VMAX_UNC) + slice_binary_gt = img_slice_mask_plot + slice_binary_gt[slice_binary_gt > 0] = 1 + ax_unc_gt.imshow(slice_binary_gt, cmap='gray', interpolation='none', alpha=0.3) + + ax_unc_pred = fig_unc.add_subplot(spec[2:4, 2:4]) + ax_unc_pred.imshow(img_slice_mask_predict_unc, interpolation='none', cmap=cmap_img, vmin=0, vmax=VMAX_UNC) + slice_binary_pred = np.argmax(y_predict_slice, axis=2) + slice_binary_pred[slice_binary_pred > 0] = 1 + ax_unc_pred.imshow(slice_binary_pred, cmap='gray', interpolation='none', alpha=0.3) + + # Show and save + # path_savefig = Path(model_folder_epoch_images).joinpath(filename_meta.replace('.npy','.png')) + path_savefig = Path(path_save).joinpath(patient_id + '_' + '%.3d' % (slice_id) + '_mean.png') + fig.savefig(str(path_savefig), bbox_inches='tight') + path_savefig_unc = Path(path_save).joinpath(patient_id + '_' + '%.3d' % (slice_id) + '_{}.png'.format(unc_savesufix)) + fig_unc.savefig(str(path_savefig_unc), bbox_inches='tight') + plt.close(fig=fig) + plt.close(fig=fig_unc) + pbar_slices.update(1) + + except: + traceback.print_exc() + pdb.set_trace() + \ No newline at end of file