Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show residuals during reconstruction #2430

Merged
merged 7 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#2430: Add Residual Plot in AsyncTaskDialog Displayed During Reconstruction

38 changes: 33 additions & 5 deletions mantidimaging/core/reconstruct/cil_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,39 @@ class MIProgressCallback(Callback):
def __init__(self, verbose=1, progress: Progress | None = None) -> None:
super().__init__(verbose)
self.progress = progress
self.iteration_count = 1

def __call__(self, algo: Algorithm) -> None:
if self.progress:
extra_info = {'iterations': algo.iterations, 'losses': algo.loss}
if algo.last_residual and algo.last_residual[0] == algo.iteration:
extra_info["residual"] = algo.last_residual[1]
self.progress.update(
steps=1,
msg=f'CIL: Iteration {self.iteration_count } of {algo.max_iteration}'
msg=f'CIL: Iteration {algo.iteration} of {algo.max_iteration}'
f': Objective {algo.get_last_objective():.2f}',
force_continue=False,
extra_info=extra_info,
)
self.iteration_count += 1


class RecordResidualsCallback(Callback):

def __init__(self, verbose=1, residual_interval: int = 1) -> None:
super().__init__(verbose)
self.residual_interval = residual_interval

def __call__(self, algo: Algorithm) -> None:
if algo.iteration % self.residual_interval == 0:
if isinstance(algo, PDHG):
forward_projection = algo.operator.direct(algo.solution)[1].as_array()
data = algo.f[1].b.as_array()
if len(forward_projection.shape) == 3:
JackEAllen marked this conversation as resolved.
Show resolved Hide resolved
# For a full 3D recon, just select the middle slice
slice = forward_projection.shape[0] // 2
forward_projection = forward_projection[slice]
data = data[slice]
residual: np.ndarray = (data - forward_projection)**2
algo.last_residual = (algo.iteration, residual**2)


class CILRecon(BaseRecon):
Expand Down Expand Up @@ -282,7 +302,11 @@ def single_sino(sino: np.ndarray,
# this may be confusing for the user in case of SPDHG, because they will
# input num_iter and they will run num_iter * num_subsets
algo.max_iteration = num_iter
algo.run(num_iter, callbacks=[MIProgressCallback(progress=progress)])
algo.run(num_iter,
callbacks=[
RecordResidualsCallback(residual_interval=update_objective_interval),
MIProgressCallback(progress=progress)
])

finally:
if progress:
Expand Down Expand Up @@ -401,7 +425,11 @@ def full(images: ImageStack,
# this may be confusing for the user in case of SPDHG, because they will
# input num_iter and they will run num_iter * num_subsets
algo.max_iteration = num_iter
algo.run(num_iter, callbacks=[MIProgressCallback(progress=progress)])
algo.run(num_iter,
callbacks=[
RecordResidualsCallback(residual_interval=update_objective_interval),
MIProgressCallback(progress=progress)
])

if isinstance(algo.solution, BlockDataContainer):
# TGV case
Expand Down
8 changes: 7 additions & 1 deletion mantidimaging/gui/dialogs/async_task/presenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from collections.abc import Callable

import numpy as np
from PyQt5.QtCore import QObject, pyqtSignal

from mantidimaging.core.utility.progress_reporting import ProgressHandler
Expand All @@ -21,13 +22,15 @@ class Notification(Enum):
class AsyncTaskDialogPresenter(QObject, ProgressHandler):
progress_updated = pyqtSignal(float, str)
progress_plot_updated = pyqtSignal(list, list)
progress_residual_plot_updated = pyqtSignal(np.ndarray)

def __init__(self, view):
super().__init__()

self.view = view
self.progress_updated.connect(self.view.set_progress)
self.progress_plot_updated.connect(self.view.set_progress_plot)
self.progress_residual_plot_updated.connect(self.view.set_progress_residual_plot)

self.model = AsyncTaskDialogModel()
self.model.task_done.connect(self.view.handle_completion)
Expand Down Expand Up @@ -74,9 +77,12 @@ def progress_update(self) -> None:
extra_info = progress_info[-1].extra_info
self.progress_updated.emit(self.progress.completion(), msg if msg is not None else '')

if extra_info:
if extra_info and 'losses' in extra_info:
JackEAllen marked this conversation as resolved.
Show resolved Hide resolved
self.update_progress_plot(extra_info['iterations'], extra_info['losses'])

if extra_info and 'residual' in extra_info:
self.progress_residual_plot_updated.emit(extra_info["residual"])

def show_stop_button(self, show: bool = False) -> None:
self.view.show_cancel_button(show)

Expand Down
18 changes: 17 additions & 1 deletion mantidimaging/gui/dialogs/async_task/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

from typing import Any
from collections.abc import Callable
from pyqtgraph import PlotWidget

import numpy as np
from pyqtgraph import PlotWidget, ImageView

from mantidimaging.core.utility.progress_reporting import Progress
from mantidimaging.gui.mvp_base import BaseDialogView
Expand All @@ -24,12 +26,15 @@ def __init__(self, parent: QMainWindow):

self.progressBar.setMinimum(0)
self.progressBar.setMaximum(1000)

self.progress_plot = PlotWidget()
self.PlotVerticalLayout.addWidget(self.progress_plot)
self.progress_plot.hide()
self.progress_plot.setLogMode(y=True)
self.progress_plot.setMinimumHeight(300)

self.residual_image_view: ImageView | None = None

self.show_timer = QTimer(self)
self.cancelButton.clicked.connect(self.presenter.stop_progress)
self.cancelButton.hide()
Expand Down Expand Up @@ -73,6 +78,17 @@ def set_progress_plot(self, x: list, y: list):
self.progress_plot.show()
self.progress_plot.plotItem.plot(x, y)

def set_progress_residual_plot(self, residual_image: np.ndarray) -> None:
if self.residual_image_view is None:
JackEAllen marked this conversation as resolved.
Show resolved Hide resolved
residual_image_view = ImageView()
residual_image_view.setMinimumSize(600, 400)
self.PlotVerticalLayout.addWidget(residual_image_view)
self.residual_image_view = residual_image_view
max_level = np.percentile(residual_image, 95) * 2
self.residual_image_view.setImage(residual_image, levels=(0, max_level))
self.residual_image_view.ui.histogram.gradient.loadPreset("viridis")
self.residual_image_view.ui.histogram.setHistogramRange(0, max_level)

def show_delayed(self, timeout) -> None:
self.show_timer.singleShot(timeout, self.show_from_timer)
self.show_timer.start()
Expand Down
Loading