diff --git a/data/default_rigid.yaml b/ereg/configurations/default_rigid.yaml similarity index 100% rename from data/default_rigid.yaml rename to ereg/configurations/default_rigid.yaml diff --git a/data/sample_config.yaml b/ereg/configurations/sample_config.yaml similarity index 100% rename from data/sample_config.yaml rename to ereg/configurations/sample_config.yaml diff --git a/ereg/functional.py b/ereg/functional.py index 5469e27..af946af 100644 --- a/ereg/functional.py +++ b/ereg/functional.py @@ -10,7 +10,7 @@ def registration_function( target_image: Union[str, sitk.Image], moving_image: Union[str, sitk.Image], output_image: str, - config_file: str, + configuration: str, transform_file: str = None, log_file: str = None, **kwargs, @@ -28,14 +28,14 @@ def registration_function( Returns: float: The structural similarity index. """ - if isinstance(config_file, str): - assert os.path.isfile(config_file), "Config file does not exist." - elif isinstance(config_file, dict): + if isinstance(configuration, str): + assert os.path.isfile(configuration), "Config file does not exist." + elif isinstance(configuration, dict): pass else: raise ValueError("Config file must be a string or dictionary.") - registration_obj = RegistrationClass(config_file) + registration_obj = RegistrationClass(configuration) registration_obj.register( target_image=target_image, moving_image=moving_image, diff --git a/ereg/registration.py b/ereg/registration.py index c4e3780..1bc8e3e 100644 --- a/ereg/registration.py +++ b/ereg/registration.py @@ -15,7 +15,7 @@ class RegistrationClass: def __init__( self, - config_file: Union[str, dict] = None, + configuration: Union[str, dict] = None, **kwargs, ) -> None: """ @@ -65,167 +65,56 @@ def __init__( ] self.total_attempts = 5 self.transform = None - if config_file is not None: - self.update_parameters(config_file) - def update_parameters(self, config_file: Union[str, dict], **kwargs): - """ - Update the parameters for the registration. - - Args: - config_file (Union[str, dict]): The config file or dictionary. - """ - if isinstance(config_file, str): - self.parameters = yaml.safe_load(open(config_file, "r")) - elif isinstance(config_file, dict): - self.parameters = config_file + if configuration is not None: + self.update_parameters(configuration) else: - raise ValueError("Config file must be a string or dictionary.") + self.parameters = self._generate_default_parameters() - self.parameters["metric"] = ( - self.parameters.get("metric", "mean_squares") - .replace("_", "") - .replace("-", "") - .lower() - ) - self.parameters["metric_parameters"] = self.parameters.get( - "metric_parameters", {} - ) - self.parameters["metric_parameters"]["histogram_bins"] = self.parameters[ - "metric_parameters" - ].get("histogram_bins", 50) - self.parameters["metric_parameters"]["radius"] = self.parameters[ - "metric_parameters" - ].get("radius", 5) - self.parameters["metric_parameters"][ - "intensityDifferenceThreshold" - ] = self.parameters["metric_parameters"].get( - "intensityDifferenceThreshold", 0.001 - ) - self.parameters["metric_parameters"][ - "varianceForJointPDFSmoothing" - ] = self.parameters["metric_parameters"].get( - "varianceForJointPDFSmoothing", 1.5 + def _generate_default_parameters(self) -> dict: + defaults_file = os.path.normpath( + os.path.abspath( + __file__ + "configurations/default_rigid.yaml", + ) ) + default_parameters = self.parameters = yaml.safe_load(open(defaults_file, "r")) + return default_parameters - self.parameters["transform"] = ( - self.parameters.get("transform", "versor") - .replace("_", "") - .replace("-", "") - .lower() - ) - assert self.parameters["transform"] in self.available_transforms, ( - f"Transform {self.parameters['transform']} not recognized. " - f"Available transforms: {self.available_transforms}" - ) - if self.parameters["transform"] in ["euler", "versorrigid"]: - self.parameters["rigid_registration"] = True - self.parameters["initialization"] = self.parameters.get( - "initialization", "geometry" - ).lower() - self.parameters["bias_correct"] = self.parameters.get( - "bias_correct", self.parameters.get("bias", False) - ) - self.parameters["interpolator"] = ( - self.parameters.get("interpolator", "linear") - .replace("_", "") - .replace("-", "") - .lower() - ) - self.parameters["shrink_factors"] = self.parameters.get( - "shrink_factors", self.parameters.get("shrink", [8, 4, 2]) - ) - self.parameters["smoothing_sigmas"] = self.parameters.get( - "smoothing_sigmas", self.parameters.get("smooth", [3, 2, 1]) - ) - assert len(self.parameters["shrink_factors"]) == len( - self.parameters["smoothing_sigmas"] - ), "The number of shrink factors and smoothing sigmas must be the same." - self.parameters["sampling_strategy"] = self.parameters.get( - "sampling_strategy", "none" - ) - self.parameters["sampling_percentage"] = self.parameters.get( - "sampling_percentage", 0.01 - ) - if isinstance(self.parameters["sampling_percentage"], int) or isinstance( - self.parameters["sampling_percentage"], float - ): - temp_percentage = self.parameters["sampling_percentage"] - self.parameters["sampling_percentage"] = [] - for _ in range(len(self.parameters["shrink_factors"])): - self.parameters["sampling_percentage"].append(temp_percentage) + @property + def configuration(self) -> dict: + return self.parameters - assert len(self.parameters["shrink_factors"]) == len( - self.parameters["sampling_percentage"] - ), "The number of shrink factors and sampling percentages must be the same." + @configuration.setter + def configuration( + self, + new_config_file: Union[str, dict], + ) -> None: + self.parameters = self._generate_default_parameters() + self.update_parameters(configuration=new_config_file) - self.parameters["attempts"] = self.parameters.get("attempts", 5) + def update_parameters( + self, + configuration: Union[str, dict], + ): + """ + Update the parameters for the registration. - # check for composite transforms - self.parameters["composite_transform"] = self.parameters.get( - "composite_transform", None - ) - if self.parameters["composite_transform"]: - self.parameters["previous_transforms"] = self.parameters.get( - "previous_transforms", [] + Args: + config_file (Union[str, dict]): The tring path pointing to a .yml configuration file or configuration dictionary. + """ + if isinstance(configuration, str): + config_data = yaml.safe_load(open(configuration, "r")) + elif isinstance(configuration, dict): + config_data = configuration + else: + raise ValueError( + "Configuration must be a string path pointing to a .yml file or dictionary." ) - # checks related to composite transforms - assert isinstance( - self.parameters["previous_transforms"], list - ), "Previous transforms must be a list." - assert ( - len(self.parameters["previous_transforms"]) > 0 - ), "No previous transforms provided." - - self.parameters["optimizer"] = self.parameters.get( - "optimizer", "regular_step_gradient_descent" - ) - - # this is taken directly from the sample_config.yaml - default_optimizer_parameters = { - "min_step": 1e-6, # regular_step_gradient_descent - "max_step": 1.0, # gradient_descent, regular_step_gradient_descent - "maximumStepSizeInPhysicalUnits": 1.0, # regular_step_gradient_descent, gradient_descent_line_search, gradient_descent, - "iterations": 1000, # regular_step_gradient_descent, gradient_descent_line_search, gradient_descent, conjugate, lbfgsb, lbfgsb2 - "learningrate": 1.0, # gradient_descent, gradient_descent_line_search - "convergence_minimum": 1e-6, # gradient_descent, gradient_descent_line_search - "convergence_window_size": 10, # gradient_descent, gradient_descent_line_search - "line_search_lower_limit": 0.0, # gradient_descent_line_search - "line_search_upper_limit": 5.0, # gradient_descent_line_search - "line_search_epsilon": 0.01, # gradient_descent_line_search - "step_length": 0.1, # conjugate, exhaustive, powell - "simplex_delta": 0.1, # amoeba - "maximum_number_of_corrections": 5, # lbfgsb, lbfgsb2 - "maximum_number_of_function_evaluations": 2000, # lbfgsb, lbfgsb2 - "solution_accuracy": 1e-5, # lbfgsb2 - "hessian_approximate_accuracy": 1e-5, # lbfgsb2 - "delta_convergence_distance": 1e-5, # lbfgsb2 - "delta_convergence_tolerance": 1e-5, # lbfgsb2 - "line_search_maximum_evaluations": 50, # lbfgsb2 - "line_search_minimum_step": 1e-20, # lbfgsb2 - "line_search_accuracy": 1e-4, # lbfgsb2 - "epsilon": 1e-8, # one_plus_one_evolutionary - "initial_radius": 1.0, # one_plus_one_evolutionary - "growth_factor": -1.0, # one_plus_one_evolutionary - "shrink_factor": -1.0, # one_plus_one_evolutionary - "maximum_line_iterations": 100, # powell - "step_tolerance": 1e-6, # powell - "value_tolerance": 1e-6, # powell - "relaxation": 0.5, # regular_step_gradient_descent - "tolerance": 1e-4, # regular_step_gradient_descent - "rigid_registration": False, - } - - # check for optimizer parameters in config file - self.parameters["optimizer_parameters"] = self.parameters.get( - "optimizer_parameters", {} - ) - - # for any optimizer parameters not in the config file, use the default values - for key, value in default_optimizer_parameters.items(): - if key not in self.parameters["optimizer_parameters"]: - self.parameters["optimizer_parameters"][key] = value + # Update only the keys present in the YAML file + for key, value in config_data.items(): + if key in self.parameters: + self.parameters[key] = value def register( self,