diff --git a/brainglobe_registration/elastix/register.py b/brainglobe_registration/elastix/register.py index d04d918..4b9e69b 100644 --- a/brainglobe_registration/elastix/register.py +++ b/brainglobe_registration/elastix/register.py @@ -1,7 +1,9 @@ -from typing import List +from pathlib import Path +from typing import List, Tuple import itk import numpy as np +import numpy.typing as npt from brainglobe_atlasapi import BrainGlobeAtlas @@ -25,24 +27,33 @@ def get_atlas_by_name(atlas_name: str) -> BrainGlobeAtlas: def run_registration( - atlas_image, - moving_image, - annotation_image, + atlas_image: npt.NDArray, + moving_image: npt.NDArray, + annotation_image: npt.NDArray, + atlas_voxel_size: Tuple[float, ...], + moving_voxel_size: Tuple[float, ...], parameter_lists: List[tuple[str, dict]], -) -> tuple[np.ndarray, itk.ParameterObject, np.ndarray]: + output_directory: Path, +) -> Tuple[np.ndarray, itk.ParameterObject, np.ndarray]: """ Run the registration process on the given images. Parameters ---------- - atlas_image : np.ndarray + atlas_image : npt.NDArray The atlas image. - moving_image : np.ndarray + moving_image : npt.NDArray The moving image. - annotation_image : np.ndarray + atlas_voxel_size : Tuple[float, ...] + The voxel size of the atlas image in um. + moving_voxel_size : Tuple[float, ...] + The voxel size of the moving image in um. + annotation_image : npt.NDArray The annotation image. parameter_lists : List[tuple[str, dict]], optional The list of parameter lists, by default None + output_directory: Path + The output directory for the registration process. Returns ------- @@ -54,6 +65,13 @@ def run_registration( # convert to ITK, view only atlas_image = itk.GetImageViewFromArray(atlas_image).astype(itk.F) moving_image = itk.GetImageViewFromArray(moving_image).astype(itk.F) + annotation_image = itk.GetImageViewFromArray(annotation_image).astype( + itk.F + ) + + atlas_image.SetSpacing(atlas_voxel_size) + annotation_image.SetSpacing(atlas_voxel_size) + moving_image.SetSpacing(moving_voxel_size) # This syntax needed for 3D images elastix_object = itk.ElastixRegistrationMethod.New( @@ -63,6 +81,7 @@ def run_registration( parameter_object = setup_parameter_object(parameter_lists=parameter_lists) elastix_object.SetParameterObject(parameter_object) + elastix_object.SetOutputDirectory(str(output_directory)) # update filter object elastix_object.UpdateLargestPossibleRegion() @@ -82,6 +101,16 @@ def run_registration( result_transform_parameters, ) + # Load Transformix Object + transformix_object = itk.TransformixFilter.New(annotation_image) + transformix_object.SetTransformParameterObject(result_transform_parameters) + + # Update object (required) + transformix_object.UpdateLargestPossibleRegion() + + # Results of Transformation + annotation_image_transformix = transformix_object.GetOutput() + result_transform_parameters.SetParameter( "FinalBSplineInterpolationOrder", temp_interp_order ) diff --git a/brainglobe_registration/registration_widget.py b/brainglobe_registration/registration_widget.py index b58b24a..2f3cebe 100644 --- a/brainglobe_registration/registration_widget.py +++ b/brainglobe_registration/registration_widget.py @@ -23,6 +23,7 @@ from napari.qt.threading import thread_worker from napari.utils.notifications import show_error from napari.viewer import Viewer +from ome_zarr.dask_utils import resize from pytransform3d.rotations import active_matrix_from_angle from qtpy.QtWidgets import ( QPushButton, @@ -293,16 +294,21 @@ def _on_sample_dropdown_index_changed(self, index): self._viewer, self._sample_images[index] ) self._moving_image = self._viewer.layers[viewer_index] - self._moving_image_data_backup = self._moving_image.data.copy() + self._moving_image_data_backup = da.asarray(self._moving_image.data) - def _on_adjust_moving_image(self, x: int, y: int, rotate: float): + if len(self._moving_image.data.shape) == 3: + self.adjust_moving_image_widget.set_moving_image_to3d() + else: + self.adjust_moving_image_widget.set_moving_image_to2d() + + def _on_adjust_moving_image(self, x: int, y: int, z: int, rotate: float): if not self._moving_image: show_error( "No moving image selected. " "Please select a moving image before adjusting" ) return - adjust_napari_image_layer(self._moving_image, x, y, rotate) + adjust_napari_image_layer(self._moving_image, x, y, z, rotate) def _on_adjust_moving_image_reset_button_click(self): if not self._moving_image: @@ -371,14 +377,35 @@ def _on_crop_atlas_z_signal(self, start: int, end: int): ) def _on_run_button_click(self): + if not (self._atlas and self._moving_image): + show_error( + "Sample image or atlas not selected. " + "Please select a sample image and atlas before running" + ) + return + + if len(self._moving_image.data.shape) == 3: + reference_data = self._atlas_data_layer.data + annotation_data = self._atlas_annotations_layer.data.compute() + else: + current_atlas_slice = self._viewer.dims.current_step[0] + reference_data = self._atlas_data_layer.data[ + current_atlas_slice, :, : + ] + annotation_data = self._atlas_annotations_layer.data[ + current_atlas_slice, :, : + ] - current_atlas_slice = self._viewer.dims.current_step[0] + output_path = Path.home() / "NIU-dev" / "elastix_output" result, parameters, registered_annotation_image = run_registration( - self._atlas_data_layer.data[current_atlas_slice, :, :], - self._moving_image.data, - self._atlas_annotations_layer.data[current_atlas_slice, :, :], - self.transform_selections, + atlas_image=reference_data, + moving_image=self._moving_image.data, + atlas_voxel_size=self._atlas.resolution, + moving_voxel_size=(25, 25, 25), + annotation_image=annotation_data, + parameter_lists=self.transform_selections, + output_directory=output_path, ) boundaries = find_boundaries( @@ -475,7 +502,7 @@ def _on_sample_popup_about_to_show(self): self._sample_images = get_image_layer_names(self._viewer) self.get_atlas_widget.update_sample_image_names(self._sample_images) - def _on_scale_moving_image(self, x: float, y: float): + def _on_scale_moving_image(self, x: float, y: float, z: float = 1.0): """ Scale the moving image to have resolution equal to the atlas. @@ -485,11 +512,13 @@ def _on_scale_moving_image(self, x: float, y: float): Moving image x pixel size (> 0.0). y : float Moving image y pixel size (> 0.0). + z : float + Moving image z pixel size (> 0.0). Will show an error if the pixel sizes are less than or equal to 0. Will show an error if the moving image or atlas is not selected. """ - if x <= 0 or y <= 0: + if x <= 0 or y <= 0 or z <= 0: show_error("Pixel sizes must be greater than 0") return @@ -501,18 +530,34 @@ def _on_scale_moving_image(self, x: float, y: float): return if self._moving_image_data_backup is None: - self._moving_image_data_backup = self._moving_image.data.copy() + self._moving_image_data_backup = da.asarray( + self._moving_image.data + ) x_factor = x / self._atlas.resolution[0] y_factor = y / self._atlas.resolution[1] - - self._moving_image.data = rescale( - self._moving_image_data_backup, - (y_factor, x_factor), - mode="constant", - preserve_range=True, - anti_aliasing=True, - ) + # z_factor = z / self._atlas.resolution[2] + + if len(self._moving_image.data.shape) == 3: + self._moving_image.data = resize( + self._moving_image_data_backup, + ( + self._atlas.reference.shape[0], + self._atlas.reference.shape[2], + self._atlas.reference.shape[1], + ), + mode="constant", + preserve_range=True, + anti_aliasing=True, + ).compute() + else: + self._moving_image.data = rescale( + self._moving_image_data_backup, + (y_factor, x_factor), + mode="constant", + preserve_range=True, + anti_aliasing=True, + ) def _on_adjust_atlas_rotation(self, pitch: float, yaw: float, roll: float): if not ( @@ -580,6 +625,8 @@ def _on_adjust_atlas_rotation(self, pitch: float, yaw: float, roll: float): worker = self.compute_atlas_rotation(self._atlas_data_layer.data) worker.returned.connect(self.set_atlas_layer_data) worker.start() + self._atlas_data_layer.experimental_clipping_planes = None + self._atlas_annotations_layer.experimental_clipping_planes = None @thread_worker def compute_atlas_rotation(self, dask_array: da.Array): diff --git a/brainglobe_registration/utils/utils.py b/brainglobe_registration/utils/utils.py index 2bb9891..a0dd993 100644 --- a/brainglobe_registration/utils/utils.py +++ b/brainglobe_registration/utils/utils.py @@ -8,7 +8,11 @@ def adjust_napari_image_layer( - image_layer: napari.layers.Image, x: int, y: int, rotate: float + image_layer: napari.layers.Image, + x: int, + y: int, + z: int = 0, + rotate: float = 0, ): """ Adjusts the napari image layer by the given x, y, and rotation values. @@ -28,19 +32,31 @@ def adjust_napari_image_layer( The x-coordinate for the translation. y : int The y-coordinate for the translation. - rotate : float + z : int, optional + The z-coordinate for the translation. + rotate : float, optional The angle of rotation in degrees. Returns -------- None """ - image_layer.translate = (y, x) + num_dimensions = len(image_layer.data.shape) + if num_dimensions == 3: + image_layer.translate = (z, y, x) + translation = np.asarray([z, y, x]) + else: + image_layer.translate = (y, x) + translation = np.asarray([y, x]) + + rotation_matrix = np.eye(num_dimensions + 1) + rotation_matrix[:num_dimensions, :num_dimensions] = ( + active_matrix_from_angle(0, np.deg2rad(rotate)) + ) - rotation_matrix = active_matrix_from_angle(2, np.deg2rad(rotate)) - translate_matrix = np.eye(3) - origin = np.asarray(image_layer.data.shape) // 2 + np.asarray([y, x]) - translate_matrix[:2, -1] = origin + translate_matrix = np.eye(num_dimensions + 1) + origin = np.asarray(image_layer.data.shape) // 2 + translation + translate_matrix[:num_dimensions, -1] = origin transform_matrix = ( translate_matrix @ rotation_matrix @ np.linalg.inv(translate_matrix) ) diff --git a/brainglobe_registration/widgets/adjust_moving_image_view.py b/brainglobe_registration/widgets/adjust_moving_image_view.py index 82d51bb..715680f 100644 --- a/brainglobe_registration/widgets/adjust_moving_image_view.py +++ b/brainglobe_registration/widgets/adjust_moving_image_view.py @@ -41,8 +41,8 @@ class AdjustMovingImageView(QWidget): Resets the pitch, yaw, and roll to 0 and emits the atlas_reset_signal. """ - adjust_image_signal = Signal(int, int, float) - scale_image_signal = Signal(float, float) + adjust_image_signal = Signal(int, int, int, float) + scale_image_signal = Signal(float, float, float) atlas_rotation_signal = Signal(float, float, float) reset_atlas_signal = Signal() reset_image_signal = Signal() @@ -67,9 +67,16 @@ def __init__(self, parent=None): self.adjust_moving_image_pixel_size_x = QDoubleSpinBox(parent=self) self.adjust_moving_image_pixel_size_x.setDecimals(2) self.adjust_moving_image_pixel_size_x.setRange(0.01, 100.00) + self.adjust_moving_image_pixel_size_y = QDoubleSpinBox(parent=self) self.adjust_moving_image_pixel_size_y.setDecimals(2) self.adjust_moving_image_pixel_size_y.setRange(0.01, 100.00) + + self.adjust_moving_image_pixel_size_z = QDoubleSpinBox(parent=self) + self.adjust_moving_image_pixel_size_z.setDecimals(2) + self.adjust_moving_image_pixel_size_z.setRange(0.01, 100.00) + self.adjust_moving_image_pixel_size_z.setEnabled(False) + self.scale_moving_image_button = QPushButton() self.scale_moving_image_button.setText("Scale Image") self.scale_moving_image_button.clicked.connect( @@ -105,6 +112,11 @@ def __init__(self, parent=None): self.adjust_moving_image_y.setRange(-offset_range, offset_range) self.adjust_moving_image_y.valueChanged.connect(self._on_adjust_image) + self.adjust_moving_image_z = QSpinBox(parent=self) + self.adjust_moving_image_z.setRange(-offset_range, offset_range) + self.adjust_moving_image_z.valueChanged.connect(self._on_adjust_image) + self.adjust_moving_image_z.setEnabled(False) + self.adjust_moving_image_rotate = QDoubleSpinBox(parent=self) self.adjust_moving_image_rotate.setRange( -rotation_range, rotation_range @@ -129,6 +141,10 @@ def __init__(self, parent=None): "Sample image Y pixel size (\u03BCm / pixel):", self.adjust_moving_image_pixel_size_y, ) + self.layout().addRow( + "Sample image Z pixel size (\u03BCm / pixel):", + self.adjust_moving_image_pixel_size_z, + ) self.layout().addRow(self.scale_moving_image_button) self.layout().addRow(QLabel("Adjust the atlas pitch and yaw: ")) @@ -141,6 +157,7 @@ def __init__(self, parent=None): self.layout().addRow(QLabel("Adjust the moving image position: ")) self.layout().addRow("X offset:", self.adjust_moving_image_x) self.layout().addRow("Y offset:", self.adjust_moving_image_y) + self.layout().addRow("Z offset:", self.adjust_moving_image_z) self.layout().addRow( "Rotation (degrees):", self.adjust_moving_image_rotate ) @@ -154,6 +171,7 @@ def _on_adjust_image(self): self.adjust_image_signal.emit( self.adjust_moving_image_x.value(), self.adjust_moving_image_y.value(), + self.adjust_moving_image_z.value(), self.adjust_moving_image_rotate.value(), ) @@ -164,6 +182,7 @@ def _on_reset_image_button_click(self): """ self.adjust_moving_image_x.setValue(0) self.adjust_moving_image_y.setValue(0) + self.adjust_moving_image_z.setValue(0) self.adjust_moving_image_rotate.setValue(0) self.reset_image_signal.emit() @@ -175,6 +194,7 @@ def _on_scale_image_button_click(self): self.scale_image_signal.emit( self.adjust_moving_image_pixel_size_x.value(), self.adjust_moving_image_pixel_size_y.value(), + self.adjust_moving_image_pixel_size_z.value(), ) def _on_adjust_atlas_rotation(self): @@ -196,3 +216,11 @@ def _on_atlas_reset(self): self.adjust_atlas_roll.setValue(0) self.reset_atlas_signal.emit() + + def set_moving_image_to3d(self): + self.adjust_moving_image_z.setEnabled(True) + self.adjust_moving_image_pixel_size_z.setEnabled(True) + + def set_moving_image_to2d(self): + self.adjust_moving_image_z.setEnabled(False) + self.adjust_moving_image_pixel_size_z.setEnabled(False)