Skip to content

Commit

Permalink
Added Type checking to the liveviewer and recon (#2252)
Browse files Browse the repository at this point in the history
  • Loading branch information
samtygier-stfc authored Jul 24, 2024
2 parents 638500c + 9b7232c commit 621d5cf
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 79 deletions.
10 changes: 5 additions & 5 deletions mantidimaging/gui/windows/live_viewer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def _is_image_file(file_name: str) -> bool:
image_extensions = ('tif', 'tiff', 'fits')
return file_name.rpartition(".")[2].lower() in image_extensions

def remove_path(self):
def remove_path(self) -> None:
"""
Remove the currently set path
"""
Expand All @@ -292,22 +292,22 @@ def update_recent_watcher(self, images: list[Image_Data]) -> None:
self.recent_file_watcher.removePaths(self.recent_file_watcher.files())
self.recent_file_watcher.addPaths([str(image.image_path) for image in images])

def handle_image_modified(self, file_path):
def handle_image_modified(self, file_path) -> None:
self.recent_image_changed.emit(Path(file_path))

def add_sub_directory(self, sub_dir: SubDirectory):
def add_sub_directory(self, sub_dir: SubDirectory) -> None:
if sub_dir.path not in self.sub_directories:
self.watcher.addPath(str(sub_dir.path))

self.sub_directories[sub_dir.path] = sub_dir

def remove_sub_directory(self, sub_dir: Path):
def remove_sub_directory(self, sub_dir: Path) -> None:
if sub_dir in self.sub_directories:
self.watcher.removePath(str(sub_dir))

del self.sub_directories[sub_dir]

def clear_deleted_sub_directories(self, directory: Path):
def clear_deleted_sub_directories(self, directory: Path) -> None:
for sub_dir in list(self.sub_directories):
if sub_dir.is_relative_to(directory) and not sub_dir.exists():
self.remove_sub_directory(sub_dir)
13 changes: 7 additions & 6 deletions mantidimaging/gui/windows/live_viewer/presenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def select_image(self, index: int) -> None:

self.display_image(self.selected_image.image_path)

def display_image(self, image_path: Path):
def display_image(self, image_path: Path) -> None:
"""
Display image in the view after validating contents
"""
Expand Down Expand Up @@ -119,20 +119,21 @@ def load_image(image_path: Path) -> np.ndarray:
image_data = fit[0].data
return image_data

def update_image_modified(self, image_path: Path):
def update_image_modified(self, image_path: Path) -> None:
"""
Update the displayed image when the file is modified
"""
if self.selected_image and image_path == self.selected_image.image_path:
self.display_image(image_path)

def update_image_operation(self):
def update_image_operation(self) -> None:
"""
Reload the current image if an operation has been performed on the current image
"""
self.display_image(self.selected_image.image_path)
if self.selected_image is not None:
self.display_image(self.selected_image.image_path)

def convert_image_to_imagestack(self, image_data):
def convert_image_to_imagestack(self, image_data) -> ImageStack:
"""
Convert the single image to an imagestack so the Operations framework can be used
"""
Expand All @@ -141,7 +142,7 @@ def convert_image_to_imagestack(self, image_data):
image_data_temp[0] = image_data
return ImageStack(image_data_temp)

def perform_operations(self, image_data):
def perform_operations(self, image_data) -> np.ndarray:
if not self.view.filter_params:
return image_data
image_stack = self.convert_image_to_imagestack(image_data)
Expand Down
2 changes: 1 addition & 1 deletion mantidimaging/gui/windows/live_viewer/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def closeEvent(self, e) -> None:
super().closeEvent(e)
self.presenter = None # type: ignore # View instance to be destroyed -type can be inconsistent

def set_image_rotation_angle(self):
def set_image_rotation_angle(self) -> None:
"""Set the image rotation angle which will be read in by the presenter"""
if self.rotate_angles_group.checkedAction().text() == "0°":
if "Rotate Stack" in self.filter_params:
Expand Down
32 changes: 16 additions & 16 deletions mantidimaging/gui/windows/recon/image_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, parent):
self.imageview_projection.enable_nonpositive_check()
self.imageview_sinogram.enable_nonpositive_check()

def cleanup(self):
def cleanup(self) -> None:
self.imageview_projection.cleanup()
self.imageview_sinogram.cleanup()
self.imageview_recon.cleanup()
Expand All @@ -60,10 +60,10 @@ def cleanup(self):
del self.imageview_sinogram
del self.imageview_recon

def slice_line_moved(self):
def slice_line_moved(self) -> None:
self.slice_changed(int(self.slice_line.value()))

