Skip to content

Commit

Permalink
Integrate shape selector into rest of plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
multimeric committed Oct 11, 2024
1 parent a678a3d commit 8146b9a
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 44 deletions.
36 changes: 27 additions & 9 deletions core/lls_core/models/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from itertools import groupby
from pathlib import Path

from typing import Iterable, Optional, Tuple, Union, cast, TYPE_CHECKING, overload
from typing import Iterable, Iterator, Optional, Tuple, TypeAlias, Union, cast, TYPE_CHECKING, overload
from typing_extensions import Generic, TypeVar
from pydantic.v1 import BaseModel, NonNegativeInt, Field
from lls_core.types import ArrayLike, is_arraylike
Expand Down Expand Up @@ -75,6 +75,19 @@ class ImageSlices(ProcessedSlices[ArrayLike]):
# This re-definition of the type is helpful for `mkdocs`
slices: Iterable[ProcessedSlice[ArrayLike]] = Field(description="Iterable of result slices. For a given slice, you can access the image data through the `slice.data` property, which is a numpy-like array.")

def roi_previews(self) -> Iterable[ArrayLike]:
"""
Extracts a single 3D image for each ROI
"""
import numpy as np
def _preview(slices: Iterable[ProcessedSlice[ArrayLike]]) -> ArrayLike:
for slice in slices:
return slice.data
raise Exception("This ROI has no images. This shouldn't be possible")

for roi_index, slices in groupby(self.slices, key=lambda slice: slice.roi_index):
yield _preview(slices)

def save_image(self):
"""
Saves result slices to disk
Expand All @@ -96,7 +109,8 @@ def save_image(self):
If a `DataFrame`, then it contains non-image data returned by your workflow.
"""

class WorkflowSlices(ProcessedSlices[Union[Tuple[RawWorkflowOutput], RawWorkflowOutput]]):
MaybeTupleRawWorkflowOutput: TypeAlias = Union[Tuple[RawWorkflowOutput], RawWorkflowOutput]
class WorkflowSlices(ProcessedSlices[MaybeTupleRawWorkflowOutput]):
"""
The counterpart of `ImageSlices`, but for workflow outputs.
This is needed because workflows have vastly different outputs that may include regular
Expand Down Expand Up @@ -159,16 +173,20 @@ def process(self) -> Iterable[Tuple[RoiIndex, ProcessedWorkflowOutput]]:
else:
yield roi, pd.DataFrame(element)

def extract_preview(self) -> NDArray:
def roi_previews(self) -> Iterable[NDArray]:
"""
Extracts a single 3D image for previewing purposes
Extracts a single 3D image for each ROI
"""
import numpy as np
for slice in self.slices:
for value in slice.as_tuple():
if is_arraylike(value):
return np.asarray(value)
raise Exception("No image was returned from this workflow")
def _preview(slices: Iterable[ProcessedSlice[MaybeTupleRawWorkflowOutput]]) -> NDArray:
for slice in slices:
for value in slice.as_tuple():
if is_arraylike(value):
return np.asarray(value)
raise Exception("This ROI has no images. This shouldn't be possible")

for roi_index, slices in groupby(self.slices, key=lambda slice: slice.roi_index):
yield _preview(slices)

def save(self) -> Iterable[Path]:
"""
Expand Down
15 changes: 7 additions & 8 deletions plugin/napari_lattice/dock_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,19 @@ def preview(self, header: str, time: int, channel: int):
lattice.dy,
lattice.dx
)
preview: ArrayLike
previews: Iterable[ArrayLike]

# We extract the first available image to use as a preview
# This works differently for workflows and non-workflows
if lattice.workflow is None:
for slice in lattice.process().slices:
preview = slice.data
break
previews = lattice.process().roi_previews()
else:
preview = lattice.process_workflow().extract_preview()
previews = lattice.process_workflow().roi_previews()

self.parent_viewer.add_image(preview, scale=scale, name="Napari Lattice Preview")
max_z = np.argmax(np.sum(preview, axis=(1, 2)))
self.parent_viewer.dims.set_current_step(0, max_z)
for preview in previews:
self.parent_viewer.add_image(preview, scale=scale, name="Napari Lattice Preview")
max_z = np.argmax(np.sum(preview, axis=(1, 2)))
self.parent_viewer.dims.set_current_step(0, max_z)


