diff --git a/brainles_preprocessing/modality.py b/brainles_preprocessing/modality.py index 1131cea..0fa7b54 100644 --- a/brainles_preprocessing/modality.py +++ b/brainles_preprocessing/modality.py @@ -94,11 +94,11 @@ def register( registered_log = os.path.join(registration_dir, f"{moving_image_name}.log") registrator.register( - fixed_image=fixed_image_path, - moving_image=self.current, + fixed_image_path=fixed_image_path, + moving_image_path=self.current, transformed_image=registered, - matrix=registered_matrix, - log_file=registered_log, + matrix_path=registered_matrix, + log_file_path=registered_log, ) self.current = registered return registered_matrix @@ -133,11 +133,11 @@ def transform( transformed_log = os.path.join(registration_dir, f"{moving_image_name}.log") registrator.transform( - fixed_image=fixed_image_path, - moving_image=self.current, - transformed_image=transformed, - matrix=transformation_matrix, - log_file=transformed_log, + fixed_image_path=fixed_image_path, + moving_image_path=self.current, + transformed_image_path=transformed, + matrix_path=transformation_matrix, + log_file_path=transformed_log, ) self.current = transformed diff --git a/brainles_preprocessing/registration/niftyreg.py b/brainles_preprocessing/registration/niftyreg.py index 9516d4d..c702c90 100644 --- a/brainles_preprocessing/registration/niftyreg.py +++ b/brainles_preprocessing/registration/niftyreg.py @@ -42,25 +42,25 @@ def __init__( def register( self, - fixed_image, - moving_image, - transformed_image, - matrix, - log_file, + fixed_image_path: str, + moving_image_path: str, + transformed_image_path: str, + matrix_path: str, + log_file_path: str, ): """ Register images using NiftyReg. Args: - fixed_image (str): Path to the fixed image. - moving_image (str): Path to the moving image. - transformed_image (str): Path to the transformed image (output). - matrix (str): Path to the transformation matrix (output). - log_file (str): Path to the log file. + fixed_image_path (str): Path to the fixed image. + moving_image_path (str): Path to the moving image. + transformed_image_path (str): Path to the transformed image (output). + matrix_path (str): Path to the transformation matrix (output). + log_file_path (str): Path to the log file. """ runner = ScriptRunner( script_path=self.registration_script, - log_path=log_file, + log_path=log_file_path, ) niftyreg_executable = str( @@ -69,10 +69,10 @@ def register( input_params = [ turbopath(niftyreg_executable), - turbopath(fixed_image), - turbopath(moving_image), - turbopath(transformed_image), - turbopath(matrix), + turbopath(fixed_image_path), + turbopath(moving_image_path), + turbopath(transformed_image_path), + turbopath(matrix_path), ] # Call the run method to execute the script and capture the output in the log file @@ -85,25 +85,25 @@ def register( def transform( self, - fixed_image, - moving_image, - transformed_image, - matrix, - log_file, + fixed_image_path: str, + moving_image_path: str, + transformed_image_path: str, + matrix_path: str, + log_file_path: str, ): """ Apply a transformation using NiftyReg. Args: - fixed_image (str): Path to the fixed image. - moving_image (str): Path to the moving image. - transformed_image (str): Path to the transformed image (output). - matrix (str): Path to the transformation matrix. - log_file (str): Path to the log file. + fixed_image_path (str): Path to the fixed image. + moving_image_path (str): Path to the moving image. + transformed_image_path (str): Path to the transformed image (output). + matrix_path (str): Path to the transformation matrix. + log_file_path (str): Path to the log file. """ runner = ScriptRunner( script_path=self.transformation_script, - log_path=log_file, + log_path=log_file_path, ) niftyreg_executable = str( @@ -112,10 +112,10 @@ def transform( input_params = [ turbopath(niftyreg_executable), - turbopath(fixed_image), - turbopath(moving_image), - turbopath(transformed_image), - turbopath(matrix), + turbopath(fixed_image_path), + turbopath(moving_image_path), + turbopath(transformed_image_path), + turbopath(matrix_path), ] # Call the run method to execute the script and capture the output in the log file diff --git a/brainles_preprocessing/registration/registrator.py b/brainles_preprocessing/registration/registrator.py index 95e8f23..56fb749 100644 --- a/brainles_preprocessing/registration/registrator.py +++ b/brainles_preprocessing/registration/registrator.py @@ -1,29 +1,56 @@ -# TODO add typing and docs from abc import ABC, abstractmethod +from typing import Any class Registrator(ABC): - def __init__(self, backend): - self.backend = backend + # TODO probably the init here should be removed? + # def __init__(self, backend): + # self.backend = backend @abstractmethod def register( self, - fixed_image, - moving_image, - transformed_image, - matrix, - log_file, - ): + fixed_image_path: Any, + moving_image_path: Any, + transformed_image_path: Any, + matrix_path: Any, + log_file_path: str, + ) -> None: + """ + Abstract method for registering images. + + Args: + fixed_image_path (Any): The fixed image for registration. + moving_image_path (Any): The moving image to be registered. + transformed_image_path (Any): The resulting transformed image after registration. + matrix_path (Any): The transformation matrix applied during registration. + log_file_path (str): The path to the log file for recording registration details. + + Returns: + None + """ pass @abstractmethod def transform( self, - fixed_image, - moving_image, - transformed_image, - matrix, - log_file, - ): + fixed_image_path: Any, + moving_image_path: Any, + transformed_image_path: Any, + matrix: Any, + log_file: str, + ) -> None: + """ + Abstract method for transforming images. + + Args: + fixed_image_path (Any): The fixed image to be transformed. + moving_image_path (Any): The moving image to be transformed. + transformed_image_path (Any): The resulting transformed image. + matrix_path (Any): The transformation matrix applied during transformation. + log_file_path (str): The path to the log file for recording transformation details. + + Returns: + None + """ pass diff --git a/tests/test_brain_extractor.py b/tests/test_hdbet_brain_extractor.py similarity index 94% rename from tests/test_brain_extractor.py rename to tests/test_hdbet_brain_extractor.py index 85c963a..a79d2a5 100644 --- a/tests/test_brain_extractor.py +++ b/tests/test_hdbet_brain_extractor.py @@ -11,12 +11,14 @@ class TestHDBetExtractor(unittest.TestCase): def setUp(self): test_data_dir = turbopath(__file__).parent + "/test_data" input_dir = test_data_dir + "/input" - self.output_dir = test_data_dir + "/temp_output" + self.output_dir = test_data_dir + "/temp_output_hdbet" os.makedirs(self.output_dir, exist_ok=True) self.brain_extractor = HDBetExtractor() + self.input_image_path = input_dir + "/tcia_example_t1c.nii.gz" self.input_brain_mask_path = input_dir + "/bet_tcia_example_t1c_mask.nii.gz" + self.masked_image_path = self.output_dir + "/bet_tcia_example_t1c.nii.gz" self.brain_mask_path = self.output_dir + "/bet_tcia_example_t1c_mask.nii.gz" self.masked_again_image_path = ( @@ -30,9 +32,8 @@ def tearDown(self): # Clean up created files if they exist shutil.rmtree(self.output_dir) - def test_extract_creates_output_files(self): - # we try to run the fastest possible skullstripping on GPU + # we try to run the fastest possible skullstripping on CPU self.brain_extractor.extract( input_image_path=self.input_image_path, masked_image_path=self.masked_image_path, diff --git a/tests/test_niftyreg_registrator.py b/tests/test_niftyreg_registrator.py new file mode 100644 index 0000000..2315960 --- /dev/null +++ b/tests/test_niftyreg_registrator.py @@ -0,0 +1,56 @@ +import os +import shutil +import unittest + +from auxiliary.turbopath import turbopath + +from brainles_preprocessing.registration.niftyreg import NiftyRegRegistrator + + +class TestNiftyRegRegistrator(unittest.TestCase): + def setUp(self): + test_data_dir = turbopath(__file__).parent + "/test_data" + input_dir = test_data_dir + "/input" + self.output_dir = test_data_dir + "/temp_output_niftyreg" + os.makedirs(self.output_dir, exist_ok=True) + + self.registrator = NiftyRegRegistrator() + + self.fixed_image = input_dir + "/tcia_example_t1c.nii.gz" + self.moving_image = input_dir + "/bet_tcia_example_t1c_mask.nii.gz" + + self.transformed_image = self.output_dir + "/transformed_image.nii.gz" + self.matrix = self.output_dir + "/matrix.txt" + self.log_file = self.output_dir + "/registration.log" + + def tearDown(self): + # Clean up created files if they exist + shutil.rmtree(self.output_dir) + + def test_register_creates_output_files(self): + # we try to run the fastest possible skullstripping on GPU + self.registrator.register( + fixed_image_path=self.fixed_image, + moving_image_path=self.moving_image, + transformed_image_path=self.transformed_image, + matrix_path=self.matrix, + log_file_path=self.log_file, + ) + + self.assertTrue( + os.path.exists(self.transformed_image), + "transformed file was not created.", + ) + + self.assertTrue( + os.path.exists(self.matrix), + "matrix file was not created.", + ) + + self.assertTrue( + os.path.exists(self.log_file), + "log file was not created.", + ) + + +# TODO also test transform