Skip to content

Commit

Permalink
Add RecordResidualsCallback and pass residuals through to progress
Browse files Browse the repository at this point in the history
  • Loading branch information
samtygier-stfc committed Dec 10, 2024
1 parent 8145f34 commit 4bdc221
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions mantidimaging/core/reconstruct/cil_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def __init__(self, verbose=1, progress: Progress | None = None) -> None:
def __call__(self, algo: Algorithm) -> None:
if self.progress:
extra_info = {'iterations': algo.iterations, 'losses': algo.loss}
if algo.last_residual and algo.last_residual[0] == algo.iteration:
extra_info["residual"] = algo.last_residual[1]
self.progress.update(
steps=1,
msg=f'CIL: Iteration {algo.iteration} of {algo.max_iteration}'
Expand All @@ -65,6 +67,7 @@ def __call__(self, algo: Algorithm) -> None:
forward_projection = algo.operator.direct(algo.solution)[1].as_array()
data = algo.f[1].b.as_array()
if len(forward_projection.shape) == 3:
# For a full 3D recon, just select the middle slice
slice = forward_projection.shape[0] // 2
forward_projection = forward_projection[slice]
data = data[slice]
Expand Down Expand Up @@ -299,7 +302,11 @@ def single_sino(sino: np.ndarray,
# this may be confusing for the user in case of SPDHG, because they will
# input num_iter and they will run num_iter * num_subsets
algo.max_iteration = num_iter
algo.run(num_iter, callbacks=[MIProgressCallback(progress=progress)])
algo.run(num_iter,
callbacks=[
RecordResidualsCallback(residual_interval=update_objective_interval),
MIProgressCallback(progress=progress)
])

finally:
if progress:
Expand Down Expand Up @@ -418,7 +425,11 @@ def full(images: ImageStack,
# this may be confusing for the user in case of SPDHG, because they will
# input num_iter and they will run num_iter * num_subsets
algo.max_iteration = num_iter
algo.run(num_iter, callbacks=[MIProgressCallback(progress=progress)])
algo.run(num_iter,
callbacks=[
RecordResidualsCallback(residual_interval=update_objective_interval),
MIProgressCallback(progress=progress)
])

if isinstance(algo.solution, BlockDataContainer):
# TGV case
Expand Down

0 comments on commit 4bdc221

Please sign in to comment.