diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8617802..fb59acb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -28,7 +28,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install flake8 pytest - pip install -e . + pip install -e .[all] - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names diff --git a/brainles_preprocessing/modality.py b/brainles_preprocessing/modality.py index 476bfb8..b4a4c03 100644 --- a/brainles_preprocessing/modality.py +++ b/brainles_preprocessing/modality.py @@ -170,7 +170,9 @@ def register( str: Path to the registration matrix. """ registered = os.path.join(registration_dir, f"{moving_image_name}.nii.gz") - registered_matrix = os.path.join(registration_dir, f"{moving_image_name}.txt") + registered_matrix = os.path.join( + registration_dir, f"{moving_image_name}" + ) # note, add file ending depending on registration backend! registered_log = os.path.join(registration_dir, f"{moving_image_name}.log") registrator.register( diff --git a/brainles_preprocessing/registration/ANTs/ANTs.py b/brainles_preprocessing/registration/ANTs/ANTs.py new file mode 100644 index 0000000..0d3b214 --- /dev/null +++ b/brainles_preprocessing/registration/ANTs/ANTs.py @@ -0,0 +1,210 @@ +# TODO add typing and docs +import datetime +import os +import shutil + +import ants +from auxiliary.turbopath import turbopath + +from brainles_preprocessing.registration.registrator import Registrator + + +class ANTsRegistrator(Registrator): + def __init__( + self, + registration_params: dict = None, + transformation_params: dict = None, + ): + """ + Initialize an ANTsRegistrator instance. + + Parameters: + - registration_params (dict, optional): Dictionary of parameters for the registration method. + Defaults to None, which implies using default registration parameters with a rigid transformation. + - transformation_params (dict, optional): Dictionary of parameters for the transformation method. + Defaults to an empty dictionary. + + The registration_params dictionary may include the following keys: + - type_of_transform (str, optional): Type of transformation to use (default is "Rigid"). + + Example: + >>> reg_params = {'type_of_transform': 'Affine', 'reg_iterations': (30, 20, 10)} + >>> transform_params = {'interpolator': 'linear', 'imagetype': 1} + >>> registrator = ANTsRegistrator(registration_params=reg_params, transformation_params=transform_params) + """ + # Set default registration parameters + default_registration_params = {"type_of_transform": "Rigid"} + self.registration_params = registration_params or default_registration_params + + # Set default transformation parameters + self.transformation_params = transformation_params or {} + + def register( + self, + fixed_image_path: str, + moving_image_path: str, + transformed_image_path: str, + matrix_path: str, + log_file_path: str, + **kwargs, + ) -> None: + """ + Register images using ANTs. + + Args: + fixed_image_path (str): Path to the fixed image. + moving_image_path (str): Path to the moving image. + transformed_image_path (str): Path to the transformed image (output). + matrix_path (str): Path to the transformation matrix (output). + log_file_path (str): Path to the log file. + **kwargs: Additional registration parameters to update the instantiated defaults. + """ + # we update the transformation parameters with the provided kwargs + + start_time = datetime.datetime.now() + + registration_kwargs = {**self.registration_params, **kwargs} + transformed_image_path = turbopath(transformed_image_path) + + matrix_path = turbopath(matrix_path) + if matrix_path.suffix != ".mat": + matrix_path = matrix_path.with_suffix(".mat") + + fixed_image = ants.image_read(fixed_image_path) + moving_image = ants.image_read(moving_image_path) + registration_result = ants.registration( + fixed=fixed_image, + moving=moving_image, + **registration_kwargs, + ) + transformed_image = registration_result["warpedmovout"] + os.makedirs(transformed_image_path.parent, exist_ok=True) + ants.image_write(transformed_image, transformed_image_path) + os.makedirs(matrix_path.parent, exist_ok=True) + shutil.copyfile(registration_result["fwdtransforms"][0], matrix_path) + + end_time = datetime.datetime.now() + + # TODO nicer logging + # we create a dummy log file for the moment to pass the tests + + self._log_to_file( + log_file_path=log_file_path, + fixed_image_path=fixed_image_path, + moving_image_path=moving_image_path, + transformed_image_path=transformed_image_path, + matrix_path=matrix_path, + operation_name="registration", + start_time=start_time, + end_time=end_time, + ) + + def transform( + self, + fixed_image_path: str, + moving_image_path: str, + transformed_image_path: str, + matrix_path: str, + log_file_path: str, + **kwargs, + ) -> None: + """ + Apply a transformation using ANTs. + + Args: + fixed_image_path (str): Path to the fixed image. + moving_image_path (str): Path to the moving image. + transformed_image_path (str): Path to the transformed image (output). + matrix_path (str): Path to the transformation matrix. + log_file_path (str): Path to the log file. + **kwargs: Additional transformation parameters to update the instantiated defaults. + """ + start_time = datetime.datetime.now() + + # we update the transformation parameters with the provided kwargs + transform_kwargs = {**self.transformation_params, **kwargs} + fixed_image = ants.image_read(fixed_image_path) + moving_image = ants.image_read(moving_image_path) + transformed_image_path = turbopath(transformed_image_path) + os.makedirs(transformed_image_path.parent, exist_ok=True) + + matrix_path = turbopath(matrix_path) + if matrix_path.suffix != ".mat": + matrix_path = matrix_path.with_suffix(".mat") + transformed_image = ants.apply_transforms( + fixed=fixed_image, + moving=moving_image, + transformlist=[matrix_path], + **transform_kwargs, + ) + ants.image_write(transformed_image, transformed_image_path) + + end_time = datetime.datetime.now() + + # TODO nicer logging + # we create a dummy log file for the moment to pass the tests + + self._log_to_file( + log_file_path=log_file_path, + fixed_image_path=fixed_image_path, + moving_image_path=moving_image_path, + transformed_image_path=transformed_image_path, + matrix_path=matrix_path, + operation_name="transformation", + start_time=start_time, + end_time=end_time, + ) + + @staticmethod + def _log_to_file( + log_file_path: str, + fixed_image_path: str, + moving_image_path: str, + transformed_image_path: str, + matrix_path: str, + operation_name: str, + start_time, + end_time, + ): + + # Calculate the duration and make it human readable + duration = (end_time - start_time).total_seconds() + + hours = int(duration // 3600) + minutes = int((duration % 3600) // 60) + seconds = int(duration % 60) + milliseconds = int((duration - int(duration)) * 1000) + + # Format the duration as "0:0:0:0" + duration_formatted = f"{hours}h {minutes}m {seconds}s {milliseconds}ms" + + with open(log_file_path, "w") as f: + f.write(f"*** {operation_name} with antspyx ***\n") + f.write(f"start time: {start_time} \n") + f.write(f"fixed image: {fixed_image_path} \n") + f.write(f"moving image: {moving_image_path} \n") + f.write(f"transformed image: {transformed_image_path} \n") + f.write(f"matrix: {matrix_path} \n") + f.write(f"end time: {end_time} \n") + f.write(f"duration: {duration_formatted}\n") + + +if __name__ == "__main__": + # TODO move this into unit tests + reg = ANTsRegistrator() + + reg.register( + fixed_image_path="example/example_data/TCGA-DU-7294/AX_T1_POST_GD_FLAIR_TCGA-DU-7294_TCGA-DU-7294_GE_TCGA-DU-7294_AX_T1_POST_GD_FLAIR_RM_13_t1c.nii.gz", + moving_image_path="example/example_data/TCGA-DU-7294/AX_T2_FR-FSE_RF2_150_TCGA-DU-7294_TCGA-DU-7294_GE_TCGA-DU-7294_AX_T2_FR-FSE_RF2_150_RM_4_t2.nii.gz", + transformed_image_path="example/example_ants/transformed_image.nii.gz", + matrix_path="example/example_ants_matrix/matrix", + log_file_path="example/example_ants/log.txt", + ) + + reg.transform( + fixed_image_path="example/example_data/TCGA-DU-7294/AX_T1_POST_GD_FLAIR_TCGA-DU-7294_TCGA-DU-7294_GE_TCGA-DU-7294_AX_T1_POST_GD_FLAIR_RM_13_t1c.nii.gz", + moving_image_path="example/example_data/OtherEXampleFromTCIA/T1_AX_OtherEXampleTCIA_TCGA-FG-6692_Si_TCGA-FG-6692_T1_AX_SE_10_se2d1_t1.nii.gz", + transformed_image_path="example/example_ants_transformed/transformed_image.nii.gz", + matrix_path="example/example_ants_matrix/matrix.mat", + log_file_path="example/example_ants/log.txt", + ) diff --git a/brainles_preprocessing/registration/ANTs/TODO_ANTs_parameters.py b/brainles_preprocessing/registration/ANTs/TODO_ANTs_parameters.py new file mode 100644 index 0000000..2fce65b --- /dev/null +++ b/brainles_preprocessing/registration/ANTs/TODO_ANTs_parameters.py @@ -0,0 +1,195 @@ +import os +import shlex + +from app.project_e.image_processing.utilities.utils import eleSubprocess +from flask_socketio import SocketIO + + +def ants_registrator( + fixed_image, moving_image, outputmat, transformationalgorithm="rigid" +): + # ants call parameters + dimensionality = "-d 3" + initial_moving_transform = "-r [" + fixed_image + ", " + moving_image + ", 0]" + + # transformations + if transformationalgorithm == "rigid": + # rigid ants_transformation + transform_rigid = "-t rigid[0.1]" + metric_rigid = ( + "-m Mattes[" + fixed_image + "," + moving_image + ", 1, 32, Regular, 0.5]" + ) + convergence_rigid = "-c [1000x500x250, 1e-6, 10]" + smoothing_sigmas_rigid = "-s 3x2x1vox" + shrink_factors_rigid = "-f 8x4x2" + elif transformationalgorithm == "rigid+affine": + # rigid ants_transformation + transform_rigid = "-t rigid[0.1]" + metric_rigid = ( + "-m Mattes[" + fixed_image + "," + moving_image + ", 1, 32, Regular, 0.5]" + ) + convergence_rigid = "-c [1000x500x250, 1e-6, 10]" + smoothing_sigmas_rigid = "-s 3x2x1vox" + shrink_factors_rigid = "-f 8x4x2" + + # affine ants_transformation + transform_affine = "-t affine[0.1]" + metric_affine = ( + "-m Mattes[" + fixed_image + "," + moving_image + ", 1, 32, Regular, 0.5]" + ) + convergence_affine = "-c [1000x500x250, 1e-6, 10]" + smoothing_sigmas_affine = "-s 3x2x1vox" + shrink_factors_affine = "-f 8x4x2" + elif transformationalgorithm == "rex-dfc": + # translation + transform_translation = "-t translation[0.1]" + metric_translation = ( + "-m Mattes[" + fixed_image + "," + moving_image + ", 1, 32, Regular, 0.05]" + ) + convergence_translation = "-c [1000, 1e-8, 20]" + smoothing_sigmas_translation = "-s 4vox" + shrink_factors_translation = "-f 6" + + # rigid ants_transformation + transform_rigid = "-t rigid[0.1]" + metric_rigid = ( + "-m Mattes[" + fixed_image + "," + moving_image + ", 1, 32, Regular, 0.1]" + ) + convergence_rigid = "-c [1000x1000, 1e-8, 20]" + smoothing_sigmas_rigid = "-s 4x2vox" + shrink_factors_rigid = "-f 4x2" + + # affine ants_transformation + transform_affine = "-t affine[0.1]" + metric_affine = ( + "-m Mattes[" + fixed_image + "," + moving_image + ", 1, 32, Regular, 0.1]" + ) + convergence_affine = "-c [10000x1111x5, 1e-8, 20]" + smoothing_sigmas_affine = "-s 3x2x1vox" + shrink_factors_affine = "-f 8x4x2" + + # other parameters + use_estimate_learning_rate_once = "-l 1" + collapse_output_transforms = "-z 1" + interpolation = "-n BSpline[3]" + precision = "--float 1" + output = "-o " + "[" + outputmat + "]" + + # generate calls + if transformationalgorithm == "rigid": + ants_cmd = ( + "antsRegistration", + dimensionality, + initial_moving_transform, + # rigid ants_transformation + transform_rigid, + metric_rigid, + convergence_rigid, + smoothing_sigmas_rigid, + shrink_factors_rigid, + # other parameters + use_estimate_learning_rate_once, + collapse_output_transforms, + interpolation, + precision, + output, + ) + ants_call = shlex.split("%s %s %s %s %s %s %s %s %s %s %s %s %s" % ants_cmd) + + elif transformationalgorithm == "rigid+affine": + ants_cmd = ( + "antsRegistration", + dimensionality, + initial_moving_transform, + # rigid ants_transformation + transform_rigid, + metric_rigid, + convergence_rigid, + smoothing_sigmas_rigid, + shrink_factors_rigid, + # affine ants_transformation + transform_affine, + metric_affine, + convergence_affine, + smoothing_sigmas_affine, + shrink_factors_affine, + # other parameters + use_estimate_learning_rate_once, + collapse_output_transforms, + interpolation, + precision, + output, + ) + ants_call = shlex.split( + "%s %s %s %s %s %s %s %s %s %s %s %s %s %s %s %s %s %s" % ants_cmd + ) + + elif transformationalgorithm == "rex-dfc": + ants_cmd = ( + "antsRegistration", + dimensionality, + initial_moving_transform, + # translation + transform_translation, + metric_translation, + convergence_translation, + smoothing_sigmas_translation, + shrink_factors_translation, + # rigid ants_transformation + transform_rigid, + metric_rigid, + convergence_rigid, + smoothing_sigmas_rigid, + shrink_factors_rigid, + # affine ants_transformation + transform_affine, + metric_affine, + convergence_affine, + smoothing_sigmas_affine, + shrink_factors_affine, + # other parameters + use_estimate_learning_rate_once, + collapse_output_transforms, + interpolation, + precision, + output, + ) + ants_call = shlex.split( + "%s %s %s %s %s %s %s %s %s %s %s %s %s %s %s %s %s %s %s %s %s %s %s" + % ants_cmd + ) + + # construct call + readable_ants_call = " ".join(ants_cmd) + print("calling ants with the following call:") + print(readable_ants_call) + + # log file + logFilePath = outputmat + "registration.log" + # call it + eleSubprocess(logFilePath=logFilePath, call=ants_call) + + +def modality_registrator(examid, modality): + socketio = SocketIO(message_queue="redis://") + socketio.emit( + "ipstatus", {"examid": examid, "ipstatus": modality + " ants registration"} + ) + + niftipath = os.path.normpath(os.path.join("data/tmp/", examid, "raw_niftis")) + # fixed image + native_t1 = os.path.join(niftipath, examid + "_native_t1.nii.gz") + + # moving images + moving_image = os.path.join(niftipath, examid + "_native_" + modality + ".nii.gz") + + # output mats + exportpath = os.path.normpath(os.path.join("data/tmp/", examid, "registrations")) + os.makedirs(exportpath, exist_ok=True) + filename = examid + "_" + modality + "_to_t1_" + outputmat = os.path.join(exportpath, filename) + + # call it + ants_registrator( + native_t1, moving_image, outputmat, transformationalgorithm="rigid" + ) diff --git a/brainles_preprocessing/registration/ANTs/__init__.py b/brainles_preprocessing/registration/ANTs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/brainles_preprocessing/registration/__init__.py b/brainles_preprocessing/registration/__init__.py index 6f56b85..b4c8bcb 100644 --- a/brainles_preprocessing/registration/__init__.py +++ b/brainles_preprocessing/registration/__init__.py @@ -1 +1,10 @@ +import warnings + +try: + from .ANTs.ANTs import ANTsRegistrator +except ImportError: + warnings.warn( + "ANTS package not found. If you want to use it, please install it using 'pip install brainles_preprocessing[ants]'" + ) + from .niftyreg.niftyreg import NiftyRegRegistrator diff --git a/brainles_preprocessing/registration/niftyreg/niftyreg.py b/brainles_preprocessing/registration/niftyreg/niftyreg.py index 778484b..afbe089 100644 --- a/brainles_preprocessing/registration/niftyreg/niftyreg.py +++ b/brainles_preprocessing/registration/niftyreg/niftyreg.py @@ -67,6 +67,10 @@ def register( turbopath(__file__).parent + "/niftyreg_scripts/reg_aladin", ) + turbopath(matrix_path) + if matrix_path.suffix != ".txt": + matrix_path = matrix_path.with_suffix(".txt") + input_params = [ turbopath(niftyreg_executable), turbopath(fixed_image_path), @@ -110,12 +114,17 @@ def transform( turbopath(__file__).parent + "/niftyreg_scripts/reg_resample", ) + turbopath(matrix_path) + if matrix_path.suffix != ".txt": + matrix_path = matrix_path.with_suffix(".txt") + input_params = [ turbopath(niftyreg_executable), turbopath(fixed_image_path), turbopath(moving_image_path), turbopath(transformed_image_path), turbopath(matrix_path), + # we need to add txt as this is the format for niftyreg matrixes ] # Call the run method to execute the script and capture the output in the log file diff --git a/example/example_modality_centric_preprocessor.py b/example/example_modality_centric_preprocessor.py index 384a139..4bf7c0e 100644 --- a/example/example_modality_centric_preprocessor.py +++ b/example/example_modality_centric_preprocessor.py @@ -6,7 +6,7 @@ from brainles_preprocessing.brain_extraction import HDBetExtractor from brainles_preprocessing.modality import Modality from brainles_preprocessing.preprocessor import Preprocessor -from brainles_preprocessing.registration import NiftyRegRegistrator +from brainles_preprocessing.registration import ANTsRegistrator, NiftyRegRegistrator def preprocess(inputDir): @@ -100,7 +100,9 @@ def preprocess(inputDir): preprocessor = Preprocessor( center_modality=center, moving_modalities=moving_modalities, - registrator=NiftyRegRegistrator(), + # choose the registration backend you want to use + # registrator=NiftyRegRegistrator(), + registrator=ANTsRegistrator(), brain_extractor=HDBetExtractor(), temp_folder="temporary_directory", limit_cuda_visible_devices="0", diff --git a/pyproject.toml b/pyproject.toml index 4df8f63..899614b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,11 +40,10 @@ classifiers = [ "Operating System :: OS Independent", ] - [tool.poetry.dependencies] python = "^3.10" # core -path = "^16.2.0" +path = "^16.10.0" ttictoc = "^0.5.6" pathlib = "^1.0.1" nibabel = "^3.2.1" @@ -55,9 +54,17 @@ BrainLes-HD-BET = ">=0.0.5" # utils tqdm = "^4.64.1" -auxiliary = "^0.0.40" +auxiliary = "^0.0.41" rich = "^13.6.0" +# optional registration backends +antspyx = { version = "^0.4.2", optional = true } + +[tool.poetry.extras] +all = ["antspyx"] +ants = ["antspyx"] + + [tool.poetry.dev-dependencies] pytest = "^6.2" diff --git a/tests/test_data/input/matrix.mat b/tests/test_data/input/matrix.mat new file mode 100644 index 0000000..d5e7571 Binary files /dev/null and b/tests/test_data/input/matrix.mat differ diff --git a/tests/test_data/input/matrix.txt b/tests/test_data/input/matrix.txt new file mode 100644 index 0000000..e4ca998 --- /dev/null +++ b/tests/test_data/input/matrix.txt @@ -0,0 +1,4 @@ +0.9999356 0.009583665 0.006066407 -0.2090089 +-0.009398761 0.9995115 -0.02980891 -0.7436395 +-0.006349129 0.02974998 0.9995372 0.3484979 +0 0 0 1 diff --git a/tests/test_niftyreg_registrator.py b/tests/test_niftyreg_registrator.py deleted file mode 100644 index 94bcaca..0000000 --- a/tests/test_niftyreg_registrator.py +++ /dev/null @@ -1,56 +0,0 @@ -import os -import shutil -import unittest - -from auxiliary.turbopath import turbopath - -from brainles_preprocessing.registration.niftyreg.niftyreg import NiftyRegRegistrator - - -class TestNiftyRegRegistrator(unittest.TestCase): - def setUp(self): - test_data_dir = turbopath(__file__).parent + "/test_data" - input_dir = test_data_dir + "/input" - self.output_dir = test_data_dir + "/temp_output_niftyreg" - os.makedirs(self.output_dir, exist_ok=True) - - self.registrator = NiftyRegRegistrator() - - self.fixed_image = input_dir + "/tcia_example_t1c.nii.gz" - self.moving_image = input_dir + "/bet_tcia_example_t1c_mask.nii.gz" - - self.transformed_image = self.output_dir + "/transformed_image.nii.gz" - self.matrix = self.output_dir + "/matrix.txt" - self.log_file = self.output_dir + "/registration.log" - - def tearDown(self): - # Clean up created files if they exist - shutil.rmtree(self.output_dir) - - def test_register_creates_output_files(self): - # we try to run the fastest possible skullstripping on GPU - self.registrator.register( - fixed_image_path=self.fixed_image, - moving_image_path=self.moving_image, - transformed_image_path=self.transformed_image, - matrix_path=self.matrix, - log_file_path=self.log_file, - ) - - self.assertTrue( - os.path.exists(self.transformed_image), - "transformed file was not created.", - ) - - self.assertTrue( - os.path.exists(self.matrix), - "matrix file was not created.", - ) - - self.assertTrue( - os.path.exists(self.log_file), - "log file was not created.", - ) - - -# TODO also test transform diff --git a/tests/test_registrators.py b/tests/test_registrators.py new file mode 100644 index 0000000..1661926 --- /dev/null +++ b/tests/test_registrators.py @@ -0,0 +1,106 @@ +import os +import shutil +import unittest +from abc import abstractmethod + +from auxiliary.turbopath import turbopath + +from brainles_preprocessing.registration.ANTs.ANTs import ANTsRegistrator +from brainles_preprocessing.registration.niftyreg.niftyreg import NiftyRegRegistrator + + +class RegistratorBase: + @abstractmethod + def get_registrator(self): + pass + + @abstractmethod + def get_method_and_extension(self): + pass + + def setUp(self): + self.registrator = self.get_registrator() + self.method_name, self.matrix_extension = self.get_method_and_extension() + + test_data_dir = turbopath(__file__).parent + "/test_data" + input_dir = test_data_dir + "/input" + self.output_dir = test_data_dir + f"/temp_output_{self.method_name}" + os.makedirs(self.output_dir, exist_ok=True) + + self.fixed_image = input_dir + "/tcia_example_t1c.nii.gz" + self.moving_image = input_dir + "/bet_tcia_example_t1c_mask.nii.gz" + + self.matrix = self.output_dir + "/matrix" + self.transform_matrix = input_dir + f"/matrix.{self.matrix_extension}" + + def tearDown(self): + # Clean up created files if they exist + shutil.rmtree(self.output_dir) + + def test_register_creates_output_files(self): + transformed_image = self.output_dir + "/registered_image.nii.gz" + log_file = self.output_dir + "/registration.log" + + self.registrator.register( + fixed_image_path=self.fixed_image, + moving_image_path=self.moving_image, + transformed_image_path=transformed_image, + matrix_path=self.matrix, + log_file_path=log_file, + ) + + self.assertTrue( + os.path.exists(transformed_image), + "transformed file was not created.", + ) + + self.assertTrue( + os.path.exists(f"{self.matrix}.{self.matrix_extension}"), + "matrix file was not created.", + ) + + self.assertTrue( + os.path.exists(log_file), + "log file was not created.", + ) + + def test_transform_creates_output_files(self): + transformed_image = self.output_dir + "/transformed_image.nii.gz" + log_file = self.output_dir + "/transformation.log" + + self.registrator.transform( + fixed_image_path=self.fixed_image, + moving_image_path=self.moving_image, + transformed_image_path=transformed_image, + matrix_path=self.transform_matrix, + log_file_path=log_file, + ) + + self.assertTrue( + os.path.exists(transformed_image), + "transformed file was not created.", + ) + + self.assertTrue( + os.path.exists(log_file), + "log file was not created.", + ) + + +# TODO also test transform + + +class TestANTsRegistrator(RegistratorBase, unittest.TestCase): + def get_registrator(self): + return ANTsRegistrator() + + def get_method_and_extension(self): + return "ants", "mat" + + +class TestNiftyRegRegistratorRegistrator(RegistratorBase, unittest.TestCase): + def get_registrator(self): + return NiftyRegRegistrator() + + def get_method_and_extension(self): + return "niftyreg", "txt"