Skip to content

Commit

Permalink
Merge pull request #17 from BrainLesion/addd_resample_function
Browse files Browse the repository at this point in the history
Add `resample_image` method to `RegistrationClass`
  • Loading branch information
neuronflow authored Dec 13, 2023
2 parents 37012bc + 080946d commit e49bb0a
Showing 1 changed file with 75 additions and 21 deletions.
96 changes: 75 additions & 21 deletions ereg/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
]
self.total_attempts = 5
self.log_file = None
self.transform = None
if config_file is not None:
self.update_parameters(config_file)

Expand Down Expand Up @@ -275,60 +276,113 @@ def register(
if transform_file is not None:
if os.path.isfile(transform_file):
try:
transform = sitk.ReadTransform(transform_file)
self.transform = sitk.ReadTransform(transform_file)
compute_transform = False
except:
self.logger.info(
"Could not read transform file. Computing transform transform."
"Could not read transform file. Computing transform."
)
pass
if compute_transform:
self.logger.info(
f"Starting registration with parameters:: {self.parameters}"
)
transform = self._register_image_and_get_transform(
self.transform = self._register_image_and_get_transform(
target_image=target_image,
moving_image=moving_image,
)
if transform_file is not None:
sitk.WriteTransform(transform, transform_file)
sitk.WriteTransform(self.transform, transform_file)

# apply composite transform if provided
if self.parameters["composite_transform"] is not None:
self.logger.info("Applying composite transform.")
transform_composite = sitk.ReadTransform(
self.parameters["composite_transform"]
)
transform = sitk.CompositeTransform(transform_composite, transform)
self.transform = sitk.CompositeTransform(
transform_composite, self.transform
)

if self.parameters["composite_transform"]:
self.logger.info("Applying previous transforms.")
current_transform = None
for previous_transform in self.parameters["previous_transforms"]:
previous_transform = sitk.ReadTransform(previous_transform)
current_transform = (
sitk.CompositeTransform(previous_transform, transform)
sitk.CompositeTransform(previous_transform, self.transform)
if current_transform is None
else sitk.CompositeTransform(previous_transform, current_transform)
)

transform = current_transform

self.logger.info("Resampling image.")
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(target_image)
interpolator_type = self.interpolator_type.get(self.parameters["interpolator"])
resampler.SetInterpolator(interpolator_type)
resampler.SetDefaultPixelValue(0)
resampler.SetTransform(transform)
output_image_struct = resampler.Execute(moving_image)
sitk.WriteImage(output_image_struct, output_image)
self.ssim_score = get_ssim(target_image, output_image_struct)
self.logger.info(
f"SSIM score of moving against target image: {self.ssim_score}"
)
self.transform = current_transform

# no need for logging since resample_image will log by itself
logging.shutdown()

# resample the moving image to the target image
self.resample_image(
target_image=target_image,
moving_image=moving_image,
output_image=output_image,
transform_file=transform_file,
)

def resample_image(
self,
target_image: Union[str, sitk.Image],
moving_image: Union[str, sitk.Image],
output_image: str,
transform_file: str = None,
**kwargs,
) -> None:
"""
Resample the moving image to the target image.
Args:
logger (logging.Logger): The logger to use.
target_image (Union[str, sitk.Image]): The target image.
moving_image (Union[str, sitk.Image]): The moving image.
output_image (str): The output image.
transform_file (str, optional): The transform file. Defaults to None.
"""

# check if output image exists
if not os.path.exists(output_image):
if self.transform is not None:
if self.log_file is None:
self.log_file = output_image.replace(".nii.gz", ".log")
logging.basicConfig(
filename=self.log_file,
format="%(asctime)s,%(name)s,%(levelname)s,%(message)s",
datefmt="%H:%M:%S",
level=logging.DEBUG,
)
self.logger = logging.getLogger("registration")

self.logger.info(
f"Target image: {target_image}, Moving image: {moving_image}, Transform file: {transform_file}"
)
target_image = read_image_and_cast_to_32bit_float(target_image)
moving_image = read_image_and_cast_to_32bit_float(moving_image)

self.logger.info("Resampling image.")
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(target_image)
interpolator_type = self.interpolator_type.get(
self.parameters["interpolator"]
)
resampler.SetInterpolator(interpolator_type)
resampler.SetDefaultPixelValue(0)
resampler.SetTransform(self.transform)
output_image_struct = resampler.Execute(moving_image)
sitk.WriteImage(output_image_struct, output_image)
self.ssim_score = get_ssim(target_image, output_image_struct)
self.logger.info(
f"SSIM score of moving against target image: {self.ssim_score}"
)
logging.shutdown()

def _get_transform_wrapper(self, transform: str, dim: int) -> sitk.Transform:
"""
Get the transform class.
Expand Down

0 comments on commit e49bb0a

Please sign in to comment.