diff --git a/ereg/registration.py b/ereg/registration.py index 360d2f7..7ab2bed 100644 --- a/ereg/registration.py +++ b/ereg/registration.py @@ -65,6 +65,7 @@ def __init__( ] self.total_attempts = 5 self.log_file = None + self.transform = None if config_file is not None: self.update_parameters(config_file) @@ -275,23 +276,23 @@ def register( if transform_file is not None: if os.path.isfile(transform_file): try: - transform = sitk.ReadTransform(transform_file) + self.transform = sitk.ReadTransform(transform_file) compute_transform = False except: self.logger.info( - "Could not read transform file. Computing transform transform." + "Could not read transform file. Computing transform." ) pass if compute_transform: self.logger.info( f"Starting registration with parameters:: {self.parameters}" ) - transform = self._register_image_and_get_transform( + self.transform = self._register_image_and_get_transform( target_image=target_image, moving_image=moving_image, ) if transform_file is not None: - sitk.WriteTransform(transform, transform_file) + sitk.WriteTransform(self.transform, transform_file) # apply composite transform if provided if self.parameters["composite_transform"] is not None: @@ -299,7 +300,9 @@ def register( transform_composite = sitk.ReadTransform( self.parameters["composite_transform"] ) - transform = sitk.CompositeTransform(transform_composite, transform) + self.transform = sitk.CompositeTransform( + transform_composite, self.transform + ) if self.parameters["composite_transform"]: self.logger.info("Applying previous transforms.") @@ -307,28 +310,79 @@ def register( for previous_transform in self.parameters["previous_transforms"]: previous_transform = sitk.ReadTransform(previous_transform) current_transform = ( - sitk.CompositeTransform(previous_transform, transform) + sitk.CompositeTransform(previous_transform, self.transform) if current_transform is None else sitk.CompositeTransform(previous_transform, current_transform) ) - transform = current_transform - - self.logger.info("Resampling image.") - resampler = sitk.ResampleImageFilter() - resampler.SetReferenceImage(target_image) - interpolator_type = self.interpolator_type.get(self.parameters["interpolator"]) - resampler.SetInterpolator(interpolator_type) - resampler.SetDefaultPixelValue(0) - resampler.SetTransform(transform) - output_image_struct = resampler.Execute(moving_image) - sitk.WriteImage(output_image_struct, output_image) - self.ssim_score = get_ssim(target_image, output_image_struct) - self.logger.info( - f"SSIM score of moving against target image: {self.ssim_score}" - ) + self.transform = current_transform + + # no need for logging since resample_image will log by itself logging.shutdown() + # resample the moving image to the target image + self.resample_image( + target_image=target_image, + moving_image=moving_image, + output_image=output_image, + transform_file=transform_file, + ) + + def resample_image( + self, + target_image: Union[str, sitk.Image], + moving_image: Union[str, sitk.Image], + output_image: str, + transform_file: str = None, + **kwargs, + ) -> None: + """ + Resample the moving image to the target image. + + Args: + logger (logging.Logger): The logger to use. + target_image (Union[str, sitk.Image]): The target image. + moving_image (Union[str, sitk.Image]): The moving image. + output_image (str): The output image. + transform_file (str, optional): The transform file. Defaults to None. + """ + + # check if output image exists + if not os.path.exists(output_image): + if self.transform is not None: + if self.log_file is None: + self.log_file = output_image.replace(".nii.gz", ".log") + logging.basicConfig( + filename=self.log_file, + format="%(asctime)s,%(name)s,%(levelname)s,%(message)s", + datefmt="%H:%M:%S", + level=logging.DEBUG, + ) + self.logger = logging.getLogger("registration") + + self.logger.info( + f"Target image: {target_image}, Moving image: {moving_image}, Transform file: {transform_file}" + ) + target_image = read_image_and_cast_to_32bit_float(target_image) + moving_image = read_image_and_cast_to_32bit_float(moving_image) + + self.logger.info("Resampling image.") + resampler = sitk.ResampleImageFilter() + resampler.SetReferenceImage(target_image) + interpolator_type = self.interpolator_type.get( + self.parameters["interpolator"] + ) + resampler.SetInterpolator(interpolator_type) + resampler.SetDefaultPixelValue(0) + resampler.SetTransform(self.transform) + output_image_struct = resampler.Execute(moving_image) + sitk.WriteImage(output_image_struct, output_image) + self.ssim_score = get_ssim(target_image, output_image_struct) + self.logger.info( + f"SSIM score of moving against target image: {self.ssim_score}" + ) + logging.shutdown() + def _get_transform_wrapper(self, transform: str, dim: int) -> sitk.Transform: """ Get the transform class.