Skip to content

Commit

Permalink
Merge pull request #91 from multimeric/shape-selector
Browse files Browse the repository at this point in the history
Shape selector
  • Loading branch information
multimeric authored Nov 17, 2024
2 parents 454d204 + 523e2ba commit baaec71
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 34 deletions.
38 changes: 28 additions & 10 deletions core/lls_core/models/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path

from typing import Iterable, Optional, Tuple, Union, cast, TYPE_CHECKING, overload
from typing_extensions import Generic, TypeVar
from typing_extensions import Generic, TypeVar, TypeAlias
from pydantic.v1 import BaseModel, NonNegativeInt, Field
from lls_core.types import ArrayLike, is_arraylike
from lls_core.utils import make_filename_suffix
Expand Down Expand Up @@ -75,6 +75,19 @@ class ImageSlices(ProcessedSlices[ArrayLike]):
# This re-definition of the type is helpful for `mkdocs`
slices: Iterable[ProcessedSlice[ArrayLike]] = Field(description="Iterable of result slices. For a given slice, you can access the image data through the `slice.data` property, which is a numpy-like array.")

def roi_previews(self) -> Iterable[ArrayLike]:
"""
Extracts a single 3D image for each ROI
"""
import numpy as np
def _preview(slices: Iterable[ProcessedSlice[ArrayLike]]) -> ArrayLike:
for slice in slices:
return slice.data
raise Exception("This ROI has no images. This shouldn't be possible")

for roi_index, slices in groupby(self.slices, key=lambda slice: slice.roi_index):
yield _preview(slices)

def save_image(self):
"""
Saves result slices to disk
Expand Down Expand Up @@ -124,15 +137,16 @@ def save(self) -> Path:
else:
return self.data

class WorkflowSlices(ProcessedSlices[Tuple[RawWorkflowOutput, ...]]):
MaybeTupleRawWorkflowOutput: TypeAlias = Union[Tuple[RawWorkflowOutput], RawWorkflowOutput]
class WorkflowSlices(ProcessedSlices[MaybeTupleRawWorkflowOutput]):
"""
The counterpart of `ImageSlices`, but for workflow outputs.
This is needed because workflows have vastly different outputs that may include regular
Python types rather than only image slices.
"""

# This re-definition of the type is helpful for `mkdocs`
slices: Iterable[ProcessedSlice[Tuple[RawWorkflowOutput, ...]]] = Field(description="Iterable of raw workflow results, the exact nature of which is determined by the author of the workflow. Not typically useful directly, and using he result of `.process()` is recommended instead.")
slices: Iterable[ProcessedSlice[MaybeTupleRawWorkflowOutput]] = Field(description="Iterable of raw workflow results, the exact nature of which is determined by the author of the workflow. Not typically useful directly, and using he result of `.process()` is recommended instead.")

def process(self) -> Iterable[ProcessedWorkflowOutput]:
"""
Expand Down Expand Up @@ -189,16 +203,20 @@ def process(self) -> Iterable[ProcessedWorkflowOutput]:
else:
yield ProcessedWorkflowOutput(index=i, roi_index=roi, data=pd.DataFrame(element), lattice_data=self.lattice_data)

def extract_preview(self) -> NDArray:
def roi_previews(self) -> Iterable[NDArray]:
"""
Extracts a single 3D image for previewing purposes
Extracts a single 3D image for each ROI
"""
import numpy as np
for slice in self.slices:
for value in slice.as_tuple():
if is_arraylike(value):
return np.asarray(value)
raise Exception("No image was returned from this workflow")
def _preview(slices: Iterable[ProcessedSlice[MaybeTupleRawWorkflowOutput]]) -> NDArray:
for slice in slices:
for value in slice.as_tuple():
if is_arraylike(value):
return np.asarray(value)
raise Exception("This ROI has no images. This shouldn't be possible")

for roi_index, slices in groupby(self.slices, key=lambda slice: slice.roi_index):
yield _preview(slices)

