diff --git a/brainles_preprocessing/brain_extraction/brain_extractor.py b/brainles_preprocessing/brain_extraction/brain_extractor.py index a7f86aa..de9cac5 100644 --- a/brainles_preprocessing/brain_extraction/brain_extractor.py +++ b/brainles_preprocessing/brain_extraction/brain_extractor.py @@ -1,16 +1,10 @@ # TODO add typing and docs from abc import abstractmethod -import os - -import nibabel as nib -import numpy as np -from brainles_hd_bet import run_hd_bet +from shutil import copyfile from auxiliary.nifti.io import read_nifti, write_nifti from auxiliary.turbopath import name_extractor - - -from shutil import copyfile +from brainles_hd_bet import run_hd_bet class BrainExtractor: @@ -68,6 +62,8 @@ def extract( log_file_path: str = None, # TODO convert mode to enum mode: str = "accurate", + device: int | str = 0, + do_tta: bool = True, ) -> None: # GPU + accurate + TTA """skullstrips images with HD-BET generates a skullstripped file and mask""" @@ -78,9 +74,9 @@ def extract( # TODO consider postprocessing # postprocess=False, mode=mode, - device=0, + device=device, postprocess=False, - do_tta=True, + do_tta=do_tta, keep_mask=True, overwrite=True, ) @@ -89,7 +85,7 @@ def extract( masked_image_path.parent + "/" + name_extractor(masked_image_path) - + "_masked.nii.gz" + + "_mask.nii.gz" ) copyfile( diff --git a/pyproject.toml b/pyproject.toml index 4c10ad0..4df8f63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ BrainLes-HD-BET = ">=0.0.5" # utils tqdm = "^4.64.1" -auxiliary = "^0.0.38" +auxiliary = "^0.0.40" rich = "^13.6.0" [tool.poetry.dev-dependencies] @@ -65,7 +65,7 @@ pytest = "^6.2" optional = true [tool.poetry.group.docs.dependencies] -Sphinx = ">=7.0.0" +Sphinx = ">=7.0.0" sphinx-copybutton = ">=0.5.2" sphinx-rtd-theme = ">=1.3.0" -myst-parser = ">=2.0.0" \ No newline at end of file +myst-parser = ">=2.0.0" diff --git a/tests/test_brain_extractor.py b/tests/test_brain_extractor.py new file mode 100644 index 0000000..54ca6d7 --- /dev/null +++ b/tests/test_brain_extractor.py @@ -0,0 +1,57 @@ +import os +import shutil +import unittest + +from auxiliary.turbopath import turbopath + +from brainles_preprocessing.brain_extraction import HDBetExtractor + + +class TestHDBetExtractor(unittest.TestCase): + def setUp(self): + test_data_dir = turbopath(__file__).parent + "/test_data" + input_dir = test_data_dir + "/input" + self.output_dir = test_data_dir + "/temp_output" + os.makedirs(self.output_dir, exist_ok=True) + + self.brain_extractor = HDBetExtractor() + self.input_image_path = input_dir + "/tcia_example_t1c.nii.gz" + self.masked_image_path = self.output_dir + "/bet_tcia_example_t1c.nii.gz" + self.brain_mask_path = self.output_dir + "/bet_tcia_example_t1c_masked.nii.gz" + + print(self.input_image_path) + print(self.masked_image_path) + + def tearDown(self): + # Clean up created files if they exist + shutil.rmtree(self.output_dir) + + def test_extract_creates_output_files(self): + # we try to run the fastest possible skullstripping on GPU + self.brain_extractor.extract( + input_image_path=self.input_image_path, + masked_image_path=self.masked_image_path, + brain_mask_path=self.brain_mask_path, + mode="fast", + device="cpu", + do_tta=False, + # TODO generate and also test for presence of log file + ) + + self.assertTrue( + os.path.exists(self.masked_image_path), "Masked image file was not created." + ) + self.assertTrue( + os.path.exists(self.brain_mask_path), + "Brain mask image file was not created.", + ) + + def test_apply_mask_creates_output_file(self): + # self.brain_extractor.apply_mask( + # self.input_image, self.mask_image, self.output_image + # ) + # self.assertTrue( + # os.path.exists(self.output_image_path), + # "Output image file was not created in apply_mask.", + # ) + ... diff --git a/tests/test_data/input/tcia_example_t1c.nii.gz b/tests/test_data/input/tcia_example_t1c.nii.gz new file mode 100644 index 0000000..9979cb9 Binary files /dev/null and b/tests/test_data/input/tcia_example_t1c.nii.gz differ