Skip to content

Commit

Permalink
Merge pull request #5 from BrainLesion/3-add-tests
Browse files Browse the repository at this point in the history
3 add tests
  • Loading branch information
neuronflow authored Nov 26, 2024
2 parents 53608b7 + b9f6e3e commit a837403
Show file tree
Hide file tree
Showing 12 changed files with 139 additions and 11 deletions.
27 changes: 17 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ Quality prediction for brain tumor segmentation on scale ranging from 1 to 6 sta
Can be used to estimate the quality of a segmentation for evaluation purposes or as e.g. as part of a loss function during model training.

> [!NOTE]
> This package expects images in atlas space and segementation labels in brats style, i.e. label 1 is the necrotic and non-enhancing tumor core, label 2 is the peritumoral edema, label 3 is the GD-enhancing tumor (used to be label 4 in older datasets, both are supported)
> This package expects images in atlas space and segmentation labels in brats style, i.e.
> - `label 1` is the necrotic and non-enhancing tumor core
> - `label 2` is the peritumoral edema
> - `label 3` is the GD-enhancing tumor (used to be `label 4` in older data, both are supported)
## Installation

Expand All @@ -27,15 +30,19 @@ pip install deep_quality_estimation
A minimal example to predict the quality of a segmentation could look like this:

```python
from deep_quality_estimation import DQE

# shown parameters are default values but can be adapted to usecase
dqe = DQE(device="cuda", cuda_devices="0")

# inputs can be Paths (str or pathlib.Path object), NumPy NDArrays or a mix
mean_score, scores_per_view = dqe.predict(
t1c="t1c.nii.gz", t1="t1.nii.gz", t2="t2.nii.gz", flair="flair.nii.gz", segmentation="segmentation.nii.gz"
)
from deep_quality_estimation import DQE

# shown parameters are default values but can be adapted to usecase
dqe = DQE(device="cuda", cuda_devices="0")

# inputs can be Paths (str or pathlib.Path object), NumPy NDArrays or a mix
mean_score, scores_per_view = dqe.predict(
t1c="t1c.nii.gz",
t1="t1.nii.gz",
t2="t2.nii.gz",
flair="flair.nii.gz",
segmentation="segmentation.nii.gz",
)
```


Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion deep_quality_estimation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from monai.networks.nets import DenseNet121
from numpy.typing import NDArray

from deep_quality_estimation.dataloader import DataHandler
from deep_quality_estimation.data_handler import DataHandler
from deep_quality_estimation.enums import View

PACKAGE_DIR = Path(__file__).parent
Expand Down
Binary file added tests/data/seg-BraTS23_1.nii.gz
Binary file not shown.
Binary file added tests/data/seg-BraTS23_1_mapped.nii.gz
Binary file not shown.
Binary file added tests/data/seg-bratstoolkit_isen.nii.gz
Binary file not shown.
Binary file added tests/data/t1c.nii.gz
Binary file not shown.
Binary file added tests/data/t1n.nii.gz
Binary file not shown.
Binary file added tests/data/t2f.nii.gz
Binary file not shown.
Binary file added tests/data/t2w.nii.gz
Binary file not shown.
46 changes: 46 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import unittest
from pathlib import Path

import nibabel as nib

from deep_quality_estimation import DQE
from deep_quality_estimation.enums import View


class TestDQEModel(unittest.TestCase):

def setUp(self):

self.t1c = Path("tests/data/t1c.nii.gz")
self.t2 = Path("tests/data/t2w.nii.gz")
self.t1 = Path("tests/data/t1n.nii.gz")
self.flair = Path("tests/data/t2f.nii.gz")
self.segmentation = Path("tests/data/seg-BraTS23_1.nii.gz")

def test_full_prediction_paths(self):
dqe = DQE(device="cpu")
mean_score, scores = dqe.predict(
t1=self.t1,
t1c=self.t1c,
t2=self.t2,
flair=self.flair,
segmentation=self.segmentation,
)
self.assertAlmostEqual(mean_score, 5.315626939137776, places=4)
self.assertAlmostEqual(scores[View.AXIAL.name], 5.45253849029541, places=4)
self.assertAlmostEqual(scores[View.CORONAL.name], 5.1386494636535645, places=4)
self.assertAlmostEqual(scores[View.SAGITTAL.name], 5.3556928634643555, places=4)

def test_full_prediction_numpy(self):
dqe = DQE(device="cpu")
mean_score, scores = dqe.predict(
t1=nib.load(self.t1).get_fdata(),
t1c=nib.load(self.t1c).get_fdata(),
t2=nib.load(self.t2).get_fdata(),
flair=nib.load(self.flair).get_fdata(),
segmentation=nib.load(self.segmentation).get_fdata(),
)
self.assertAlmostEqual(mean_score, 5.315626939137776, places=4)
self.assertAlmostEqual(scores[View.AXIAL.name], 5.45253849029541, places=4)
self.assertAlmostEqual(scores[View.CORONAL.name], 5.1386494636535645, places=4)
self.assertAlmostEqual(scores[View.SAGITTAL.name], 5.3556928634643555, places=4)
75 changes: 75 additions & 0 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import unittest
from pathlib import Path

import numpy as np

from deep_quality_estimation.data_handler import DataHandler


class TestTransforms(unittest.TestCase):

def setUp(self):

self.t1c = Path("tests/data/t1c.nii.gz")
self.t2 = Path("tests/data/t2w.nii.gz")
self.t1 = Path("tests/data/t1n.nii.gz")
self.flair = Path("tests/data/t2f.nii.gz")
self.segmentation_new_labels = Path("tests/data/seg-BraTS23_1.nii.gz")
self.segmentation_new_labels_mapped = Path(
"tests/data/seg-BraTS23_1_mapped.nii.gz"
)
self.segmentation_old_labels = Path("tests/data/seg-bratstoolkit_isen.nii.gz")

def test_labels_transforms_feasible(self):
"""
Verify that for all segmentations the labels are transformed correctly (have correct shape, one hot and at least 1 labeled pixel for the given data)
"""
for segmentation in [
self.segmentation_new_labels_mapped,
self.segmentation_new_labels,
self.segmentation_old_labels,
]:
data_handler = DataHandler(
t1c=self.t1c,
t2=self.t2,
t1=self.t1,
flair=self.flair,
segmentation=segmentation,
)

# get first element from dataloader
data = next(iter(data_handler.get_dataloader()))

self.assertEqual(data["labels"].shape, (1, 3, 240, 240))
for label_map in data["labels"][0]:
self.assertTrue(np.all(np.isin(label_map, [0, 1])))
self.assertTrue(np.sum(label_map) > 0)

def test_labels_transform_equal(self):
"""
Verify that for the same segmentation (with differing labeling convention) the transformed labels are equal
"""

data_handler_new_labels = DataHandler(
t1c=self.t1c,
t2=self.t2,
t1=self.t1,
flair=self.flair,
segmentation=self.segmentation_new_labels,
)
data_handler_new_labels_mapped = DataHandler(
t1c=self.t1c,
t2=self.t2,
t1=self.t1,
flair=self.flair,
segmentation=self.segmentation_new_labels_mapped,
)

data_new_labels = next(iter(data_handler_new_labels.get_dataloader()))
data_new_labels_mapped = next(
iter(data_handler_new_labels_mapped.get_dataloader())
)

self.assertTrue(
np.all(data_new_labels["labels"] == data_new_labels_mapped["labels"])
)

0 comments on commit a837403

Please sign in to comment.