From 9fa861ac85c51ff873cd8ae86c649d117a2469eb Mon Sep 17 00:00:00 2001 From: Marcel Rosier Date: Mon, 25 Nov 2024 12:25:31 +0100 Subject: [PATCH] Fix test prediction to be on cpu and increase error tolerance --- tests/test_model.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 98f8a7c..3e75a57 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -18,7 +18,7 @@ def setUp(self): self.segmentation = Path("tests/data/seg-BraTS23_1.nii.gz") def test_full_prediction_paths(self): - dqe = DQE() + dqe = DQE(device="cpu") mean_score, scores = dqe.predict( t1=self.t1, t1c=self.t1c, @@ -26,13 +26,13 @@ def test_full_prediction_paths(self): flair=self.flair, segmentation=self.segmentation, ) - self.assertAlmostEqual(mean_score, 5.315534591674805) - self.assertAlmostEqual(scores[View.AXIAL.name], 5.452404975891113) - self.assertAlmostEqual(scores[View.CORONAL.name], 5.138582229614258) - self.assertAlmostEqual(scores[View.SAGITTAL.name], 5.355616569519043) + self.assertAlmostEqual(mean_score, 5.315626939137776, places=4) + self.assertAlmostEqual(scores[View.AXIAL.name], 5.45253849029541, places=4) + self.assertAlmostEqual(scores[View.CORONAL.name], 5.1386494636535645, places=4) + self.assertAlmostEqual(scores[View.SAGITTAL.name], 5.3556928634643555, places=4) def test_full_prediction_numpy(self): - dqe = DQE() + dqe = DQE(device="cpu") mean_score, scores = dqe.predict( t1=nib.load(self.t1).get_fdata(), t1c=nib.load(self.t1c).get_fdata(), @@ -40,7 +40,7 @@ def test_full_prediction_numpy(self): flair=nib.load(self.flair).get_fdata(), segmentation=nib.load(self.segmentation).get_fdata(), ) - self.assertAlmostEqual(mean_score, 5.315534591674805) - self.assertAlmostEqual(scores[View.AXIAL.name], 5.452404975891113) - self.assertAlmostEqual(scores[View.CORONAL.name], 5.138582229614258) - self.assertAlmostEqual(scores[View.SAGITTAL.name], 5.355616569519043) + self.assertAlmostEqual(mean_score, 5.315626939137776, places=4) + self.assertAlmostEqual(scores[View.AXIAL.name], 5.45253849029541, places=4) + self.assertAlmostEqual(scores[View.CORONAL.name], 5.1386494636535645, places=4) + self.assertAlmostEqual(scores[View.SAGITTAL.name], 5.3556928634643555, places=4)