Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reading default values from yml files #23

Merged
merged 11 commits into from
Mar 27, 2024
File renamed without changes.
File renamed without changes.
10 changes: 5 additions & 5 deletions ereg/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def registration_function(
target_image: Union[str, sitk.Image],
moving_image: Union[str, sitk.Image],
output_image: str,
config_file: str,
configuration: str,
transform_file: str = None,
log_file: str = None,
**kwargs,
Expand All @@ -28,14 +28,14 @@ def registration_function(
Returns:
float: The structural similarity index.
"""
if isinstance(config_file, str):
assert os.path.isfile(config_file), "Config file does not exist."
elif isinstance(config_file, dict):
if isinstance(configuration, str):
assert os.path.isfile(configuration), "Config file does not exist."
elif isinstance(configuration, dict):
pass
else:
raise ValueError("Config file must be a string or dictionary.")

registration_obj = RegistrationClass(config_file)
registration_obj = RegistrationClass(configuration)
registration_obj.register(
target_image=target_image,
moving_image=moving_image,
Expand Down
193 changes: 41 additions & 152 deletions ereg/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class RegistrationClass:
def __init__(
self,
config_file: Union[str, dict] = None,
configuration: Union[str, dict] = None,
**kwargs,
) -> None:
"""
Expand Down Expand Up @@ -65,167 +65,56 @@ def __init__(
]
self.total_attempts = 5
self.transform = None
if config_file is not None:
self.update_parameters(config_file)

def update_parameters(self, config_file: Union[str, dict], **kwargs):
"""
Update the parameters for the registration.

Args:
config_file (Union[str, dict]): The config file or dictionary.
"""
if isinstance(config_file, str):
self.parameters = yaml.safe_load(open(config_file, "r"))
elif isinstance(config_file, dict):
self.parameters = config_file
if configuration is not None:
self.update_parameters(configuration)
else:
raise ValueError("Config file must be a string or dictionary.")
self.parameters = self._generate_default_parameters()

self.parameters["metric"] = (
self.parameters.get("metric", "mean_squares")
.replace("_", "")
.replace("-", "")
.lower()
)
self.parameters["metric_parameters"] = self.parameters.get(
"metric_parameters", {}
)
self.parameters["metric_parameters"]["histogram_bins"] = self.parameters[
"metric_parameters"
].get("histogram_bins", 50)
self.parameters["metric_parameters"]["radius"] = self.parameters[
"metric_parameters"
].get("radius", 5)
self.parameters["metric_parameters"][
"intensityDifferenceThreshold"
] = self.parameters["metric_parameters"].get(
"intensityDifferenceThreshold", 0.001
)
self.parameters["metric_parameters"][
"varianceForJointPDFSmoothing"
] = self.parameters["metric_parameters"].get(
"varianceForJointPDFSmoothing", 1.5
def _generate_default_parameters(self) -> dict:
defaults_file = os.path.normpath(
os.path.abspath(
__file__ + "configurations/default_rigid.yaml",
)
)
default_parameters = self.parameters = yaml.safe_load(open(defaults_file, "r"))
return default_parameters

self.parameters["transform"] = (
self.parameters.get("transform", "versor")
.replace("_", "")
.replace("-", "")
.lower()
)
assert self.parameters["transform"] in self.available_transforms, (
f"Transform {self.parameters['transform']} not recognized. "
f"Available transforms: {self.available_transforms}"
)
if self.parameters["transform"] in ["euler", "versorrigid"]:
self.parameters["rigid_registration"] = True
self.parameters["initialization"] = self.parameters.get(
"initialization", "geometry"
).lower()
self.parameters["bias_correct"] = self.parameters.get(
"bias_correct", self.parameters.get("bias", False)
)
self.parameters["interpolator"] = (
self.parameters.get("interpolator", "linear")
Copy link
Contributor Author

@neuronflow neuronflow Dec 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand why this is done @sarthakpati . The

            .replace("_", "")
            .replace("-", "")
            .lower()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sarthakpati did you see this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uh, I think this is to just remove these specific special characters to make the comparisons easier in the next lines.

.replace("_", "")
.replace("-", "")
.lower()
)
self.parameters["shrink_factors"] = self.parameters.get(
"shrink_factors", self.parameters.get("shrink", [8, 4, 2])
)
self.parameters["smoothing_sigmas"] = self.parameters.get(
"smoothing_sigmas", self.parameters.get("smooth", [3, 2, 1])
)
assert len(self.parameters["shrink_factors"]) == len(
self.parameters["smoothing_sigmas"]
), "The number of shrink factors and smoothing sigmas must be the same."
self.parameters["sampling_strategy"] = self.parameters.get(
"sampling_strategy", "none"
)
self.parameters["sampling_percentage"] = self.parameters.get(
"sampling_percentage", 0.01
)
if isinstance(self.parameters["sampling_percentage"], int) or isinstance(
self.parameters["sampling_percentage"], float
):
temp_percentage = self.parameters["sampling_percentage"]
self.parameters["sampling_percentage"] = []
for _ in range(len(self.parameters["shrink_factors"])):
self.parameters["sampling_percentage"].append(temp_percentage)
@property
def configuration(self) -> dict:
return self.parameters

assert len(self.parameters["shrink_factors"]) == len(
self.parameters["sampling_percentage"]
), "The number of shrink factors and sampling percentages must be the same."
@configuration.setter
def configuration(
self,
new_config_file: Union[str, dict],
) -> None:
self.parameters = self._generate_default_parameters()
self.update_parameters(configuration=new_config_file)

self.parameters["attempts"] = self.parameters.get("attempts", 5)
def update_parameters(
self,
configuration: Union[str, dict],
):
"""
Update the parameters for the registration.

# check for composite transforms
self.parameters["composite_transform"] = self.parameters.get(
"composite_transform", None
)
if self.parameters["composite_transform"]:
self.parameters["previous_transforms"] = self.parameters.get(
"previous_transforms", []
Args:
config_file (Union[str, dict]): The tring path pointing to a .yml configuration file or configuration dictionary.
"""
if isinstance(configuration, str):
config_data = yaml.safe_load(open(configuration, "r"))
elif isinstance(configuration, dict):
config_data = configuration
else:
raise ValueError(
"Configuration must be a string path pointing to a .yml file or dictionary."
)

# checks related to composite transforms
assert isinstance(
self.parameters["previous_transforms"], list
), "Previous transforms must be a list."
assert (
len(self.parameters["previous_transforms"]) > 0
), "No previous transforms provided."

self.parameters["optimizer"] = self.parameters.get(
"optimizer", "regular_step_gradient_descent"
)

# this is taken directly from the sample_config.yaml
default_optimizer_parameters = {
"min_step": 1e-6, # regular_step_gradient_descent
"max_step": 1.0, # gradient_descent, regular_step_gradient_descent
"maximumStepSizeInPhysicalUnits": 1.0, # regular_step_gradient_descent, gradient_descent_line_search, gradient_descent,
"iterations": 1000, # regular_step_gradient_descent, gradient_descent_line_search, gradient_descent, conjugate, lbfgsb, lbfgsb2
"learningrate": 1.0, # gradient_descent, gradient_descent_line_search
"convergence_minimum": 1e-6, # gradient_descent, gradient_descent_line_search
"convergence_window_size": 10, # gradient_descent, gradient_descent_line_search
"line_search_lower_limit": 0.0, # gradient_descent_line_search
"line_search_upper_limit": 5.0, # gradient_descent_line_search
"line_search_epsilon": 0.01, # gradient_descent_line_search
"step_length": 0.1, # conjugate, exhaustive, powell
"simplex_delta": 0.1, # amoeba
"maximum_number_of_corrections": 5, # lbfgsb, lbfgsb2
"maximum_number_of_function_evaluations": 2000, # lbfgsb, lbfgsb2
"solution_accuracy": 1e-5, # lbfgsb2
"hessian_approximate_accuracy": 1e-5, # lbfgsb2
"delta_convergence_distance": 1e-5, # lbfgsb2
"delta_convergence_tolerance": 1e-5, # lbfgsb2
"line_search_maximum_evaluations": 50, # lbfgsb2
"line_search_minimum_step": 1e-20, # lbfgsb2
"line_search_accuracy": 1e-4, # lbfgsb2
"epsilon": 1e-8, # one_plus_one_evolutionary
"initial_radius": 1.0, # one_plus_one_evolutionary
"growth_factor": -1.0, # one_plus_one_evolutionary
"shrink_factor": -1.0, # one_plus_one_evolutionary
"maximum_line_iterations": 100, # powell
"step_tolerance": 1e-6, # powell
"value_tolerance": 1e-6, # powell
"relaxation": 0.5, # regular_step_gradient_descent
"tolerance": 1e-4, # regular_step_gradient_descent
"rigid_registration": False,
}

# check for optimizer parameters in config file
self.parameters["optimizer_parameters"] = self.parameters.get(
"optimizer_parameters", {}
)

# for any optimizer parameters not in the config file, use the default values
for key, value in default_optimizer_parameters.items():
if key not in self.parameters["optimizer_parameters"]:
self.parameters["optimizer_parameters"][key] = value
# Update only the keys present in the YAML file
for key, value in config_data.items():
if key in self.parameters:
self.parameters[key] = value

def register(
self,
Expand Down
Loading