generated from BrainLesion/brainles-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from BrainLesion/3-add-tests
3 add tests
- Loading branch information
Showing
12 changed files
with
139 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import unittest | ||
from pathlib import Path | ||
|
||
import nibabel as nib | ||
|
||
from deep_quality_estimation import DQE | ||
from deep_quality_estimation.enums import View | ||
|
||
|
||
class TestDQEModel(unittest.TestCase): | ||
|
||
def setUp(self): | ||
|
||
self.t1c = Path("tests/data/t1c.nii.gz") | ||
self.t2 = Path("tests/data/t2w.nii.gz") | ||
self.t1 = Path("tests/data/t1n.nii.gz") | ||
self.flair = Path("tests/data/t2f.nii.gz") | ||
self.segmentation = Path("tests/data/seg-BraTS23_1.nii.gz") | ||
|
||
def test_full_prediction_paths(self): | ||
dqe = DQE(device="cpu") | ||
mean_score, scores = dqe.predict( | ||
t1=self.t1, | ||
t1c=self.t1c, | ||
t2=self.t2, | ||
flair=self.flair, | ||
segmentation=self.segmentation, | ||
) | ||
self.assertAlmostEqual(mean_score, 5.315626939137776, places=4) | ||
self.assertAlmostEqual(scores[View.AXIAL.name], 5.45253849029541, places=4) | ||
self.assertAlmostEqual(scores[View.CORONAL.name], 5.1386494636535645, places=4) | ||
self.assertAlmostEqual(scores[View.SAGITTAL.name], 5.3556928634643555, places=4) | ||
|
||
def test_full_prediction_numpy(self): | ||
dqe = DQE(device="cpu") | ||
mean_score, scores = dqe.predict( | ||
t1=nib.load(self.t1).get_fdata(), | ||
t1c=nib.load(self.t1c).get_fdata(), | ||
t2=nib.load(self.t2).get_fdata(), | ||
flair=nib.load(self.flair).get_fdata(), | ||
segmentation=nib.load(self.segmentation).get_fdata(), | ||
) | ||
self.assertAlmostEqual(mean_score, 5.315626939137776, places=4) | ||
self.assertAlmostEqual(scores[View.AXIAL.name], 5.45253849029541, places=4) | ||
self.assertAlmostEqual(scores[View.CORONAL.name], 5.1386494636535645, places=4) | ||
self.assertAlmostEqual(scores[View.SAGITTAL.name], 5.3556928634643555, places=4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import unittest | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
|
||
from deep_quality_estimation.data_handler import DataHandler | ||
|
||
|
||
class TestTransforms(unittest.TestCase): | ||
|
||
def setUp(self): | ||
|
||
self.t1c = Path("tests/data/t1c.nii.gz") | ||
self.t2 = Path("tests/data/t2w.nii.gz") | ||
self.t1 = Path("tests/data/t1n.nii.gz") | ||
self.flair = Path("tests/data/t2f.nii.gz") | ||
self.segmentation_new_labels = Path("tests/data/seg-BraTS23_1.nii.gz") | ||
self.segmentation_new_labels_mapped = Path( | ||
"tests/data/seg-BraTS23_1_mapped.nii.gz" | ||
) | ||
self.segmentation_old_labels = Path("tests/data/seg-bratstoolkit_isen.nii.gz") | ||
|
||
def test_labels_transforms_feasible(self): | ||
""" | ||
Verify that for all segmentations the labels are transformed correctly (have correct shape, one hot and at least 1 labeled pixel for the given data) | ||
""" | ||
for segmentation in [ | ||
self.segmentation_new_labels_mapped, | ||
self.segmentation_new_labels, | ||
self.segmentation_old_labels, | ||
]: | ||
data_handler = DataHandler( | ||
t1c=self.t1c, | ||
t2=self.t2, | ||
t1=self.t1, | ||
flair=self.flair, | ||
segmentation=segmentation, | ||
) | ||
|
||
# get first element from dataloader | ||
data = next(iter(data_handler.get_dataloader())) | ||
|
||
self.assertEqual(data["labels"].shape, (1, 3, 240, 240)) | ||
for label_map in data["labels"][0]: | ||
self.assertTrue(np.all(np.isin(label_map, [0, 1]))) | ||
self.assertTrue(np.sum(label_map) > 0) | ||
|
||
def test_labels_transform_equal(self): | ||
""" | ||
Verify that for the same segmentation (with differing labeling convention) the transformed labels are equal | ||
""" | ||
|
||
data_handler_new_labels = DataHandler( | ||
t1c=self.t1c, | ||
t2=self.t2, | ||
t1=self.t1, | ||
flair=self.flair, | ||
segmentation=self.segmentation_new_labels, | ||
) | ||
data_handler_new_labels_mapped = DataHandler( | ||
t1c=self.t1c, | ||
t2=self.t2, | ||
t1=self.t1, | ||
flair=self.flair, | ||
segmentation=self.segmentation_new_labels_mapped, | ||
) | ||
|
||
data_new_labels = next(iter(data_handler_new_labels.get_dataloader())) | ||
data_new_labels_mapped = next( | ||
iter(data_handler_new_labels_mapped.get_dataloader()) | ||
) | ||
|
||
self.assertTrue( | ||
np.all(data_new_labels["labels"] == data_new_labels_mapped["labels"]) | ||
) |