From 4bdc221c443409f2c1a920e3ec672d860efeba1f Mon Sep 17 00:00:00 2001 From: Sam Tygier Date: Mon, 9 Dec 2024 17:44:07 +0000 Subject: [PATCH] Add RecordResidualsCallback and pass residuals through to progress --- mantidimaging/core/reconstruct/cil_recon.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/mantidimaging/core/reconstruct/cil_recon.py b/mantidimaging/core/reconstruct/cil_recon.py index ace2cf1b3d3..269b9e98e96 100644 --- a/mantidimaging/core/reconstruct/cil_recon.py +++ b/mantidimaging/core/reconstruct/cil_recon.py @@ -44,6 +44,8 @@ def __init__(self, verbose=1, progress: Progress | None = None) -> None: def __call__(self, algo: Algorithm) -> None: if self.progress: extra_info = {'iterations': algo.iterations, 'losses': algo.loss} + if algo.last_residual and algo.last_residual[0] == algo.iteration: + extra_info["residual"] = algo.last_residual[1] self.progress.update( steps=1, msg=f'CIL: Iteration {algo.iteration} of {algo.max_iteration}' @@ -65,6 +67,7 @@ def __call__(self, algo: Algorithm) -> None: forward_projection = algo.operator.direct(algo.solution)[1].as_array() data = algo.f[1].b.as_array() if len(forward_projection.shape) == 3: + # For a full 3D recon, just select the middle slice slice = forward_projection.shape[0] // 2 forward_projection = forward_projection[slice] data = data[slice] @@ -299,7 +302,11 @@ def single_sino(sino: np.ndarray, # this may be confusing for the user in case of SPDHG, because they will # input num_iter and they will run num_iter * num_subsets algo.max_iteration = num_iter - algo.run(num_iter, callbacks=[MIProgressCallback(progress=progress)]) + algo.run(num_iter, + callbacks=[ + RecordResidualsCallback(residual_interval=update_objective_interval), + MIProgressCallback(progress=progress) + ]) finally: if progress: @@ -418,7 +425,11 @@ def full(images: ImageStack, # this may be confusing for the user in case of SPDHG, because they will # input num_iter and they will run num_iter * num_subsets algo.max_iteration = num_iter - algo.run(num_iter, callbacks=[MIProgressCallback(progress=progress)]) + algo.run(num_iter, + callbacks=[ + RecordResidualsCallback(residual_interval=update_objective_interval), + MIProgressCallback(progress=progress) + ]) if isinstance(algo.solution, BlockDataContainer): # TGV case