@set_design(text="Save")
Expand Down
24 changes: 12 additions & 12 deletions plugin/napari_lattice/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,14 +432,6 @@ class CroppingFields(NapariFieldGroup):
fields_enabled = field(False, label="Enabled")

shapes= vfield(ShapeSelector)
z_range = field(Tuple[int, int]).with_options(
label = "Z Range",
value = (0, 1),
options = dict(
min = 0,
),
)
errors = field(Label).with_options(label="Errors")

@set_design(text="Import ROI")
def import_roi(self, path: Path):
Expand All @@ -457,7 +449,16 @@ def new_crop_layer(self):
from napari_lattice.utils import get_viewer
shapes = get_viewer().add_shapes(name="Napari Lattice Crop")
shapes.mode = "ADD_RECTANGLE"
self.shapes.value += [shapes]
# self.shapes.value += [shapes]

z_range = field(Tuple[int, int]).with_options(
label = "Z Range",
value = (0, 1),
options = dict(
min = 0,
),
)
errors = field(Label).with_options(label="Errors")

@connect_parent("deskew_fields.img_layer")
def _on_image_changed(self, field: MagicField):
Expand Down Expand Up @@ -488,9 +489,8 @@ def _make_model(self) -> Optional[CropParams]:
if self.fields_enabled.value:
deskew = self._get_deskew()
rois = []
for shape_layer in self.shapes.value:
for x in shape_layer.data:
rois.append(Roi.from_array(x / deskew.dy))
for shape in self.shapes.shapes.value:
rois.append(Roi.from_array(shape.get_array() / deskew.dy))

return CropParams(
# Convert from the input image space to the deskewed image space
Expand Down
37 changes: 22 additions & 15 deletions plugin/napari_lattice/shape_selector.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterator, Tuple, TYPE_CHECKING
from magicclass import field, magicclass
from magicgui.widgets import Select
from magicclass import field, magicclass, set_design
from magicgui.widgets import Select, Button
from napari.layers import Shapes
from napari.components.layerlist import LayerList
from collections import defaultdict
Expand Down Expand Up @@ -31,7 +31,27 @@ def get_array(self) -> NDArray:
@magicclass
class ShapeSelector:

def _get_shape_choices(self, widget: Select | None = None) -> Iterator[Tuple[str, Shape]]:
"""
Returns the choices to use for the Select box
"""
viewer = get_viewer()
for layer in viewer.layers:
if isinstance(layer, Shapes):
for index in layer.features.index:
result = Shape(layer=layer, index=index)
yield str(result), result

_blocked: bool
shapes = field(Select, options={"choices": _get_shape_choices, "label": "ROIs"})

@set_design(text="Select All")
def select_all(self) -> None:
self.shapes.value = self.shapes.choices

@set_design(text="Deselect All")
def deselect_all(self) -> None:
self.shapes.value = []

def __init__(self, enabled: bool, *args, **kwargs) -> None:
self._blocked = False
Expand All @@ -50,17 +70,6 @@ def _block(self):
yield True
self._blocked = False

def _get_shape_choices(self, widget: Select | None = None) -> Iterator[Tuple[str, Shape]]:
"""
Returns the choices to use for the Select box
"""
viewer = get_viewer()
for layer in viewer.layers:
if isinstance(layer, Shapes):
for index in layer.features.index:
result = Shape(layer=layer, index=index)
yield str(result), result

def _on_selection_change(self, event: Event) -> None:
"""
Triggered when the user clicks on one or more shapes.
Expand Down Expand Up @@ -119,8 +128,6 @@ def __post_init__(self) -> None:
if isinstance(layer, Shapes):
self._connect_shapes(layer)

shapes = field(Select, options={"choices": _get_shape_choices, "label": "ROIs"})

# values is a list[Shape], but if we use the correct annotation it breaks magicclass
@shapes.connect
def _widget_changed(self, values: list) -> None:
Expand Down

0 comments on commit 8146b9a

Please sign in to comment.