diff --git a/brainles_preprocessing/registration/__init__.py b/brainles_preprocessing/registration/__init__.py index b4c8bcb..4c941cc 100644 --- a/brainles_preprocessing/registration/__init__.py +++ b/brainles_preprocessing/registration/__init__.py @@ -7,4 +7,12 @@ "ANTS package not found. If you want to use it, please install it using 'pip install brainles_preprocessing[ants]'" ) +try: + from .eReg.eReg import eRegRegistrator +except ImportError: + warnings.warn( + "eReg package not found. If you want to use it, please install it using 'pip install brainles_preprocessing[ereg]'" + ) + + from .niftyreg.niftyreg import NiftyRegRegistrator diff --git a/brainles_preprocessing/registration/eReg/__init__.py b/brainles_preprocessing/registration/eReg/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/brainles_preprocessing/registration/eReg/config_files/TODO.md b/brainles_preprocessing/registration/eReg/config_files/TODO.md new file mode 100644 index 0000000..e69de29 diff --git a/brainles_preprocessing/registration/eReg/eReg.py b/brainles_preprocessing/registration/eReg/eReg.py new file mode 100644 index 0000000..4404c19 --- /dev/null +++ b/brainles_preprocessing/registration/eReg/eReg.py @@ -0,0 +1,100 @@ +# TODO add typing and docs +import os + +from ereg.registration import RegistrationClass + +from brainles_preprocessing.registration.registrator import Registrator + + +class eRegRegistrator(Registrator): + def __init__( + self, + # TODO define default + configuration_file: str | None = None, + ): + """ + # TODO + """ + self.configuration_file = configuration_file + + def register( + self, + fixed_image_path: str, + moving_image_path: str, + transformed_image_path: str, + matrix_path: str, + log_file_path: str = None, + ) -> None: + """ + Register images using eReg. + + Args: + fixed_image_path (str): Path to the fixed image. + moving_image_path (str): Path to the moving image. + transformed_image_path (str): Path to the transformed image (output). + matrix_path (str): Path to the transformation matrix (output). + log_file_path (str): Path to the log file. + """ + # TODO do we need to handle kwargs? + registrator = RegistrationClass( + configuration_file=self.configuration_file, + ) + + matrix_path = _add_mat_suffix(matrix_path) + + registrator.register( + target_image=fixed_image_path, + moving_image=moving_image_path, + output_image=transformed_image_path, + transform_file=matrix_path, + log_file=log_file_path, + ) + + def transform( + self, + fixed_image_path: str, + moving_image_path: str, + transformed_image_path: str, + matrix_path: str, + log_file_path: str = None, + ) -> None: + """ + Apply a transformation using eReg. + + Args: + fixed_image_path (str): Path to the fixed image. + moving_image_path (str): Path to the moving image. + transformed_image_path (str): Path to the transformed image (output). + matrix_path (str): Path to the transformation matrix. + log_file_path (str): Path to the log file. + """ + # TODO do we need to handle kwargs? + registrator = RegistrationClass( + configuration_file=self.configuration_file, + ) + + matrix_path = _add_mat_suffix(matrix_path) + + registrator.resample_image( + target_image=fixed_image_path, + moving_image=moving_image_path, + output_image=transformed_image_path, + transform_file=matrix_path, + log_file=log_file_path, + ) + + +def _add_mat_suffix(filename: str) -> str: + """ + Adds a ".mat" suffix to the filename if it doesn't have any extension. + + Parameters: + filename (str): The filename to check and potentially modify. + + Returns: + str: The filename with ".mat" suffix added if needed. + """ + base, ext = os.path.splitext(filename) + if not ext: + filename += ".mat" + return filename diff --git a/brainles_preprocessing/registration/niftyreg/niftyreg.py b/brainles_preprocessing/registration/niftyreg/niftyreg.py index afbe089..3d405fe 100644 --- a/brainles_preprocessing/registration/niftyreg/niftyreg.py +++ b/brainles_preprocessing/registration/niftyreg/niftyreg.py @@ -1,4 +1,3 @@ -# TODO add typing and docs import os from auxiliary.runscript import ScriptRunner diff --git a/example/example_modality_centric_preprocessor.py b/example/example_modality_centric_preprocessor.py index 4bf7c0e..2f78f1e 100644 --- a/example/example_modality_centric_preprocessor.py +++ b/example/example_modality_centric_preprocessor.py @@ -6,7 +6,11 @@ from brainles_preprocessing.brain_extraction import HDBetExtractor from brainles_preprocessing.modality import Modality from brainles_preprocessing.preprocessor import Preprocessor -from brainles_preprocessing.registration import ANTsRegistrator, NiftyRegRegistrator +from brainles_preprocessing.registration import ( + ANTsRegistrator, + NiftyRegRegistrator, + eRegRegistrator, +) def preprocess(inputDir): @@ -103,6 +107,7 @@ def preprocess(inputDir): # choose the registration backend you want to use # registrator=NiftyRegRegistrator(), registrator=ANTsRegistrator(), + # registrator=eRegRegistrator(), brain_extractor=HDBetExtractor(), temp_folder="temporary_directory", limit_cuda_visible_devices="0", diff --git a/pyproject.toml b/pyproject.toml index 899614b..52ef9da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,10 +59,13 @@ rich = "^13.6.0" # optional registration backends antspyx = { version = "^0.4.2", optional = true } +ereg = { version = "^0.0.10", optional = true } + [tool.poetry.extras] -all = ["antspyx"] +all = ["antspyx", "ereg"] ants = ["antspyx"] +ereg = ["ereg"] [tool.poetry.dev-dependencies] diff --git a/tests/registrator_base.py b/tests/registrator_base.py new file mode 100644 index 0000000..7b529e2 --- /dev/null +++ b/tests/registrator_base.py @@ -0,0 +1,90 @@ +import os +from abc import abstractmethod + +from auxiliary.turbopath import turbopath +import shutil + + +class RegistratorBase: + + @abstractmethod + def get_registrator(self): + pass + + @abstractmethod + def get_method_and_extension(self): + pass + + def setUp(self): + self.registrator = self.get_registrator() + self.method_name, self.matrix_extension = self.get_method_and_extension() + + test_data_dir = turbopath(__file__).parent + "/test_data" + input_dir = test_data_dir + "/input" + self.output_dir = test_data_dir + f"/temp_output_{self.method_name}" + os.makedirs(self.output_dir, exist_ok=True) + + self.fixed_image = input_dir + "/tcia_example_t1c.nii.gz" + self.moving_image = input_dir + "/tcia_example_t1.nii.gz" + + self.matrix = self.output_dir + f"/{self.method_name}_matrix" + self.transform_matrix = input_dir + f"/{self.method_name}_matrix" + + def tearDown(self): + # Clean up created files if they exist + shutil.rmtree(self.output_dir) + # pass + + def test_register_creates_output_files(self): + transformed_image = ( + self.output_dir + f"/{self.method_name}_registered_image.nii.gz" + ) + log_file = self.output_dir + f"/{self.method_name}_registration.log" + + self.registrator.register( + fixed_image_path=self.fixed_image, + moving_image_path=self.moving_image, + transformed_image_path=transformed_image, + matrix_path=self.matrix, + log_file_path=log_file, + ) + + self.assertTrue( + os.path.exists(transformed_image), + "transformed file was not created.", + ) + + self.assertTrue( + os.path.exists(f"{self.matrix}.{self.matrix_extension}"), + "matrix file was not created.", + ) + + self.assertTrue( + os.path.exists(log_file), + "log file was not created.", + ) + + def test_transform_creates_output_files(self): + transformed_image = ( + self.output_dir + f"/{self.method_name}_transformed_image.nii.gz" + ) + log_file = self.output_dir + f"/{self.method_name}_transformation.log" + + print("tf m:", self.transform_matrix) + self.registrator.transform( + fixed_image_path=self.fixed_image, + moving_image_path=self.moving_image, + transformed_image_path=transformed_image, + matrix_path=self.transform_matrix, + log_file_path=log_file, + ) + + self.assertTrue( + os.path.exists(transformed_image), + "transformed file was not created.", + ) + + self.assertTrue( + os.path.exists(log_file), + "log file was not created.", + ) diff --git a/tests/test_data/input/ants_matrix.mat b/tests/test_data/input/ants_matrix.mat new file mode 100644 index 0000000..5833a29 Binary files /dev/null and b/tests/test_data/input/ants_matrix.mat differ diff --git a/tests/test_data/input/ereg_matrix.mat b/tests/test_data/input/ereg_matrix.mat new file mode 100644 index 0000000..3deea3c Binary files /dev/null and b/tests/test_data/input/ereg_matrix.mat differ diff --git a/tests/test_data/input/matrix.mat b/tests/test_data/input/matrix.mat deleted file mode 100644 index d5e7571..0000000 Binary files a/tests/test_data/input/matrix.mat and /dev/null differ diff --git a/tests/test_data/input/matrix.txt b/tests/test_data/input/matrix.txt deleted file mode 100644 index e4ca998..0000000 --- a/tests/test_data/input/matrix.txt +++ /dev/null @@ -1,4 +0,0 @@ -0.9999356 0.009583665 0.006066407 -0.2090089 --0.009398761 0.9995115 -0.02980891 -0.7436395 --0.006349129 0.02974998 0.9995372 0.3484979 -0 0 0 1 diff --git a/tests/test_data/input/niftyreg_matrix.txt b/tests/test_data/input/niftyreg_matrix.txt new file mode 100644 index 0000000..4447b4a --- /dev/null +++ b/tests/test_data/input/niftyreg_matrix.txt @@ -0,0 +1,4 @@ +0.9888756 0.1354215 0.06153017 -1.314461 +-0.1397797 0.9874795 0.07311541 -1.147088 +-0.05085838 -0.08090271 0.9954236 2.731441 +0 0 0 1 diff --git a/tests/test_data/input/tcia_example_t1.nii.gz b/tests/test_data/input/tcia_example_t1.nii.gz new file mode 100644 index 0000000..044fcf0 Binary files /dev/null and b/tests/test_data/input/tcia_example_t1.nii.gz differ diff --git a/tests/test_registrators.py b/tests/test_registrators.py index 1661926..be36fae 100644 --- a/tests/test_registrators.py +++ b/tests/test_registrators.py @@ -1,93 +1,10 @@ -import os -import shutil -import unittest -from abc import abstractmethod - -from auxiliary.turbopath import turbopath +from registrator_base import RegistratorBase from brainles_preprocessing.registration.ANTs.ANTs import ANTsRegistrator +from brainles_preprocessing.registration.eReg.eReg import eRegRegistrator from brainles_preprocessing.registration.niftyreg.niftyreg import NiftyRegRegistrator - -class RegistratorBase: - @abstractmethod - def get_registrator(self): - pass - - @abstractmethod - def get_method_and_extension(self): - pass - - def setUp(self): - self.registrator = self.get_registrator() - self.method_name, self.matrix_extension = self.get_method_and_extension() - - test_data_dir = turbopath(__file__).parent + "/test_data" - input_dir = test_data_dir + "/input" - self.output_dir = test_data_dir + f"/temp_output_{self.method_name}" - os.makedirs(self.output_dir, exist_ok=True) - - self.fixed_image = input_dir + "/tcia_example_t1c.nii.gz" - self.moving_image = input_dir + "/bet_tcia_example_t1c_mask.nii.gz" - - self.matrix = self.output_dir + "/matrix" - self.transform_matrix = input_dir + f"/matrix.{self.matrix_extension}" - - def tearDown(self): - # Clean up created files if they exist - shutil.rmtree(self.output_dir) - - def test_register_creates_output_files(self): - transformed_image = self.output_dir + "/registered_image.nii.gz" - log_file = self.output_dir + "/registration.log" - - self.registrator.register( - fixed_image_path=self.fixed_image, - moving_image_path=self.moving_image, - transformed_image_path=transformed_image, - matrix_path=self.matrix, - log_file_path=log_file, - ) - - self.assertTrue( - os.path.exists(transformed_image), - "transformed file was not created.", - ) - - self.assertTrue( - os.path.exists(f"{self.matrix}.{self.matrix_extension}"), - "matrix file was not created.", - ) - - self.assertTrue( - os.path.exists(log_file), - "log file was not created.", - ) - - def test_transform_creates_output_files(self): - transformed_image = self.output_dir + "/transformed_image.nii.gz" - log_file = self.output_dir + "/transformation.log" - - self.registrator.transform( - fixed_image_path=self.fixed_image, - moving_image_path=self.moving_image, - transformed_image_path=transformed_image, - matrix_path=self.transform_matrix, - log_file_path=log_file, - ) - - self.assertTrue( - os.path.exists(transformed_image), - "transformed file was not created.", - ) - - self.assertTrue( - os.path.exists(log_file), - "log file was not created.", - ) - - -# TODO also test transform +import unittest class TestANTsRegistrator(RegistratorBase, unittest.TestCase): @@ -104,3 +21,11 @@ def get_registrator(self): def get_method_and_extension(self): return "niftyreg", "txt" + + +class TestEregRegistratorRegistrator(RegistratorBase, unittest.TestCase): + def get_registrator(self): + return eRegRegistrator() + + def get_method_and_extension(self): + return "ereg", "mat"