From 6d99009f750c05f733b4152a4fe3ceabc0d01e77 Mon Sep 17 00:00:00 2001 From: Marcel Rosier Date: Tue, 16 Apr 2024 13:48:08 +0200 Subject: [PATCH] migrate test to unit tests cleanup and fixes --- unit_tests/test_ereg.py | 130 ++++++++++++++++++++++++ unit_tests/test_full.py | 216 ---------------------------------------- 2 files changed, 130 insertions(+), 216 deletions(-) create mode 100644 unit_tests/test_ereg.py delete mode 100644 unit_tests/test_full.py diff --git a/unit_tests/test_ereg.py b/unit_tests/test_ereg.py new file mode 100644 index 0000000..45decf8 --- /dev/null +++ b/unit_tests/test_ereg.py @@ -0,0 +1,130 @@ +import logging +import os +import tempfile +import unittest + +import yaml +from ereg.cli.run import main +from ereg.functional import registration_function, resample_function +from ereg.registration import RegistrationClass +from ereg.utils.io import read_image_and_cast_to_32bit_float + + +class TestEReg(unittest.TestCase): + + def setUp(self): + # While this is already set within eReg it is necessary to specify it as well in the test environment + # else pytest will overwrite it with its own logging configuration (level WARNING) + logging.getLogger().setLevel(logging.DEBUG) + + test_data_dir = "data" + atlas_data_dir = "atlases" + self.moving_image = os.path.join(test_data_dir, "tcia_aaac_t1ce.nii.gz") + self.atlas_sri = os.path.join(atlas_data_dir, "sri24", "image.nii.gz") + self.test_config_file = os.path.join(test_data_dir, "test_config.yaml") + + test_config = {"initialization": "moments"} + with open(self.test_config_file, "w") as f: + yaml.dump(test_config, f) + + # Helper function + def _image_sanity_check(self, image1, image2): + image_1 = read_image_and_cast_to_32bit_float(image1) + image_2 = read_image_and_cast_to_32bit_float(image2) + assert image_1.GetSize() == image_2.GetSize(), "Image sizes do not match." + assert ( + image_1.GetSpacing() == image_2.GetSpacing() + ), "Image spacings do not match." + assert image_1.GetOrigin() == image_2.GetOrigin(), "Image origins do not match." + assert ( + image_1.GetDirection() == image_2.GetDirection() + ), "Image directions do not match." + + ###### TESTS ###### + + def test_cli_run_main(self): + + with tempfile.TemporaryDirectory() as temp_dir: + output_image = os.path.join(temp_dir, "reg.nii.gz") + transform_file = os.path.join(temp_dir, "trans.mat") + main( + [ + "--movingImg", + self.moving_image, + "--targetImg", + self.atlas_sri, + "--output", + output_image, + "--transfile", + transform_file, + "--config", + self.test_config_file, + ] + ) + self._image_sanity_check(self.atlas_sri, output_image) + + def test_registration_function(self): + test_config = {"initialization": "moments", "bias": True} + with tempfile.TemporaryDirectory() as temp_dir: + output_image = os.path.join(temp_dir, "reg.nii.gz") + transform_file = os.path.join(temp_dir, "trans.mat") + log_file = os.path.join(temp_dir, "reg.log") + registration_function( + target_image=self.atlas_sri, + moving_image=self.moving_image, + output_image=output_image, + transform_file=transform_file, + configuration=test_config, + log_file=log_file, + ) + + self._image_sanity_check(self.atlas_sri, output_image) + + assert os.path.exists(transform_file), "Transform file not created." + assert os.path.exists(log_file), "Log file not created." + # check if log_file is empty + assert os.path.getsize(log_file) > 0, "Log file is empty." + + def test_registration_and_resampling_function(self): + test_config = {"initialization": "moments", "bias": True} + with tempfile.TemporaryDirectory() as temp_dir: + reg_output_image = os.path.join(temp_dir, "reg.nii.gz") + transform_file = os.path.join(temp_dir, "trans.mat") + reg_log_file = os.path.join(temp_dir, "reg.log") + + registration_function( + target_image=self.atlas_sri, + moving_image=self.moving_image, + output_image=reg_output_image, + transform_file=transform_file, + configuration=test_config, + log_file=reg_log_file, + ) + + self._image_sanity_check(self.atlas_sri, reg_output_image) + + assert os.path.exists(transform_file), "Transform file not created." + assert os.path.exists(reg_log_file), "Registration log file not created." + assert os.path.getsize(reg_log_file) > 0, "Registration log file is empty." + + ## Resample + resample_log_file = os.path.join(temp_dir, "resample.log") + resample_output_image = os.path.join(temp_dir, "resample.nii.gz") + resample_function( + target_image=self.atlas_sri, + moving_image=self.moving_image, + output_image=resample_output_image, + transform_file=transform_file, + configuration=test_config, + log_file=resample_log_file, + ) + + self._image_sanity_check(self.atlas_sri, resample_output_image) + assert os.path.exists(transform_file), "Transform file not created" + assert os.path.exists(resample_log_file), "Resample log file not created" + assert os.path.getsize(resample_log_file) > 0, "Resample log file is empty" + + def test_bias_correct_image(self): + register_obj = RegistrationClass() + moving_bias = register_obj._bias_correct_image(self.moving_image) + self._image_sanity_check(self.moving_image, moving_bias) diff --git a/unit_tests/test_full.py b/unit_tests/test_full.py deleted file mode 100644 index a2ba6e0..0000000 --- a/unit_tests/test_full.py +++ /dev/null @@ -1,216 +0,0 @@ -import os -import tempfile -from pathlib import Path - -import yaml - -from ereg.cli.run import main -from ereg.functional import registration_function, resample_function -from ereg.registration import RegistrationClass -from ereg.utils.io import read_image_and_cast_to_32bit_float - - -def _image_sanity_check(image1, image2): - image_1 = read_image_and_cast_to_32bit_float(image1) - image_2 = read_image_and_cast_to_32bit_float(image2) - assert image_1.GetSize() == image_2.GetSize(), "Image sizes do not match." - assert image_1.GetSpacing() == image_2.GetSpacing(), "Image spacings do not match." - assert image_1.GetOrigin() == image_2.GetOrigin(), "Image origins do not match." - assert ( - image_1.GetDirection() == image_2.GetDirection() - ), "Image directions do not match." - - -def test_main(): - cwd = Path.cwd() - test_data_dir = (cwd / "data").absolute().as_posix() - atlas_data_dir = (cwd / "atlases").absolute().as_posix() - base_config_file = os.path.join(test_data_dir, "test_config.yaml") - moving_image = os.path.join(test_data_dir, "tcia_aaac_t1ce.nii.gz") - temp_output_dir = tempfile.gettempdir() - output_image = os.path.join(temp_output_dir, "tcia_aaac_t1ce_registered.nii.gz") - transform_file = os.path.join(temp_output_dir, "tcia_aaac_t1ce_transform.mat") - atlas_sri = os.path.join(atlas_data_dir, "sri24", "image.nii.gz") - test_config = {"initialization": "moments"} - with open(base_config_file, "w") as f: - yaml.dump(test_config, f) - - main( - [ - "--movingImg", - moving_image, - "--targetImg", - atlas_sri, - "--output", - output_image, - "--transfile", - transform_file, - "--config", - base_config_file, - ] - ) - _image_sanity_check(atlas_sri, output_image) - for file_to_delete in [output_image, transform_file]: - os.remove(file_to_delete) - - -## todo: this is not working for some reason -- will fix later -# def test_main_dir(): -# cwd = Path.cwd() -# test_data_dir = (cwd / "data").absolute().as_posix() -# atlas_data_dir = (cwd / "atlases").absolute().as_posix() -# base_config_file = os.path.join(test_data_dir, "test_config.yaml") -# atlas_sri = os.path.join(atlas_data_dir, "sri24", "image.nii.gz") -# # check dir processing -# output_dir = os.path.join(tempfile.gettempdir(), "dir_output") -# test_config = {"initialization": "moments"} -# with open(base_config_file, "w") as f: -# yaml.dump(test_config, f) -# main( -# [ -# "--movingImg", -# test_data_dir, -# "--targetImg", -# atlas_sri, -# "--output", -# output_dir, -# "--config", -# base_config_file, -# ] -# ) -# shutil.rmtree(output_dir) - - -def test_registration_function(): - cwd = Path.cwd() - test_data_dir = (cwd / "data").absolute().as_posix() - atlas_data_dir = (cwd / "atlases").absolute().as_posix() - moving_image = os.path.join(test_data_dir, "tcia_aaac_t1ce.nii.gz") - temp_output_dir = tempfile.gettempdir() - output_image = os.path.join(temp_output_dir, "tcia_aaac_t1ce_registered.nii.gz") - atlas_sri = os.path.join(atlas_data_dir, "sri24", "image.nii.gz") - transform_file = os.path.join(temp_output_dir, "tcia_aaac_t1ce_transform.mat") - log_file = os.path.join(temp_output_dir, "tcia_aaac_t1ce_registration.log") - test_config = {"initialization": "moments", "bias": True} - - registration_function( - target_image=atlas_sri, - moving_image=moving_image, - output_image=output_image, - config_file=test_config, - transform_file=transform_file, - log_file=log_file, - ) - - # checks - _image_sanity_check(atlas_sri, output_image) - assert os.path.exists(transform_file), "Transform file not created." - assert os.path.exists(log_file), "Log file not created." - # check if log_file is empty - assert os.path.getsize(log_file) > 0, "Log file is empty." - - # cleanup - for file_to_delete in [output_image, transform_file, log_file]: - os.remove(file_to_delete) - - -def test_resample_function_using_previously_saved_transform(): - cwd = Path.cwd() - test_data_dir = (cwd / "data").absolute().as_posix() - atlas_data_dir = (cwd / "atlases").absolute().as_posix() - moving_image = os.path.join(test_data_dir, "tcia_aaac_t1ce.nii.gz") - temp_output_dir = tempfile.gettempdir() - output_image = os.path.join(temp_output_dir, "tcia_aaac_t1ce_registered.nii.gz") - atlas_sri = os.path.join(atlas_data_dir, "sri24", "image.nii.gz") - transform_file = os.path.join(test_data_dir, "tcia_aaac_t1ce_transform.mat") - log_file = os.path.join(temp_output_dir, "tcia_aaac_t1ce_transformation.log") - test_config = {"initialization": "moments", "bias": True} - - resample_function( - target_image=atlas_sri, - moving_image=moving_image, - output_image=output_image, - transform_file=transform_file, - configuration=test_config, - log_file=log_file, - ) - - # checks - _image_sanity_check(atlas_sri, output_image) - assert os.path.exists(log_file), "Log file not created." - - # cleanup - for file_to_delete in [output_image, log_file]: - os.remove(file_to_delete) - - -def test_registration_and_resampling_function(): - cwd = Path.cwd() - test_data_dir = (cwd / "data").absolute().as_posix() - atlas_data_dir = (cwd / "atlases").absolute().as_posix() - moving_image = os.path.join(test_data_dir, "tcia_aaac_t1ce.nii.gz") - temp_output_dir = tempfile.gettempdir() - output_image = os.path.join(temp_output_dir, "tcia_aaac_t1ce_registered.nii.gz") - resample_output_image = os.path.join( - temp_output_dir, "tcia_aaac_t1ce_resampled.nii.gz" - ) - - atlas_sri = os.path.join(atlas_data_dir, "sri24", "image.nii.gz") - transform_file = os.path.join(temp_output_dir, "tcia_aaac_t1ce_transform.mat") - registration_log_file = os.path.join( - temp_output_dir, "tcia_aaac_t1ce_registration.log" - ) - resample_log_file = os.path.join( - temp_output_dir, "tcia_aaac_t1ce_transformation.log" - ) - test_config = {"initialization": "moments", "bias": True} - - registration_function( - target_image=atlas_sri, - moving_image=moving_image, - output_image=output_image, - config_file=test_config, - transform_file=transform_file, - log_file=registration_log_file, - ) - assert os.path.exists(registration_log_file), "Registration log file not created." - # check if registration_log_file is empty - assert os.path.getsize(registration_log_file) > 0, "Registration log file is empty." - - # checks - _image_sanity_check(atlas_sri, output_image) - assert os.path.exists(transform_file), "Transform file not created." - - resample_function( - target_image=atlas_sri, - moving_image=moving_image, - output_image=resample_output_image, - transform_file=transform_file, - configuration=test_config, - log_file=resample_log_file, - ) - - _image_sanity_check(atlas_sri, resample_output_image) - assert os.path.exists(transform_file), "Transform file got deleted, somehow." - assert os.path.exists(resample_log_file), "Transform log file not created." - # check if resample_log_file is empty - assert os.path.getsize(resample_log_file) > 0, "Resample log file is empty." - - # cleanup - for file_to_delete in [ - output_image, - resample_output_image, - transform_file, - registration_log_file, - resample_log_file, - ]: - os.remove(file_to_delete) - - -def test_bias(): - cwd = Path.cwd() - test_data_dir = (cwd / "data").absolute().as_posix() - moving_image = os.path.join(test_data_dir, "tcia_aaac_t1ce.nii.gz") - register_obj = RegistrationClass() - moving_bias = register_obj._bias_correct_image(moving_image) - _image_sanity_check(moving_image, moving_bias)