diff --git a/README.md b/README.md index 844aead..ff7c5f8 100644 --- a/README.md +++ b/README.md @@ -1,42 +1,59 @@ -# PACKAGE_NAME +# Deep Quality Estimation -[![Python Versions](https://img.shields.io/pypi/pyversions/PACKAGE_NAME)](https://pypi.org/project/PACKAGE_NAME/) -[![Stable Version](https://img.shields.io/pypi/v/PACKAGE_NAME?label=stable)](https://pypi.python.org/pypi/PACKAGE_NAME/) -[![Documentation Status](https://readthedocs.org/projects/PACKAGE_NAME/badge/?version=latest)](http://PACKAGE_NAME.readthedocs.io/?badge=latest) -[![tests](https://github.com/BrainLesion/PACKAGE_NAME/actions/workflows/tests.yml/badge.svg)](https://github.com/BrainLesion/PACKAGE_NAME/actions/workflows/tests.yml) -[![codecov](https://codecov.io/gh/BrainLesion/PACKAGE_NAME/graph/badge.svg?token=A7FWUKO9Y4)](https://codecov.io/gh/BrainLesion/PACKAGE_NAME) +[![Python Versions](https://img.shields.io/pypi/pyversions/deep_quality_estimation)](https://pypi.org/project/deep_quality_estimation/) +[![Stable Version](https://img.shields.io/pypi/v/deep_quality_estimation?label=stable)](https://pypi.python.org/pypi/deep_quality_estimation/) +[![Documentation Status](https://readthedocs.org/projects/deep_quality_estimation/badge/?version=latest)](http://deep_quality_estimation.readthedocs.io/?badge=latest) +[![tests](https://github.com/BrainLesion/deep_quality_estimation/actions/workflows/tests.yml/badge.svg)](https://github.com/BrainLesion/deep_quality_estimation/actions/workflows/tests.yml) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) + -Description -## Features +Quality prediction for brain tumor segmentation on scale ranging from 1 to 6 stars ⭐. +Can be used to estimate the quality of a segmentation for evaluation purposes or as e.g. as part of a loss function during model training. +> [!NOTE] +> This package expects images in atlas space and segementation labels in brats style, i.e. label 1 is the necrotic and non-enhancing tumor core, label 2 is the peritumoral edema, label 3 is the GD-enhancing tumor (used to be label 4 in older datasets, both are supported) ## Installation -With a Python 3.8+ environment, you can install `PACKAGE_NAME` directly from [PyPI](https://pypi.org/project/brats/): +With a Python 3.9+ environment, you can install `deep_quality_estimation` directly from [PyPI](https://pypi.org/project/deep_quality_estimation/): ```bash -pip install PACKAGE_NAME +pip install deep_quality_estimation ``` ## Use Cases and Tutorials -A minimal example to create a segmentation could look like this: +A minimal example to predict the quality of a segmentation could look like this: ```python - # example -``` + from deep_quality_estimation import DQE + + # shown parameters are default values but can be adapted to usecase + dqe = DQE(device="cuda", cuda_devices="0") - + # inputs can be Paths (str or pathlib.Path object), NumPy NDArrays or a mix + mean_score, scores_per_view = dqe.predict( + t1c="t1c.nii.gz", t1="t1.nii.gz", t2="t2.nii.gz", flair="flair.nii.gz", segmentation="segmentation.nii.gz" + ) +``` ## Citation -If you use PACKAGE_NAME in your research, please cite it to support the development! +If you use deep_quality_estimation in your research, please cite it to support the development! +https://arxiv.org/abs/2205.10355 ``` -TODO: https://arxiv.org/abs/2205.10355 +@misc{kofler2022deepqualityestimationcreating, + title={Deep Quality Estimation: Creating Surrogate Models for Human Quality Ratings}, + author={Florian Kofler and Ivan Ezhov and Lucas Fidon and Izabela Horvath and Ezequiel de la Rosa and John LaMaster and Hongwei Li and Tom Finck and Suprosanna Shit and Johannes Paetzold and Spyridon Bakas and Marie Piraud and Jan Kirschke and Tom Vercauteren and Claus Zimmer and Benedikt Wiestler and Bjoern Menze}, + year={2022}, + eprint={2205.10355}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2205.10355}, +} ``` ## Contributing @@ -45,7 +62,7 @@ We welcome all kinds of contributions from the community! ### Reporting Bugs, Feature Requests and Questions -Please open a new issue [here](https://github.com/BrainLesion/PACKAGE_NAME/issues). +Please open a new issue [here](https://github.com/BrainLesion/deep_quality_estimation/issues). ### Code contributions diff --git a/package_name/.gitignore b/deep_quality_estimation/.gitignore similarity index 100% rename from package_name/.gitignore rename to deep_quality_estimation/.gitignore diff --git a/deep_quality_estimation/__init__.py b/deep_quality_estimation/__init__.py new file mode 100644 index 0000000..b76c152 --- /dev/null +++ b/deep_quality_estimation/__init__.py @@ -0,0 +1,5 @@ +from loguru import logger + +logger.disable("deep_quality_estimation") # disable for use a library + +from .model import DQE diff --git a/deep_quality_estimation/center_of_mass.py b/deep_quality_estimation/center_of_mass.py new file mode 100644 index 0000000..78fedb5 --- /dev/null +++ b/deep_quality_estimation/center_of_mass.py @@ -0,0 +1,73 @@ +from pathlib import Path +from typing import List, Union + +import nibabel as nib +import numpy as np +import scipy +from numpy.typing import NDArray + +from deep_quality_estimation.enums import View + + +def compute_center_of_mass( + segmentation: Union[Path, NDArray], fallback_to_edema: bool = True +) -> List[int]: + """ + Compute the center of mass of the tumor core in the given segmentation. + If no tumor core is found and the @fallback_to_edema is true, the center of mass of the edema is used as a fallback. + + Args: + segmentation (Union[Path, NDArray]): Path to the segmentation file or the segmentation data as numpy NDArray. + fallback_to_edema (bool, optional): Use Edema CoM as fallback if no tumor core is found. Defaults to True. + + Returns: + List[int]: List of the center of mass coordinates + """ + if isinstance(segmentation, Path) or isinstance(segmentation, str): + segmentation = nib.load(segmentation).get_fdata() + + assert isinstance(segmentation, np.ndarray) + + mask = np.zeros(segmentation.shape) + + # get mask of tumor core + # TODO: verify if this is correct? (label change 4 to 3 in new standard?) + mask[segmentation == 1] = 1 + mask[segmentation == 3] = 1 + mask[segmentation == 4] = 1 + + if ( + np.sum(mask) == 0 and fallback_to_edema + ): # if no tumor core is found, use the edema CoM + mask[segmentation > 0] = 1 + + center_of_mass = scipy.ndimage.center_of_mass(mask) + # convert to int (cuts decimals) + center_of_mass = [int(x) for x in center_of_mass] + return center_of_mass + + +def get_center_of_mass_slices( + image: Union[Path, NDArray], + center_of_mass: List[int], +) -> dict[View, NDArray]: + """ + Get the slices of the given image that contain the center of mass in axial, coronal and sagittal view. + + Args: + image (Path): Path to the NiFTI image file + center_of_mass (List[int]): List of the center of mass coordinates + + Returns: + Dict[View, NDArray]: Dictionary with the views as keys and the 2D image slices as values + """ + if isinstance(image, Path) or isinstance(image, str): + image = nib.load(image).get_fdata() + + assert isinstance(image, np.ndarray) + + return { + View.AXIAL: image[:, :, center_of_mass[2]], + View.CORONAL: image[:, center_of_mass[1], :], + View.SAGITTAL: image[center_of_mass[0], :, :], + } diff --git a/deep_quality_estimation/dataloader.py b/deep_quality_estimation/dataloader.py new file mode 100644 index 0000000..c7cd93a --- /dev/null +++ b/deep_quality_estimation/dataloader.py @@ -0,0 +1,180 @@ +from pathlib import Path +from typing import Dict, Tuple, Union + +import numpy as np +from monai.data import Dataset, pad_list_data_collate +from monai.transforms import ( + Compose, + ConcatItemsd, + Lambdad, + ScaleIntensityRangePercentilesd, + SpatialPadd, + ToTensord, +) +from numpy.typing import NDArray +from torch.utils.data import DataLoader + +from deep_quality_estimation.center_of_mass import ( + compute_center_of_mass, + get_center_of_mass_slices, +) +from deep_quality_estimation.enums import View +from deep_quality_estimation.transforms import ( + CustomConvertToMultiChannelBasedOnBratsClassesd, +) + + +class DataHandler: + + ONLY_IMAGES = ["images"] + ONLY_LABELS = ["labels"] + ALL_CHANNELS = [*ONLY_IMAGES, *ONLY_LABELS] + + # ALL_CHANNELS + def __init__( + self, + t1c: Union[Path, NDArray], + t1: Union[Path, NDArray], + t2: Union[Path, NDArray], + flair: Union[Path, NDArray], + segmentation: Union[Path, NDArray], + ): + """ + Initialize the data handler + + Args: + t1c (Union[Path, NDArray]): Numpy NDArray or Path to the T1c NIfTI file + t1 (Union[Path, NDArray]): Numpy NDArray or Path to the T1 NIfTI file + t2 (Union[Path, NDArray]): Numpy NDArray or Path to the T2 NIfTI file + flair (Union[Path, NDArray]): Numpy NDArray or Path to the FLAIR NIfTI file + segmentation (Union[Path, NDArray]): Numpy NDArray or Path to the segmentation NIfTI file (In BraTS style) + """ + self.t1c = t1c + self.t2 = t2 + self.t1 = t1 + self.flair = flair + self.segmentation = segmentation + + def _get_transforms(self) -> Compose: + """ + Returns the transforms to be applied to the dataset + + Returns: + monai.transforms.Compose: A composition of the transforms + """ + transforms = Compose( + [ + Lambdad(self.ALL_CHANNELS, np.nan_to_num), + CustomConvertToMultiChannelBasedOnBratsClassesd(keys=self.ONLY_LABELS), + ScaleIntensityRangePercentilesd( + keys=self.ONLY_IMAGES, + lower=0.5, + upper=99.5, + b_min=0, + b_max=1, + clip=True, + relative=False, + # channel_wise=True, + channel_wise=False, + ), + # Pad all images to 240x240 (coronal and sagittal view will initially have 240 x155) + SpatialPadd( + keys=self.ALL_CHANNELS, spatial_size=(240, 240), mode="minimum" + ), # ensure at least + # make tensor + ConcatItemsd( + keys=self.ALL_CHANNELS, + name="inputs", + dim=0, + allow_missing_keys=False, + ), + ToTensord(keys=["inputs"]), # also include target! + ] + ) + + return transforms + + def _compute_slices( + self, + ) -> Tuple[ + Dict[View, NDArray], + Dict[View, NDArray], + Dict[View, NDArray], + Dict[View, NDArray], + Dict[View, NDArray], + ]: + """ + Computes the center of mass slices for the different views and all + + Returns: + Tuple[Dict[View, NDArray], Dict[View, NDArray], Dict[View, NDArray], Dict[View, NDArray], Dict[View, NDArray]]: The slices for the different views + """ + + center_of_mass = compute_center_of_mass(segmentation=self.segmentation) + t1c_slices = get_center_of_mass_slices( + image=self.t1c, center_of_mass=center_of_mass + ) + t1_slices = get_center_of_mass_slices( + image=self.t1, center_of_mass=center_of_mass + ) + t2_slices = get_center_of_mass_slices( + image=self.t2, center_of_mass=center_of_mass + ) + flair_slices = get_center_of_mass_slices( + image=self.flair, center_of_mass=center_of_mass + ) + segmentation_slices = get_center_of_mass_slices( + image=self.segmentation, center_of_mass=center_of_mass + ) + + return t1c_slices, t1_slices, t2_slices, flair_slices, segmentation_slices + + def _build_dataset(self) -> Dataset: + """ + Build the dataset consisting of CoM slices for per image for each view. + i.e. one sample is e.g. a AXIAL view with a 2D slice of each image and segmentation in Brats classes channel format + + Returns: + Dataset: The dataset + """ + + t1c_slices, t1_slices, t2_slices, flair_slices, segmentation_slices = ( + self._compute_slices() + ) + data_dicts = [] + for view in View: + data_dicts.append( + { + "images": [ + t1c_slices[view], + t1_slices[view], + t2_slices[view], + flair_slices[view], + ], + "labels": segmentation_slices[view], + "view": view.name, + } + ) + + return Dataset( + data=data_dicts, + transform=self._get_transforms(), + ) + + def get_dataloader(self) -> DataLoader: + """ + Get the dataloader for the dataset + + Returns: + DataLoader: The dataloader + """ + + dataset = self._build_dataset() + + return DataLoader( + dataset=dataset, + batch_size=1, + num_workers=8, + collate_fn=pad_list_data_collate, + shuffle=False, + ) diff --git a/deep_quality_estimation/enums.py b/deep_quality_estimation/enums.py new file mode 100644 index 0000000..3c2d427 --- /dev/null +++ b/deep_quality_estimation/enums.py @@ -0,0 +1,9 @@ +from enum import Enum, auto + + +class View(Enum): + """Enum for the different views of the brain""" + + AXIAL = auto() + CORONAL = auto() + SAGITTAL = auto() diff --git a/deep_quality_estimation/model.py b/deep_quality_estimation/model.py new file mode 100644 index 0000000..a48bf2f --- /dev/null +++ b/deep_quality_estimation/model.py @@ -0,0 +1,129 @@ +import os +from pathlib import Path +from typing import Dict, Optional, Tuple, Union + +import numpy as np +import torch +from loguru import logger +from monai.networks.nets import DenseNet121 +from numpy.typing import NDArray + +from deep_quality_estimation.dataloader import DataHandler +from deep_quality_estimation.enums import View + +PACKAGE_DIR = Path(__file__).parent + + +class DQE: + + def __init__( + self, device: Optional[torch.device] = None, cuda_devices: Optional[str] = "0" + ): + """ + Initialize the Deep Quality Estimation model + + Args: + device (Optional[torch.device], optional): Device to be used. Defaults to None. + cuda_devices (Optional[str], optional): Visible CUDA devices, e.g. "0", "0,1". Defaults to "0". + """ + + self.device = self._set_device( + device=device, + cuda_devices=cuda_devices, + ) + + self.model = self._load_model() + + def _set_device(self, device: Optional[torch.device], cuda_devices: Optional[str]): + """ + Set the device to be used for the model + + Args: + device (Optional[torch.device]): Device + cuda_devices (Optional[str]): Visible CUDA devices, e.g. "0", "0,1" + + Returns: + torch.device: Device to be used + """ + + os.environ["CUDA_VISIBLE_DEVICES"] = cuda_devices + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + device = torch.device(device) + logger.info(f"Using device: {device}") + return device + + def _load_model(self): + """ + Load model weights and return initialized model + + Returns: + monai.networks.nets.DenseNet121: Model + """ + + checkpoint_path = PACKAGE_DIR / "weights/dqe_weights.pth" + model = DenseNet121( + spatial_dims=2, in_channels=7, out_channels=1, pretrained=False + ) + + checkpoint = torch.load( + checkpoint_path, + map_location=self.device, + weights_only=True, + ) + + if self.device == torch.device("cpu"): + if "module." in list(checkpoint.keys())[0]: + checkpoint = { + k.replace("module.", ""): v for k, v in checkpoint.items() + } + else: + model = torch.nn.parallel.DataParallel(model) + + model.load_state_dict(checkpoint) + model = model.to(self.device) + + logger.info(f"Model loaded from {checkpoint_path} and initialized") + return model + + def predict( + self, + t1c: Union[Path, NDArray], + t1: Union[Path, NDArray], + t2: Union[Path, NDArray], + flair: Union[Path, NDArray], + segmentation: Union[Path, NDArray], + ) -> Tuple[float, Dict[View, float]]: + """ + Predict the quality of the given Segmentation + + Args: + t1c (Union[Path, NDArray]): Numpy NDArray or Path to the T1c NIfTI file + t1 (Union[Path, NDArray]): Numpy NDArray or Path to the T1 NIfTI file + t2 (Union[Path, NDArray]): Numpy NDArray or Path to the T2 NIfTI file + flair (Union[Path, NDArray]): Numpy NDArray or Path to the FLAIR NIfTI file + segmentation (Union[Path, NDArray]): Numpy NDArray or Path to the segmentation NIfTI file (In BraTS style) + + Returns: + Tuple[float, Dict[View, float]]: The predicted mean score and a dict with the scores per view + """ + + # load and preprocess data + data_handler = DataHandler( + t1c=t1c, t2=t2, t1=t1, flair=flair, segmentation=segmentation + ) + dataloader = data_handler.get_dataloader() + + # predict ratings + scores = {} + self.model.eval() + with torch.no_grad(): + for data in dataloader: + # assuming batch size 1 + logger.debug(f"Predicting for view: {data['view'][0]}") + inputs = data["inputs"].float().to(self.device) + outputs = self.model(inputs) + scores[data["view"][0]] = outputs.cpu().item() + mean_score = np.mean(list(scores.values())) + return mean_score, scores diff --git a/deep_quality_estimation/transforms.py b/deep_quality_estimation/transforms.py new file mode 100644 index 0000000..2ffa5f1 --- /dev/null +++ b/deep_quality_estimation/transforms.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from collections.abc import Hashable, Mapping + +import numpy as np +import torch +from monai.config import KeysCollection +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.transform import MapTransform, Transform +from monai.utils.enums import TransformBackends + + +class CustomConvertToMultiChannelBasedOnBratsClasses(Transform): + """ + Adapted from Monai's ConvertToMultiChannelBasedOnBratsClasses since label 4 changed to 3 in newer datasets + + Convert labels to multi channels based on brats18 classes: + label 1 is the necrotic and non-enhancing tumor core + label 2 is the peritumoral edema + label 3 is the GD-enhancing tumor (used to be label 4 in older datasets) + The possible classes are TC (Tumor core), WT (Whole tumor) + and ET (Enhancing tumor). + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + # if img has channel dim, squeeze it + if img.ndim == 4 and img.shape[0] == 1: + img = img.squeeze(0) + + # adapted mapping + result = [ + (img == 1) | (img == 3) | (img == 4), + (img == 1) | (img == 4) | (img == 3) | (img == 2), + (img == 4) | (img == 3), + ] + # merge labels 1 (tumor non-enh) and 3 (used to be 4) (tumor enh) and 2 (large edema) to WT + # label 3 (used to be 4) is ET + return ( + torch.stack(result, dim=0) + if isinstance(img, torch.Tensor) + else np.stack(result, axis=0) + ) + + +class CustomConvertToMultiChannelBasedOnBratsClassesd(MapTransform): + """ + Adapted from Monai's ConvertToMultiChannelBasedOnBratsClassesd + + Dictionary-based wrapper of :py:class:`monai.transforms.ConvertToMultiChannelBasedOnBratsClasses`. + Convert labels to multi channels based on brats18 classes: + label 1 is the necrotic and non-enhancing tumor core + label 2 is the peritumoral edema + label 3 is the GD-enhancing tumor (used to be label 4 in older datasets) + The possible classes are TC (Tumor core), WT (Whole tumor) + and ET (Enhancing tumor). + """ + + backend = CustomConvertToMultiChannelBasedOnBratsClasses.backend + + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False): + super().__init__(keys, allow_missing_keys) + self.converter = CustomConvertToMultiChannelBasedOnBratsClasses() + + def __call__( + self, data: Mapping[Hashable, NdarrayOrTensor] + ) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.converter(d[key]) + return d diff --git a/deep_quality_estimation/weights/dqe_weights.pth b/deep_quality_estimation/weights/dqe_weights.pth new file mode 100644 index 0000000..2e56797 Binary files /dev/null and b/deep_quality_estimation/weights/dqe_weights.pth differ diff --git a/docs/source/conf.py b/docs/source/conf.py index 7bc126f..d76ebed 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -38,7 +38,7 @@ # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -html_theme = "furo" +html_theme = "pydata-sphinx-theme" # html_static_path = ["_static"] autodoc_default_options = { diff --git a/pyproject.toml b/pyproject.toml index c086d3d..74103a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,21 +6,32 @@ build-backend = "poetry_dynamic_versioning.backend" enable = true [tool.poetry] -name = "package_name" +name = "deep_quality_estimation" version = "0.0.0" description = "" -authors = ["Florian Kofler "] -repository = "https://www.TODO.com" -homepage = "https://www.TODO.com" +authors = [ + "Marcel Rosier , Florian Kofler ", +] +repository = "https://github.com/BrainLesion/deep_quality_estimation" +homepage = "https://github.com/BrainLesion/deep_quality_estimation" documentation = "https://www.TODO.com" readme = "README.md" # Add the exclude field directly under [tool.poetry] -exclude = ["examples", "benchmark"] +exclude = ["examples", "benchmark", "tests", "docs"] [tool.poetry.dependencies] python = ">=3.9" +nibabel = "^5.3.2" +scipy = [ + { version = "^1.14.1", python = ">=3.10" }, + { version = "<1.14", python = "<3.10" }, +] +numpy = ">=1.21.2" +torch = "^2.5.1" +monai = "^1.4.0" +loguru = "^0.7.2" [tool.poetry.dev-dependencies] @@ -34,5 +45,5 @@ optional = true [tool.poetry.group.docs.dependencies] Sphinx = ">=7.0.0" sphinx-copybutton = ">=0.5.2" -furo = ">=2024.8.6" +pydata-sphinx-theme = ">=0.16.0" myst-parser = ">=2.0.0"