def save(self) -> Iterable[Path]:
"""
Expand Down
5 changes: 3 additions & 2 deletions core/tests/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ def test_sum_preview(rbc_tiny: Path):
workflow = "core/tests/workflows/binarisation/workflow.yml",
save_dir = tmpdir
)
preview = params.process_workflow().extract_preview()
np.sum(preview, axis=(1, 2))
previews = list(params.process_workflow().roi_previews())
assert len(previews) == 1, "There should be 1 preview when cropping is disabled"
assert previews[0].ndim == 3, "A preview should be a 3D image"

def test_crop_workflow(rbc_tiny: Path):
# Tests that crop workflows only process each ROI lazily
Expand Down
15 changes: 7 additions & 8 deletions plugin/napari_lattice/dock_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,19 @@ def preview(self, header: str, time: int, channel: int):
lattice.dy,
lattice.dx
)
preview: ArrayLike
previews: Iterable[ArrayLike]

# We extract the first available image to use as a preview
# This works differently for workflows and non-workflows
if lattice.workflow is None:
for slice in lattice.process().slices:
preview = slice.data
break
previews = lattice.process().roi_previews()
else:
preview = lattice.process_workflow().extract_preview()
previews = lattice.process_workflow().roi_previews()

self.parent_viewer.add_image(preview, scale=scale, name="Napari Lattice Preview")
max_z = np.argmax(np.sum(preview, axis=(1, 2)))
self.parent_viewer.dims.set_current_step(0, max_z)
for preview in previews:
self.parent_viewer.add_image(preview, scale=scale, name="Napari Lattice Preview")
max_z = np.argmax(np.sum(preview, axis=(1, 2)))
self.parent_viewer.dims.set_current_step(0, max_z)


@set_design(text="Save")
Expand Down
35 changes: 21 additions & 14 deletions plugin/napari_lattice/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from lls_core.models.deskew import DefinedPixelSizes
from lls_core.models.output import SaveFileType
from lls_core.workflow import workflow_from_path
from magicclass import FieldGroup, MagicTemplate, field, magicclass, set_design
from magicclass import FieldGroup, MagicTemplate, field, magicclass, set_design, vfield
from magicclass.fields import MagicField
from magicclass.widgets import ComboBox, Label, Widget
from napari.layers import Image, Shapes
Expand All @@ -32,9 +32,11 @@
from qtpy.QtWidgets import QTabWidget
from strenum import StrEnum
from napari_lattice.parent_connect import connect_parent
from napari_lattice.shape_selector import ShapeSelector

if TYPE_CHECKING:
from magicgui.widgets.bases import RangedWidget
from numpy.typing import NDArray

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -429,15 +431,8 @@ class CroppingFields(NapariFieldGroup):
This is to support the workflow of performing a preview deskew and using that to calculate the cropping coordinates.
"""), widget_type="Label")
fields_enabled = field(False, label="Enabled")
shapes= field(List[Shapes], widget_type="Select", label = "ROI Shape Layers").with_options(choices=lambda _x, _y: get_layers(Shapes))
z_range = field(Tuple[int, int]).with_options(
label = "Z Range",
value = (0, 1),
options = dict(
min = 0,
),
)
errors = field(Label).with_options(label="Errors")

shapes= vfield(ShapeSelector)

@set_design(text="Import ROI")
def import_roi(self, path: Path):
Expand All @@ -455,7 +450,16 @@ def new_crop_layer(self):
from napari_lattice.utils import get_viewer
shapes = get_viewer().add_shapes(name="Napari Lattice Crop")
shapes.mode = "ADD_RECTANGLE"
self.shapes.value += [shapes]
# self.shapes.value += [shapes]

z_range = field(Tuple[int, int]).with_options(
label = "Z Range",
value = (0, 1),
options = dict(
min = 0,
),
)
errors = field(Label).with_options(label="Errors")

@connect_parent("deskew_fields.img_layer")
def _on_image_changed(self, field: MagicField):
Expand Down Expand Up @@ -486,9 +490,12 @@ def _make_model(self) -> Optional[CropParams]:
if self.fields_enabled.value:
deskew = self._get_deskew()
rois = []
for shape_layer in self.shapes.value:
for x in shape_layer.data:
rois.append(Roi.from_array(x / deskew.dy))
for shape in self.shapes.shapes.value:
# The Napari shape is an array with 2 dimensions.
# Each column is an axis and each row is a point defining the shape
# We drop all but the last two axes, giving us a 2D shape with XY coordinates
array: NDArray = shape.get_array()[..., -2:] / deskew.dy
rois.append(Roi.from_array(array))

return CropParams(
# Convert from the input image space to the deskewed image space
Expand Down
58 changes: 58 additions & 0 deletions plugin/napari_lattice/shape_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations
from napari.utils.events import EventEmitter, Event
from napari.layers import Shapes

class ShapeLayerChangedEvent(Event):
"""
Event triggered when the shape layer selection changes.
"""

class ShapeSelectionListener(EventEmitter):
"""
Manages shape selection events for a given Shapes layer.
Examples:
This example code will open the viewer with an empty shape layer.
Any selection changes to that layer will trigger a notification popup.
>>> from napari import Viewer
>>> from napari.layers import Shapes
>>> viewer = Viewer()
>>> shapes = viewer.add_shapes()
>>> shape_selection = ShapeSelection(shapes)
>>> shape_selection.connect(lambda event: print("Shape selection changed!"))
"""
last_selection: set[int]
layer: Shapes

def __init__(self, layer) -> None:
"""
Initializes the ShapeSelection with the given Shapes layer.
Parameters:
layer: The Shapes layer to listen to.
"""
super().__init__(source=layer, event_class=ShapeLayerChangedEvent, type_name="shape_layer_selection_changed")
self.layer = layer
self.last_selection = set()
layer.events.highlight.connect(self._on_highlight)

def _on_highlight(self, event) -> None:
new_selection = self.layer.selected_data
if new_selection != self.last_selection:
self()
self.last_selection = set(new_selection)

def test_script():
"""
Demo for testing the event behaviour.
"""
from napari import run, Viewer
from napari.utils.notifications import show_info
viewer = Viewer()
shapes = viewer.add_shapes()
event = ShapeSelectionListener(shapes)
event.connect(lambda x: show_info("Shape selection changed!"))
run()

if __name__ == "__main__":
test_script()
Loading

0 comments on commit baaec71

Please sign in to comment.