diff --git a/tests/test_registrators.py b/tests/test_registrators.py index 79b27a3..bcc18e1 100644 --- a/tests/test_registrators.py +++ b/tests/test_registrators.py @@ -6,55 +6,83 @@ from auxiliary.turbopath import turbopath from brainles_preprocessing.registration.ANTs.ANTs import ANTsRegistrator -from brainles_preprocessing.registration.niftyreg.niftyreg import \ - NiftyRegRegistrator +from brainles_preprocessing.registration.niftyreg.niftyreg import NiftyRegRegistrator -class TestRegistratorBase(unittest.TestCase): +class RegistratorBase: @abstractmethod def get_registrator(self): pass + @abstractmethod + def get_method_and_extension(self): + pass + def setUp(self): + self.registrator = self.get_registrator() + self.method_name, self.matrix_extension = self.get_method_and_extension() + test_data_dir = turbopath(__file__).parent + "/test_data" input_dir = test_data_dir + "/input" - self.output_dir = test_data_dir + "/temp_output_niftyreg" + self.output_dir = test_data_dir + f"/temp_output_{self.method_name}" os.makedirs(self.output_dir, exist_ok=True) - self.registrator = self.get_registrator() - self.fixed_image = input_dir + "/tcia_example_t1c.nii.gz" self.moving_image = input_dir + "/bet_tcia_example_t1c_mask.nii.gz" - self.transformed_image = self.output_dir + "/transformed_image.nii.gz" - self.matrix = self.output_dir + "/matrix.txt" - self.log_file = self.output_dir + "/registration.log" + self.matrix = self.output_dir + "/matrix" + self.transform_matrix = f"{self.matrix}.{self.matrix_extension}" def tearDown(self): # Clean up created files if they exist shutil.rmtree(self.output_dir) def test_register_creates_output_files(self): + transformed_image = self.output_dir + "/registered_image.nii.gz" + log_file = self.output_dir + "/registration.log" + self.registrator.register( fixed_image_path=self.fixed_image, moving_image_path=self.moving_image, - transformed_image_path=self.transformed_image, + transformed_image_path=transformed_image, matrix_path=self.matrix, - log_file_path=self.log_file, + log_file_path=log_file, ) self.assertTrue( - os.path.exists(self.transformed_image), + os.path.exists(transformed_image), "transformed file was not created.", ) self.assertTrue( - os.path.exists(self.matrix), + os.path.exists(f"{self.matrix}.{self.matrix_extension}"), "matrix file was not created.", ) self.assertTrue( - os.path.exists(self.log_file), + os.path.exists(log_file), + "log file was not created.", + ) + + def test_transform_creates_output_files(self): + transformed_image = self.output_dir + "/transformed_image.nii.gz" + log_file = self.output_dir + "/transformation.log" + + self.registrator.transform( + fixed_image_path=self.fixed_image, + moving_image_path=self.moving_image, + transformed_image_path=transformed_image, + matrix_path=self.transform_matrix, + log_file_path=log_file, + ) + + self.assertTrue( + os.path.exists(transformed_image), + "transformed file was not created.", + ) + + self.assertTrue( + os.path.exists(log_file), "log file was not created.", ) @@ -62,10 +90,17 @@ def test_register_creates_output_files(self): # TODO also test transform -class TestANTsRegistratorBase(TestRegistratorBase): +class TestANTsRegistrator(RegistratorBase, unittest.TestCase): def get_registrator(self): return ANTsRegistrator() -class TestNiftyRegRegistratorRegistratorBase(TestRegistratorBase): + def get_method_and_extension(self): + return "ants", "mat" + + +class TestNiftyRegRegistratorRegistrator(RegistratorBase, unittest.TestCase): def get_registrator(self): return NiftyRegRegistrator() + + def get_method_and_extension(self): + return "niftyreg", "txt"