Skip to content

Commit

Permalink
appropriate defaults (float and int instead of str), added fina…
Browse files Browse the repository at this point in the history
…l logging info,
  • Loading branch information
sarthakpati committed Apr 1, 2024
1 parent c05e9f1 commit 8d53051
Showing 1 changed file with 28 additions and 37 deletions.
65 changes: 28 additions & 37 deletions ereg/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def register(
self.parameters["composite_transform"] = self.parameters.get(
"composite_transform", None
)
if self.parameters["composite_transform"] is not None:
if self.parameters.get("composite_transform"):
self.logger.info("Applying composite transform.")
transform_composite = sitk.ReadTransform(
self.parameters["composite_transform"]
Expand All @@ -191,7 +191,6 @@ def register(
transform_composite, self.transform
)

if self.parameters["composite_transform"]:
self.logger.info("Applying previous transforms.")
current_transform = None
for previous_transform in self.parameters["previous_transforms"]:
Expand Down Expand Up @@ -458,9 +457,23 @@ def _register_image_and_get_transform(
elif type(sampling_rate) in [np.ndarray, list]:
registration.SetMetricSamplingPercentagePerLevel(sampling_rate)

# initialize some defaults
self.parameters["optimizer_parameters"] = self.parameters.get(
"optimizer_parameters", {}
)
self.parameters["optimizer_parameters"]["type"] = self.parameters[
"optimizer_parameters"
].get("type", "regular_step_gradient_descent")
# set the optimizer parameters as either floats or integers
for key in self.parameters["optimizer_parameters"]:
if key not in ["type"]:
self.parameters["optimizer_parameters"][key] = float(
self.parameters["optimizer_parameters"][key]
)
if key == "iterations":
self.parameters["optimizer_parameters"][key] = int(
self.parameters["optimizer_parameters"][key]
)
if (
self.parameters["optimizer_parameters"].get("type").lower()
== "regular_step_gradient_descent"
Expand Down Expand Up @@ -682,8 +695,12 @@ def _register_image_and_get_transform(
# registration.SetOptimizerScalesFromJacobian()
registration.SetOptimizerScalesFromPhysicalShift()

registration.SetShrinkFactorsPerLevel(self.parameters.get("shrink_factors", [8, 4, 2]))
registration.SetSmoothingSigmasPerLevel(self.parameters.get("smoothing_sigmas", [3, 2, 1]))
registration.SetShrinkFactorsPerLevel(
self.parameters.get("shrink_factors", [8, 4, 2])
)
registration.SetSmoothingSigmasPerLevel(
self.parameters.get("smoothing_sigmas", [3, 2, 1])
)
registration.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

assert (
Expand Down Expand Up @@ -756,11 +773,16 @@ def _register_image_and_get_transform(
)
continue

if output_transform is None:
raise RuntimeError("Registration failed.")
assert output_transform is not None, "Registration failed."

self.logger.info(
f"Final Optimizer Parameters:: convergence={registration.GetOptimizerConvergenceValue()}, iterations={registration.GetOptimizerIteration()}, metric={registration.GetMetricValue()}, stop condition={registration.GetOptimizerStopConditionDescription()}"
)

registration_transform_sitk = output_transform
# if user is requesting a rigid registration, convert the transform to a rigid transform
if isinstance(output_transform, sitk.CompositeTransform):
registration_transform_sitk = output_transform.GetNthTransform(0)
if self.parameters["transform"] in ["euler", "versorrigid"]:
try:
# Euler Transform used:
Expand All @@ -778,35 +800,4 @@ def _register_image_and_get_transform(
tmp.SetTranslation(registration_transform_sitk.GetTranslation())
tmp.SetCenter(registration_transform_sitk.GetCenter())
registration_transform_sitk = tmp
## additional information
# print("Metric: ", registration.MetricEvaluate(target_image, moving_image), flush=True)
# print(
# "Optimizer stop condition: ",
# registration.GetOptimizerStopConditionDescription(),
# flush=True,
# )
# print("Number of iterations: ", registration.GetOptimizerIteration(), flush=True)
# print("Final metric value: ", registration.GetMetricValue(), flush=True)

# if rigid_registration:
# if target_image.GetDimension() == 2:
# output_transform = eval(
# "sitk.Euler%dDTransform(output_transform)"
# % (target_image.GetDimension())
# )
# elif target_image.GetDimension() == 3:
# output_transform = eval(
# "sitk.Euler%dDTransform(output_transform)"
# % (target_image.GetDimension())
# )
# # VersorRigid used: Transform from VersorRigid to Euler
# output_transform = eval(
# "sitk.VersorRigid%dDTransform(output_transform)"
# % (target_image.GetDimension())
# )
# tmp = eval("sitk.Euler%dDTransform()" % (target_image.GetDimension()))
# tmp.SetMatrix(output_transform.GetMatrix())
# tmp.SetTranslation(output_transform.GetTranslation())
# tmp.SetCenter(output_transform.GetCenter())
# output_transform = tmp
return registration_transform_sitk

0 comments on commit 8d53051

Please sign in to comment.