def update_projection(self, image_data: np.ndarray, preview_slice_index: int, tilt_angle: Degrees | None):
def update_projection(self, image_data: np.ndarray, preview_slice_index: int, tilt_angle: Degrees | None) -> None:
self.imageview_projection.clear()
self.imageview_projection.setImage(image_data)
self.imageview_projection.histogram.imageChanged(autoLevel=True, autoRange=True)
Expand All @@ -75,13 +75,13 @@ def update_projection(self, image_data: np.ndarray, preview_slice_index: int, ti
self.hide_tilt()
set_histogram_log_scale(self.imageview_projection.histogram)

def update_sinogram(self, image):
def update_sinogram(self, image) -> None:
self.imageview_sinogram.clear()
self.imageview_sinogram.setImage(image)
self.imageview_sinogram.histogram.imageChanged(autoLevel=True, autoRange=True)
set_histogram_log_scale(self.imageview_sinogram.histogram)

def update_recon(self, image_data, reset_roi: bool = False):
def update_recon(self, image_data, reset_roi: bool = False) -> None:
self.imageview_recon.clear()
self.imageview_recon.setImage(image_data, autoLevels=False)
set_histogram_log_scale(self.imageview_recon.histogram)
Expand All @@ -90,34 +90,34 @@ def update_recon(self, image_data, reset_roi: bool = False):
else:
self.recon_line_profile.update()

def update_recon_hist(self):
def update_recon_hist(self) -> None:
self.imageview_recon.histogram.imageChanged(autoLevel=True, autoRange=True)

def mouse_click(self, ev, line: InfiniteLine):
def mouse_click(self, ev, line: InfiniteLine) -> None:
line.setPos(ev.pos())
self.slice_changed(CloseEnoughPoint(ev.pos()).y)

def slice_changed(self, slice_index):
def slice_changed(self, slice_index) -> None:
self.parent.presenter.do_preview_reconstruct_slice(slice_idx=slice_index)
self.sigSliceIndexChanged.emit(slice_index)

def clear_recon(self):
def clear_recon(self) -> None:
self.imageview_recon.clear()

def clear_recon_line_profile(self):
def clear_recon_line_profile(self) -> None:
self.recon_line_profile.clear_plot()

def clear_sinogram(self):
def clear_sinogram(self) -> None:
self.imageview_sinogram.clear()

def clear_projection(self):
def clear_projection(self) -> None:
self.imageview_projection.clear()

def reset_slice_and_tilt(self, slice_index):
def reset_slice_and_tilt(self, slice_index) -> None:
self.slice_line.setPos(slice_index)
self.hide_tilt()

def hide_tilt(self):
def hide_tilt(self) -> None:
"""
Hides the tilt line. This stops infinite zooming out loop that messes up the image view
(the line likes to be unbound when the degree isn't a multiple o 90 - and the tilt never is)
Expand All @@ -126,13 +126,13 @@ def hide_tilt(self):
if self.tilt_line.scene() is not None:
self.imageview_projection.viewbox.removeItem(self.tilt_line)

def set_tilt(self, tilt: Degrees, pos: int | None = None):
def set_tilt(self, tilt: Degrees, pos: int | None = None) -> None:
if not isnan(tilt.value): # is isnan it means there is no tilt, i.e. the line is vertical
if pos is not None:
self.tilt_line.setAngle(90)
self.tilt_line.setPos(pos)
self.tilt_line.setAngle(90 + tilt.value)
self.imageview_projection.viewbox.addItem(self.tilt_line)

def reset_recon_histogram(self):
def reset_recon_histogram(self) -> None:
self.imageview_recon.histogram.autoHistogramRange()
11 changes: 6 additions & 5 deletions mantidimaging/gui/windows/recon/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def images(self):
def num_points(self) -> int:
return self.data_model.num_points

def initial_select_data(self, images: ImageStack | None):
def initial_select_data(self, images: ImageStack | None) -> None:
self._images = images
self.reset_cor_model()

Expand Down Expand Up @@ -134,7 +134,9 @@ def run_preview_recon(self,
images.projection_angles(recon_params.max_projection_angle),
recon_params,
progress=progress)

recon = self._apply_pixel_size(recon, recon_params)

return recon

def run_full_recon(self, recon_params: ReconstructionParameters, progress: Progress) -> ImageStack | None:
Expand All @@ -146,12 +148,11 @@ def run_full_recon(self, recon_params: ReconstructionParameters, progress: Progr
# get the image height based on the current ROI
recon = reconstructor.full(images, self.data_model.get_all_cors_from_regression(images.height), recon_params,
progress)

recon = self._apply_pixel_size(recon, recon_params, progress)
return recon

@staticmethod
def _apply_pixel_size(recon, recon_params: ReconstructionParameters, progress=None):
def _apply_pixel_size(recon: ImageStack, recon_params: ReconstructionParameters, progress=None) -> ImageStack:
if recon_params.pixel_size > 0.:
recon = DivideFilter.filter_func(recon, value=recon_params.pixel_size, unit="micron", progress=progress)
# update the reconstructed stack pixel size with the value actually used for division
Expand Down Expand Up @@ -255,7 +256,7 @@ def auto_find_correlation(self, progress: Progress) -> tuple[ScalarCoR, Degrees]
return find_center(self.images, progress)

@staticmethod
def proj_180_degree_shape_matches_images(images):
def proj_180_degree_shape_matches_images(images) -> bool:
return images.has_proj180deg() and images.height == images.proj180deg.height \
and images.width == images.proj180deg.width

Expand All @@ -269,7 +270,7 @@ def stack_contains_negative_values(self) -> bool:
return bool(np.any(self.images.data < 0))

@property
def stack_id(self):
def stack_id(self) -> uuid.UUID | None:
if self.images is not None:
return self.images.id
return None
48 changes: 27 additions & 21 deletions mantidimaging/gui/windows/recon/point_table_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from enum import Enum
from typing import Any

from PyQt5.QtCore import QAbstractTableModel, QModelIndex, Qt

Expand All @@ -29,54 +30,59 @@ class CorTiltPointQtModel(QAbstractTableModel, CorTiltDataModel):
def __init__(self, parent=None):
super().__init__(parent)

def populate_slice_indices(self, begin, end, count, cor=0.0):
def populate_slice_indices(self, begin: int, end: int, count: int, cor: float = 0.0) -> None:
self.beginResetModel()
super().populate_slice_indices(begin, end, count, cor)
self.endResetModel()

def sort_points(self):
def sort_points(self) -> None:
self.layoutAboutToBeChanged.emit()
super().sort_points()
self.layoutChanged.emit()

def set_point(self, idx, slice_idx: int | None = None, cor: float | None = None, reset_results=True):
def set_point(self,
idx: int,
slice_idx: int | None = None,
cor: float | None = None,
reset_results: bool = True) -> None:
super().set_point(idx, slice_idx, cor, reset_results)
self.dataChanged.emit(self.index(idx, 0), self.index(idx, 1))

def columnCount(self, parent=None):
def columnCount(self, parent: QModelIndex | None = None) -> int:
return 2

def rowCount(self, parent):
def rowCount(self, parent: QModelIndex) -> int:
if parent.isValid():
return 0
return self.num_points

def flags(self, index):
def flags(self, index: QModelIndex) -> Qt.ItemFlags:
flags = super().flags(index)
flags |= Qt.ItemFlag.ItemIsEditable
return flags

def data(self, index, role=Qt.ItemDataRole.DisplayRole):
if not index.isValid():
def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> int | str | float | None:
if not index.isValid() or index.row() >= len(self._points):
return None

col = index.column()
col_field = Column(col)
point: Point = self._points[index.row()]
col_field: Column = Column(index.column())

if role == Qt.ItemDataRole.DisplayRole:
if col_field == Column.SLICE_INDEX:
return self._points[index.row()].slice_index
if col_field == Column.CENTRE_OF_ROTATION:
return self._points[index.row()].cor
return point.slice_index
elif col_field == Column.CENTRE_OF_ROTATION:
return point.cor

elif role == Qt.ItemDataRole.ToolTipRole:
if col_field == Column.SLICE_INDEX:
return 'Slice index (y coordinate of projection)'
elif col_field == Column.CENTRE_OF_ROTATION:
return 'Centre of rotation for specific slice'
return ''

def getColumn(self, column_index) -> list[int]:
return None

def getColumn(self, column_index: int) -> list[int]:
if column_index != 0 and column_index != 1:
return []
else:
Expand All @@ -85,7 +91,7 @@ def getColumn(self, column_index) -> list[int]:
column.append(point.slice_index)
return column

def setData(self, index, val, role=Qt.ItemDataRole.EditRole):
def setData(self, index: QModelIndex, val: Any, role: Qt.ItemDataRole = Qt.ItemDataRole.EditRole) -> bool:
if role != Qt.ItemDataRole.EditRole:
return False

Expand All @@ -107,15 +113,15 @@ def setData(self, index, val, role=Qt.ItemDataRole.EditRole):

return True

def insertRows(self, row, count, parent=None, slice_idx: int | None = None, cor: float | None = None):
def insertRows(self, row, count, parent=None, slice_idx: int | None = None, cor: float | None = None) -> None:
self.beginInsertRows(parent if parent is not None else QModelIndex(), row, row + count - 1)

for _ in range(count):
self.add_point(row, slice_idx, cor)

self.endInsertRows()

def removeRows(self, row, count, parent=None):
def removeRows(self, row, count, parent=None) -> None:
if self.empty:
return

Expand All @@ -126,20 +132,20 @@ def removeRows(self, row, count, parent=None):

self.endRemoveRows()

def removeAllRows(self, parent=None):
def removeAllRows(self, parent: QModelIndex | None = None) -> None:
if self.empty:
return

self.beginRemoveRows(parent if parent else QModelIndex(), 0, self.num_points - 1)
self.clear_points()
self.endRemoveRows()

def appendNewRow(self, row: int, slice_idx: int, cor: float = 0.0):
def appendNewRow(self, row: int, slice_idx: int, cor: float = 0.0) -> None:
self.insertRows(row, 1, slice_idx=slice_idx, cor=cor)
self.set_point(row, slice_idx, cor)
self.sort_points()

def headerData(self, section, orientation, role):
def headerData(self, section, orientation, role) -> str | None:
if orientation != Qt.Orientation.Horizontal:
return None

Expand Down
Loading

0 comments on commit 621d5cf

Please sign in to comment.