Skip to content

Commit

Permalink
Merge pull request #46 from BrainLesion/typing_for_mod
Browse files Browse the repository at this point in the history
Typing for mod and prep
  • Loading branch information
neuronflow authored Dec 13, 2023
2 parents b17f805 + 24cfedc commit 24be746
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 52 deletions.
120 changes: 90 additions & 30 deletions brainles_preprocessing/modality.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# todo add typing and docs
import os
import shutil

from auxiliary.nifti.io import read_nifti, write_nifti
from auxiliary.normalization.normalizer_base import Normalizer
from auxiliary.turbopath import turbopath

from brainles_preprocessing.brain_extraction.brain_extractor import BrainExtractor
from brainles_preprocessing.registration.registrator import Registrator


Expand Down Expand Up @@ -34,6 +34,7 @@ class Modality:
... output_path="/path/to/preprocessed_t1.nii",
... bet=True
... )
"""

def __init__(
Expand All @@ -53,10 +54,19 @@ def __init__(

def normalize(
self,
temporary_directory,
store_unnormalized=None,
):
# TODO we need docstrings
temporary_directory: str,
store_unnormalized: str | None = None,
) -> None:
"""
Normalize the image using the specified normalizer.
Args:
temporary_directory (str): Path to the temporary directory.
store_unnormalized (str, optional): Path to store unnormalized images.
Returns:
None
"""
# Backup the unnormalized file
if store_unnormalized is not None:
os.makedirs(store_unnormalized, exist_ok=True)
Expand Down Expand Up @@ -85,11 +95,23 @@ def normalize(

def register(
self,
registrator,
registrator: Registrator,
fixed_image_path: str,
registration_dir: str,
moving_image_name: str,
):
) -> str:
"""
Register the current modality to a fixed image using the specified registrator.
Args:
registrator (Registrator): The registrator object.
fixed_image_path (str): Path to the fixed image.
registration_dir (str): Directory to store registration results.
moving_image_name (str): Name of the moving image.
Returns:
str: Path to the registration matrix.
"""
registered = os.path.join(registration_dir, f"{moving_image_name}.nii.gz")
registered_matrix = os.path.join(registration_dir, f"{moving_image_name}.txt")
registered_log = os.path.join(registration_dir, f"{moving_image_name}.log")
Expand All @@ -106,58 +128,96 @@ def register(

def apply_mask(
self,
brain_extractor,
brain_masked_dir,
atlas_mask,
):
brain_extractor: BrainExtractor,
brain_masked_dir_path: str,
atlas_mask_path: str,
) -> None:
"""
Apply a brain mask to the current modality using the specified brain extractor.
Args:
brain_extractor (BrainExtractor): The brain extractor object.
brain_masked_dir_path (str): Directory to store masked images.
atlas_mask_path (str): Path to the brain mask.
Returns:
None
"""
if self.bet:
brain_masked = os.path.join(
brain_masked_dir,
brain_masked_dir_path,
f"brain_masked__{self.modality_name}.nii.gz",
)
brain_extractor.apply_mask(
input_image_path=self.current,
mask_image_path=atlas_mask,
mask_image_path=atlas_mask_path,
masked_image_path=brain_masked,
)
self.current = brain_masked

def transform(
self,
registrator: Registrator,
fixed_image_path,
registration_dir,
moving_image_name,
transformation_matrix,
):
transformed = os.path.join(registration_dir, f"{moving_image_name}.nii.gz")
transformed_log = os.path.join(registration_dir, f"{moving_image_name}.log")
fixed_image_path: str,
registration_dir_path: str,
moving_image_name: str,
transformation_matrix_path: str,
) -> None:
"""
Transform the current modality using the specified registrator and transformation matrix.
Args:
registrator (Registrator): The registrator object.
fixed_image_path (str): Path to the fixed image.
registration_dir_path (str): Directory to store transformation results.
moving_image_name (str): Name of the moving image.
transformation_matrix_path (str): Path to the transformation matrix.
Returns:
None
"""
transformed = os.path.join(registration_dir_path, f"{moving_image_name}.nii.gz")
transformed_log = os.path.join(
registration_dir_path, f"{moving_image_name}.log"
)

registrator.transform(
fixed_image_path=fixed_image_path,
moving_image_path=self.current,
transformed_image_path=transformed,
matrix_path=transformation_matrix,
matrix_path=transformation_matrix_path,
log_file_path=transformed_log,
)
self.current = transformed

def extract_brain_region(
self,
brain_extractor,
bet_dir,
):
bet_log = os.path.join(bet_dir, "brain-extraction.log")
atlas_bet_cm = os.path.join(bet_dir, f"atlas_bet_{self.modality_name}.nii.gz")
atlas_mask = os.path.join(
bet_dir, f"atlas_bet_{self.modality_name}_mask.nii.gz"
brain_extractor: BrainExtractor,
bet_dir_path: str,
) -> str:
"""
Extract the brain region using the specified brain extractor.
Args:
brain_extractor (BrainExtractor): The brain extractor object.
bet_dir_path (str): Directory to store brain extraction results.
Returns:
str: Path to the extracted brain mask.
"""
bet_log = os.path.join(bet_dir_path, "brain-extraction.log")
atlas_bet_cm = os.path.join(
bet_dir_path, f"atlas_bet_{self.modality_name}.nii.gz"
)
atlas_mask_path = os.path.join(
bet_dir_path, f"atlas_bet_{self.modality_name}_mask.nii.gz"
)

brain_extractor.extract(
input_image_path=self.current,
masked_image_path=atlas_bet_cm,
brain_mask_path=atlas_mask,
brain_mask_path=atlas_mask_path,
log_file_path=bet_log,
)
self.current = atlas_bet_cm
return atlas_mask
return atlas_mask_path
88 changes: 66 additions & 22 deletions brainles_preprocessing/preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# TODO add typing and documentation
import os
import shutil
import tempfile
from typing import List, Optional

from auxiliary.turbopath import turbopath

Expand All @@ -11,27 +11,39 @@


class Preprocessor:
"""
Preprocesses medical image modalities using coregistration, normalization, brain extraction, and more.
Args:
center_modality (Modality): The central modality for coregistration.
moving_modalities (List[Modality]): List of modalities to be coregistered to the central modality.
registrator (Registrator): The registrator object for coregistration and registration to the atlas.
brain_extractor (BrainExtractor): The brain extractor object for brain extraction.
atlas_image_path (str, optional): Path to the atlas image for registration (default is the T1 atlas).
temp_folder (str, optional): Path to a temporary folder for storing intermediate results.
"""

def __init__(
self,
center_modality: Modality,
moving_modalities: list[Modality],
moving_modalities: List[Modality],
registrator: Registrator,
brain_extractor: BrainExtractor,
atlas_image_path: str = turbopath(__file__).parent
+ "/registration/atlas/t1_brats_space.nii",
temp_folder=None,
temp_folder: Optional[str] = None,
):
self.center_modality = center_modality
self.moving_modalities = moving_modalities
self.atlas_image_path = turbopath(atlas_image_path)
self.registrator = registrator
self.brain_extractor = brain_extractor

# create temporary storage
# Create temporary storage
if temp_folder:
os.makedirs(temp_folder, exist_ok=True)
self.temp_folder = turbopath(temp_folder)
# custom temporary storage for debugging etc
else:
storage = tempfile.TemporaryDirectory()
self.temp_folder = turbopath(storage.name)
Expand All @@ -43,11 +55,22 @@ def run(
self,
brain_extraction: bool,
normalization: bool,
save_dir_coregistration: str = None,
save_dir_atlas_registration: str = None,
save_dir_brain_extraction: str = None,
save_dir_unnormalized: str = None,
save_dir_coregistration: Optional[str] = None,
save_dir_atlas_registration: Optional[str] = None,
save_dir_brain_extraction: Optional[str] = None,
save_dir_unnormalized: Optional[str] = None,
):
"""
Run the preprocessing pipeline.
Args:
brain_extraction (bool): Whether to perform brain extraction.
normalization (bool): Whether to perform intensity normalization.
save_dir_coregistration (str, optional): Directory to save coregistration results.
save_dir_atlas_registration (str, optional): Directory to save atlas registration results.
save_dir_brain_extraction (str, optional): Directory to save brain extraction results.
save_dir_unnormalized (str, optional): Directory to save unnormalized images.
"""
# Coregister moving modalities to center modality
coregistration_dir = os.path.join(self.temp_folder, "coregistration")
os.makedirs(coregistration_dir, exist_ok=True)
Expand Down Expand Up @@ -80,9 +103,9 @@ def run(
moving_modality.transform(
registrator=self.registrator,
fixed_image_path=self.atlas_image_path,
registration_dir=self.atlas_dir,
registration_dir_path=self.atlas_dir,
moving_image_name=file_name,
transformation_matrix=transformation_matrix,
transformation_matrix_path=transformation_matrix,
)
self._save_output(
src=self.atlas_dir,
Expand All @@ -97,13 +120,13 @@ def run(
os.makedirs(brain_masked_dir, exist_ok=True)

atlas_mask = self.center_modality.extract_brain_region(
brain_extractor=self.brain_extractor, bet_dir=bet_dir
brain_extractor=self.brain_extractor, bet_dir_path=bet_dir
)
for moving_modality in self.moving_modalities:
moving_modality.apply_mask(
brain_extractor=self.brain_extractor,
brain_masked_dir=brain_masked_dir,
atlas_mask=atlas_mask,
brain_masked_dir_path=brain_masked_dir,
atlas_mask_path=atlas_mask,
)

self._save_output(
Expand Down Expand Up @@ -132,8 +155,8 @@ def all_modalities(self):

def _save_output(
self,
src,
save_dir,
src: str,
save_dir: Optional[str],
):
if save_dir:
save_dir = turbopath(save_dir)
Expand All @@ -145,8 +168,8 @@ def _save_output(

def _save_coregistration(
self,
coregistration_dir,
save_dir_coregistration,
coregistration_dir: str,
save_dir_coregistration: Optional[str],
):
if save_dir_coregistration:
save_dir_coregistration = turbopath(save_dir_coregistration)
Expand All @@ -167,16 +190,30 @@ def _save_coregistration(


class PreprocessorGPU(Preprocessor):
"""
Preprocesses medical image modalities using GPU acceleration.
Args:
center_modality (Modality): The central modality for coregistration.
moving_modalities (List[Modality]): List of modalities to be coregistered to the central modality.
registrator (Registrator): The registrator object for coregistration and registration to the atlas.
brain_extractor (BrainExtractor): The brain extractor object for brain extraction.
atlas_image_path (str, optional): Path to the atlas image for registration (default is the T1 atlas).
temp_folder (str, optional): Path to a temporary folder for storing intermediate results.
limit_cuda_visible_devices (str, optional): Limit CUDA visible devices for GPU acceleration.
"""

def __init__(
self,
center_modality: Modality,
moving_modalities: list[Modality],
moving_modalities: List[Modality],
registrator: Registrator,
brain_extractor: BrainExtractor,
atlas_image_path: str = turbopath(__file__).parent
+ "/registration/atlas/t1_brats_space.nii",
temp_folder=None,
limit_cuda_visible_devices: str = None,
temp_folder: Optional[str] = None,
limit_cuda_visible_devices: Optional[str] = None,
):
super().__init__(
center_modality,
Expand All @@ -192,8 +229,15 @@ def __init__(

def set_cuda_devices(
self,
limit_cuda_visible_devices: str = None,
limit_cuda_visible_devices: Optional[str] = None,
):
"""
Set CUDA devices for GPU acceleration.
Args:
limit_cuda_visible_devices (str, optional): Limit CUDA visible devices for GPU acceleration.
"""
if limit_cuda_visible_devices:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = limit_cuda_visible_devices

0 comments on commit 24be746

Please sign in to comment.