Skip to content

Commit

Permalink
Merge pull request #29 from BrainLesion/feat/refactor_brain_extraction
Browse files Browse the repository at this point in the history
Refactor brain extraction
  • Loading branch information
IsraMekki0 authored Nov 27, 2023
2 parents a60a1cf + a180348 commit 225af2b
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 31 deletions.
87 changes: 62 additions & 25 deletions brainles_preprocessing/brain_extraction/brain_extractor.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,79 @@
# TODO add typing and docs
from abc import abstractmethod
import os

import nibabel as nib
import numpy as np
from brainles_hd_bet import run_hd_bet

from auxiliary.nifti.io import read_nifti, write_nifti
from auxiliary.turbopath import name_extractor


from shutil import copyfile


class BrainExtractor:
@abstractmethod
def extract(
self,
input_image,
output_image,
log_file,
mode,
):
input_image_path: str,
masked_image_path: str,
brain_mask_path: str,
log_file_path: str,
# TODO convert mode to enum
mode: str,
) -> None:
pass

def apply_mask(
self,
input_image,
mask_image,
output_image,
):
"""masks images with brain masks"""
inputnifti = nib.load(input_image)
mask = nib.load(mask_image)
input_image_path: str,
mask_image_path: str,
masked_image_path: str,
) -> None:
"""
Apply a brain mask to an input image.
# mask it
masked_file = np.multiply(inputnifti.get_fdata(), mask.get_fdata())
masked_file = nib.Nifti1Image(masked_file, inputnifti.affine, inputnifti.header)
Parameters:
- input_image_path (str): Path to the input image (NIfTI format).
- mask_image_path (str): Path to the brain mask image (NIfTI format).
- masked_image_path (str): Path to save the resulting masked image (NIfTI format).
# save it
nib.save(masked_file, output_image)
Returns:
- str: Path to the saved masked image.
"""

# read data
input_data = read_nifti(input_image_path)
mask_data = read_nifti(mask_image_path)

# mask and save it
masked_data = input_data * mask_data

write_nifti(
input_array=masked_data,
output_nifti_path=masked_image_path,
reference_nifti_path=input_image_path,
create_parent_directory=True,
)


class HDBetExtractor(BrainExtractor):
def extract(
self,
input_image,
masked_image,
# TODO implement logging!
log_file,
mode="accurate",
):
input_image_path: str,
masked_image_path: str,
brain_mask_path: str,
log_file_path: str = None,
# TODO convert mode to enum
mode: str = "accurate",
) -> None:
# GPU + accurate + TTA
"""skullstrips images with HD-BET generates a skullstripped file and mask"""
run_hd_bet(
mri_fnames=[input_image],
output_fnames=[masked_image],
mri_fnames=[input_image_path],
output_fnames=[masked_image_path],
# device=0,
# TODO consider postprocessing
# postprocess=False,
Expand All @@ -59,3 +84,15 @@ def extract(
keep_mask=True,
overwrite=True,
)

hdbet_mask_path = (
masked_image_path.parent
+ "/"
+ name_extractor(masked_image_path)
+ "_masked.nii.gz"
)

copyfile(
src=hdbet_mask_path,
dst=brain_mask_path,
)
13 changes: 7 additions & 6 deletions brainles_preprocessing/modality.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def apply_mask(
f"brain_masked__{self.modality_name}.nii.gz",
)
brain_extractor.apply_mask(
input_image=self.current,
mask_image=atlas_mask,
output_image=brain_masked,
input_image_path=self.current,
mask_image_path=atlas_mask,
masked_image_path=brain_masked,
)
self.current = brain_masked

Expand Down Expand Up @@ -153,9 +153,10 @@ def extract_brain_region(
)

brain_extractor.extract(
input_image=self.current,
masked_image=atlas_bet_cm,
log_file=bet_log,
input_image_path=self.current,
masked_image_path=atlas_bet_cm,
brain_mask_path=atlas_mask,
log_file_path=bet_log,
)
self.current = atlas_bet_cm
return atlas_mask

0 comments on commit 225af2b

Please sign in to comment.