diff --git a/brats/utils/data_handling.py b/brats/utils/data_handling.py index 2b3f2bc..1eab8ff 100644 --- a/brats/utils/data_handling.py +++ b/brats/utils/data_handling.py @@ -89,7 +89,7 @@ def input_sanity_check( t2w (Path | str, optional): T2w image path (required for segmentation) mask (Path | str, optional): Mask image path (required for inpainting) """ - + # Filter out None values to only include provided images images = { "t1n": t1n, diff --git a/tests/core/test_brats_algorithm.py b/tests/core/test_brats_algorithm.py index c57086b..be6b4b0 100644 --- a/tests/core/test_brats_algorithm.py +++ b/tests/core/test_brats_algorithm.py @@ -7,8 +7,9 @@ from brats import AdultGliomaSegmenter from brats.utils.constants import OUTPUT_NAME_SCHEMA + class TestBraTSAlgorithm(unittest.TestCase): - + def setUp(self): # Create a temporary directory for testing self.test_dir = Path(tempfile.mkdtemp()) @@ -21,75 +22,88 @@ def setUp(self): self.subject_A_folder.mkdir(parents=True, exist_ok=True) # Create mock file paths self.input_files = { - "t1c":self.subject_A_folder / "A-t1c.nii.gz", - "t1n":self.subject_A_folder/ "A-t1n.nii.gz", - "t2f":self.subject_A_folder/ "A-t2f.nii.gz", - "t2w":self.subject_A_folder/ "A-t2w.nii.gz", + "t1c": self.subject_A_folder / "A-t1c.nii.gz", + "t1n": self.subject_A_folder / "A-t1n.nii.gz", + "t2f": self.subject_A_folder / "A-t2f.nii.gz", + "t2w": self.subject_A_folder / "A-t2w.nii.gz", } for file in self.input_files.values(): file.touch() # the core inference method is the same for all segmentation and inpainting algorithms, we use AdultGliomaSegmenter as an example during testing self.segmenter = AdultGliomaSegmenter() - + def tearDown(self): # Remove the temporary directory after the test shutil.rmtree(self.test_dir) - @patch("brats.core.brats_algorithm.run_container") @patch("brats.core.segmentation_algorithms.input_sanity_check") @patch("brats.core.brats_algorithm.InferenceSetup") - def test_infer_single(self,mock_inference_setup,mock_input_sanity_check, mock_run_container): - + def test_infer_single( + self, mock_inference_setup, mock_input_sanity_check, mock_run_container + ): + # Mock InferenceSetup context manager mock_inference_setup_ret = mock_inference_setup.return_value - mock_inference_setup_ret.__enter__.return_value = (self.data_folder, self.output_folder) - - + mock_inference_setup_ret.__enter__.return_value = ( + self.data_folder, + self.output_folder, + ) + def create_output_file(*args, **kwargs): - subject_id = self.segmenter.algorithm.run_args.input_name_schema.format(id=0) - alg_output_file = self.output_folder / OUTPUT_NAME_SCHEMA[self.segmenter.task].format(subject_id=subject_id) + subject_id = self.segmenter.algorithm.run_args.input_name_schema.format( + id=0 + ) + alg_output_file = self.output_folder / OUTPUT_NAME_SCHEMA[ + self.segmenter.task + ].format(subject_id=subject_id) alg_output_file.touch() + mock_run_container.side_effect = create_output_file - - + output_file = self.output_folder / "output.nii.gz" self.segmenter.infer_single( t1c=self.input_files["t1c"], t1n=self.input_files["t1n"], t2f=self.input_files["t2f"], t2w=self.input_files["t2w"], - output_file=output_file + output_file=output_file, ) mock_input_sanity_check.assert_called_once() mock_run_container.assert_called_once() - + self.assertTrue(output_file.exists()) - + @patch("brats.core.brats_algorithm.run_container") @patch("brats.core.segmentation_algorithms.input_sanity_check") @patch("brats.core.brats_algorithm.InferenceSetup") - def test_infer_batch(self,mock_inference_setup,mock_input_sanity_check, mock_run_container): - + def test_infer_batch( + self, mock_inference_setup, mock_input_sanity_check, mock_run_container + ): + # Mock InferenceSetup context manager mock_inference_setup_ret = mock_inference_setup.return_value - mock_inference_setup_ret.__enter__.return_value = (self.data_folder, self.output_folder) - - + mock_inference_setup_ret.__enter__.return_value = ( + self.data_folder, + self.output_folder, + ) + def create_output_file(*args, **kwargs): - subject_id = self.segmenter.algorithm.run_args.input_name_schema.format(id=0) - alg_output_file = self.output_folder / OUTPUT_NAME_SCHEMA[self.segmenter.task].format(subject_id=subject_id) + subject_id = self.segmenter.algorithm.run_args.input_name_schema.format( + id=0 + ) + alg_output_file = self.output_folder / OUTPUT_NAME_SCHEMA[ + self.segmenter.task + ].format(subject_id=subject_id) alg_output_file.touch() + mock_run_container.side_effect = create_output_file - - + self.segmenter.infer_batch( - data_folder=self.data_folder, - output_folder=self.output_folder + data_folder=self.data_folder, output_folder=self.output_folder ) mock_input_sanity_check.assert_called_once() mock_run_container.assert_called_once() output_file = self.output_folder / "A.nii.gz" self.assertTrue(output_file.exists()) - diff --git a/tests/core/test_inpainting_algorithms.py b/tests/core/test_inpainting_algorithms.py index 8be3dd8..1e15dae 100644 --- a/tests/core/test_inpainting_algorithms.py +++ b/tests/core/test_inpainting_algorithms.py @@ -27,15 +27,15 @@ def setUp(self): # Create dummy files for img in [self.t1n, self.mask]: img.touch(exist_ok=True) - + self.segmenter = Inpainter() def tearDown(self): # Remove the temporary directory after the test shutil.rmtree(self.test_dir) - + ### Standardization tests - + @patch("brats.core.inpainting_algorithms.input_sanity_check") def test_successful_single_standardization(self, mock_input_sanity_check): subject_id = "test_subject" @@ -45,16 +45,12 @@ def test_successful_single_standardization(self, mock_input_sanity_check): inputs={ "t1n": self.t1n, "mask": self.mask, - } + }, ) subject_folder = self.tmp_data_folder / subject_id self.assertTrue(subject_folder.exists()) - self.assertTrue( - (subject_folder / f"{subject_id}-t1n-voided.nii.gz").exists() - ) - self.assertTrue( - (subject_folder / f"{subject_id}-mask.nii.gz").exists() - ) + self.assertTrue((subject_folder / f"{subject_id}-t1n-voided.nii.gz").exists()) + self.assertTrue((subject_folder / f"{subject_id}-mask.nii.gz").exists()) @patch("brats.core.inpainting_algorithms.input_sanity_check") @patch("sys.exit") @@ -71,15 +67,13 @@ def test_single_standardize_handle_file_not_found_error( inputs={ "t1n": t1n, "mask": self.mask, - } + }, ) mock_logger.assert_called() mock_exit.assert_called_with(1) @patch("brats.core.inpainting_algorithms.Inpainter._standardize_single_inputs") - def test_standardize_segmentation_inputs_list( - self, mock_standardize_single_inputs - ): + def test_standardize_segmentation_inputs_list(self, mock_standardize_single_inputs): subjects = [f for f in self.data_folder.iterdir() if f.is_dir()] mapping = self.segmenter._standardize_batch_inputs( data_folder=self.tmp_data_folder, @@ -93,20 +87,16 @@ def test_standardize_segmentation_inputs_list( }, ) mock_standardize_single_inputs.assert_called_once() - + ### Initialization tests - + def test_inpainter_initialization(self): # Test default initialization inpainter = Inpainter() self.assertIsInstance(inpainter, Inpainter) - + # Test with custom arguments custom_inpainter = Inpainter( - algorithm=InpaintingAlgorithms.BraTS23_2, - cuda_devices="1", - force_cpu=True + algorithm=InpaintingAlgorithms.BraTS23_2, cuda_devices="1", force_cpu=True ) self.assertIsInstance(custom_inpainter, Inpainter) - - \ No newline at end of file diff --git a/tests/core/test_segmentation_algorithms.py b/tests/core/test_segmentation_algorithms.py index d198e08..0c2d85d 100644 --- a/tests/core/test_segmentation_algorithms.py +++ b/tests/core/test_segmentation_algorithms.py @@ -6,11 +6,20 @@ from loguru import logger -from brats import (AdultGliomaSegmenter, AfricaSegmenter, MeningiomaSegmenter, - MetastasesSegmenter, PediatricSegmenter) -from brats.utils.constants import (AdultGliomaAlgorithms, AfricaAlgorithms, - MeningiomaAlgorithms, MetastasesAlgorithms, - PediatricAlgorithms) +from brats import ( + AdultGliomaSegmenter, + AfricaSegmenter, + MeningiomaSegmenter, + MetastasesSegmenter, + PediatricSegmenter, +) +from brats.utils.constants import ( + AdultGliomaAlgorithms, + AfricaAlgorithms, + MeningiomaAlgorithms, + MetastasesAlgorithms, + PediatricAlgorithms, +) class TestSegmentationAlgorithms(unittest.TestCase): @@ -32,15 +41,15 @@ def setUp(self): # Create dummy files for img in [self.t1c, self.t1n, self.t2f, self.t2w]: img.touch(exist_ok=True) - + self.segmenter = AdultGliomaSegmenter() def tearDown(self): # Remove the temporary directory after the test shutil.rmtree(self.test_dir) - + ### Standardization tests - + @patch("brats.core.segmentation_algorithms.input_sanity_check") def test_successful_single_standardization(self, mock_input_sanity_check): subject_id = "test_subject" @@ -52,7 +61,7 @@ def test_successful_single_standardization(self, mock_input_sanity_check): "t1n": self.t1n, "t2f": self.t2f, "t2w": self.t2w, - } + }, ) subject_folder = self.tmp_data_folder / subject_id self.assertTrue(subject_folder.exists()) @@ -78,15 +87,15 @@ def test_single_standardize_handle_file_not_found_error( "t1n": self.t1n, "t2f": self.t2f, "t2w": self.t2w, - } + }, ) mock_logger.assert_called() mock_exit.assert_called_with(1) - @patch("brats.core.segmentation_algorithms.SegmentationAlgorithm._standardize_single_inputs") - def test_standardize_segmentation_inputs_list( - self, mock_standardize_single_inputs - ): + @patch( + "brats.core.segmentation_algorithms.SegmentationAlgorithm._standardize_single_inputs" + ) + def test_standardize_segmentation_inputs_list(self, mock_standardize_single_inputs): subjects = [f for f in self.data_folder.iterdir() if f.is_dir()] mapping = self.segmenter._standardize_batch_inputs( data_folder=self.tmp_data_folder, @@ -100,19 +109,17 @@ def test_standardize_segmentation_inputs_list( }, ) mock_standardize_single_inputs.assert_called_once() - + ### Initialization tests - + def test_adult_glioma_segmenter_initialization(self): # Test default initialization segmenter = AdultGliomaSegmenter() self.assertIsInstance(segmenter, AdultGliomaSegmenter) - + # Test with custom arguments custom_segmenter = AdultGliomaSegmenter( - algorithm=AdultGliomaAlgorithms.BraTS23_2, - cuda_devices="1", - force_cpu=True + algorithm=AdultGliomaAlgorithms.BraTS23_2, cuda_devices="1", force_cpu=True ) self.assertIsInstance(custom_segmenter, AdultGliomaSegmenter) @@ -120,12 +127,10 @@ def test_meningioma_segmenter_initialization(self): # Test default initialization segmenter = MeningiomaSegmenter() self.assertIsInstance(segmenter, MeningiomaSegmenter) - + # Test with custom arguments custom_segmenter = MeningiomaSegmenter( - algorithm=MeningiomaAlgorithms.BraTS23_2, - cuda_devices="1", - force_cpu=True + algorithm=MeningiomaAlgorithms.BraTS23_2, cuda_devices="1", force_cpu=True ) self.assertIsInstance(custom_segmenter, MeningiomaSegmenter) @@ -133,12 +138,10 @@ def test_pediatric_segmenter_initialization(self): # Test default initialization segmenter = PediatricSegmenter() self.assertIsInstance(segmenter, PediatricSegmenter) - + # Test with custom arguments custom_segmenter = PediatricSegmenter( - algorithm=PediatricAlgorithms.BraTS23_2, - cuda_devices="1", - force_cpu=True + algorithm=PediatricAlgorithms.BraTS23_2, cuda_devices="1", force_cpu=True ) self.assertIsInstance(custom_segmenter, PediatricSegmenter) @@ -146,12 +149,10 @@ def test_africa_segmenter_initialization(self): # Test default initialization segmenter = AfricaSegmenter() self.assertIsInstance(segmenter, AfricaSegmenter) - + # Test with custom arguments custom_segmenter = AfricaSegmenter( - algorithm=AfricaAlgorithms.BraTS23_2, - cuda_devices="1", - force_cpu=True + algorithm=AfricaAlgorithms.BraTS23_2, cuda_devices="1", force_cpu=True ) self.assertIsInstance(custom_segmenter, AfricaSegmenter) @@ -159,13 +160,9 @@ def test_metastases_segmenter_initialization(self): # Test default initialization segmenter = MetastasesSegmenter() self.assertIsInstance(segmenter, MetastasesSegmenter) - + # Test with custom arguments custom_segmenter = MetastasesSegmenter( - algorithm=MetastasesAlgorithms.BraTS23_2, - cuda_devices="1", - force_cpu=True + algorithm=MetastasesAlgorithms.BraTS23_2, cuda_devices="1", force_cpu=True ) self.assertIsInstance(custom_segmenter, MetastasesSegmenter) - - diff --git a/tests/utils/test_data_handling.py b/tests/utils/test_data_handling.py index 7f38495..47e8c08 100644 --- a/tests/utils/test_data_handling.py +++ b/tests/utils/test_data_handling.py @@ -6,7 +6,12 @@ from loguru import logger -from brats.utils.data_handling import InferenceSetup, add_log_file_handler, input_sanity_check, remove_tmp_folder +from brats.utils.data_handling import ( + InferenceSetup, + add_log_file_handler, + input_sanity_check, + remove_tmp_folder, +) class TestDataHandlingUtils(unittest.TestCase): @@ -32,12 +37,15 @@ def setUp(self): def tearDown(self): # Remove the temporary directory after the test shutil.rmtree(self.test_dir) - + def test_inference_setup_with_log_file(self): # Create a temporary log file tmp_log_file = Path(tempfile.mktemp()) - - with InferenceSetup(log_file=tmp_log_file) as (tmp_data_folder, tmp_output_folder): + + with InferenceSetup(log_file=tmp_log_file) as ( + tmp_data_folder, + tmp_output_folder, + ): # Check that the folders are created self.assertTrue(tmp_data_folder.is_dir()) self.assertTrue(tmp_output_folder.is_dir()) @@ -51,23 +59,23 @@ def test_inference_setup_with_log_file(self): # Remove the temporary log file tmp_log_file.unlink(missing_ok=True) - + def test_inference_setup_without_log_file(self): # Create a temporary log file tmp_log_file = Path(tempfile.mktemp()) - + with InferenceSetup() as (tmp_data_folder, tmp_output_folder): # Check that the folders are created self.assertTrue(tmp_data_folder.is_dir()) self.assertTrue(tmp_output_folder.is_dir()) # Check that the log file exists - self.assertFalse(tmp_log_file.exists()) # Log file should not be created + self.assertFalse(tmp_log_file.exists()) # Log file should not be created # Check if folders are cleaned up self.assertFalse(tmp_data_folder.exists()) self.assertFalse(tmp_output_folder.exists()) - + def test_remove_tmp_folder_success(self): # Test successful removal of a folder temp_folder = Path(tempfile.mkdtemp()) @@ -106,8 +114,6 @@ def test_add_log_file_handler(self): logger.remove(handler_id) log_file.unlink(missing_ok=True) - - @patch("brats.utils.data_handling.nib.load") @patch("brats.utils.data_handling.logger.warning") def test_input_sanity_check_correct_shape(self, mock_warning, mock_nib_load):