-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Qt widget for loading pose datasets as napari Points layers (#253)
* initialise napari plugin development * Create skeleton for napari plugin with collapsible widgets (#218) * initialise napari plugin development * initialise napari plugin development * create skeleton for napari plugin with collapsible widgets * add basic widget smoke tests and allow headless testing * do not depend on napari from pip * include napari option in install instructions * make meta_widget module private * pin atlasapi version to avoid unnecessary dependencies * pin napari >= 0.4.19 from conda-forge * switched to pip install of napari[all] * seperation of concerns in widget tests * add pytest-mock dev dependency * initialise napari plugin development * initialise napari plugin development * initialise napari plugin development * Added loader widget for poses * update widget tests * simplify dependency on brainglobe-utils * consistent monospace formatting for movement in public docstrings * get rid of code that's only relevant for displaying Tracks * enable visibility of napari layer tooltips * renamed widget to PosesLoader * make cmap optional in set_color_by method * wrote unit tests for napari convert module * wrote unit-tests for the layer styles module * linkcheck ignore zenodo redirects * move _sample_colormap out of PointsStyle class * small refactoring in the loader widget * Expand tests for loader widget * added comments and docstrings to napari plugin tests * refactored all napari tests into separate unit test folder * added napari-video to dependencies * replaced deprecated edge_width with border_width * got rid of widget pytest fixtures * remove duplicate word from docstring * remove napari-video dependency * include napari extras in docs requirements * add test for _on_browse_clicked method * getOpenFileName returns tuple, not str * simplify poses_to_napari_tracks Co-authored-by: Chang Huan Lo <[email protected]> * [pre-commit.ci] pre-commit autoupdate (#338) updates: - [github.com/astral-sh/ruff-pre-commit: v0.6.9 → v0.7.2](astral-sh/ruff-pre-commit@v0.6.9...v0.7.2) - [github.com/pre-commit/mirrors-mypy: v1.11.2 → v1.13.0](pre-commit/mirrors-mypy@v1.11.2...v1.13.0) - [github.com/mgedmin/check-manifest: 0.49 → 0.50](mgedmin/check-manifest@0.49...0.50) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Implement `compute_speed` and `compute_path_length` (#280) * implement compute_speed and compute_path_length functions * added speed to existing kinematics unit test * rewrote compute_path_length with various nan policies * unit test compute_path_length across time ranges * fixed and refactor compute_path_length and its tests * fixed docstring for compute_path_length * Accept suggestion on docstring wording Co-authored-by: Chang Huan Lo <[email protected]> * Remove print statement from test Co-authored-by: Chang Huan Lo <[email protected]> * Ensure nan report is printed Co-authored-by: Chang Huan Lo <[email protected]> * adapt warning message match in test * change 'any' to 'all' * uniform wording across path length docstrings * (mostly) leave time range validation to xarray slice * refactored parameters for test across time ranges * simplified test for path lenght with nans * replace drop policy with ffill * remove B905 ruff rule * make pre-commit happy --------- Co-authored-by: Chang Huan Lo <[email protected]> * initialise napari plugin development * initialise napari plugin development * initialise napari plugin development * initialise napari plugin development * initialise napari plugin development * avoid redefining duplicate attributes in child dataclass * modify test case to match poses_to_napari_tracks simplification * expected_log_messages should be a subset of captured messages Co-authored-by: Chang Huan Lo <[email protected]> * fix typo Co-authored-by: Chang Huan Lo <[email protected]> * use names for Qwidgets * reorganised test_valid_poses_to_napari_tracks * parametrised layer style tests * delet integration test which was reintroduced after conflict resolution * added test about file filters * deleted obsolete loader widget file (had snuck back in due to conflict merging) * combine tests for button callouts Co-authored-by: Chang Huan Lo <[email protected]> * Simplify test_layer_style_as_kwargs Co-authored-by: Chang Huan Lo <[email protected]> --------- Co-authored-by: Chang Huan Lo <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
5e3947a
commit 31e98c7
Showing
13 changed files
with
711 additions
and
103 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
-e . | ||
-e .[napari] | ||
ablog | ||
linkify-it-py | ||
myst-parser | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
"""Widgets for loading movement datasets from file.""" | ||
|
||
import logging | ||
from pathlib import Path | ||
|
||
from napari.settings import get_settings | ||
from napari.utils.notifications import show_warning | ||
from napari.viewer import Viewer | ||
from qtpy.QtWidgets import ( | ||
QComboBox, | ||
QFileDialog, | ||
QFormLayout, | ||
QHBoxLayout, | ||
QLineEdit, | ||
QPushButton, | ||
QSpinBox, | ||
QWidget, | ||
) | ||
|
||
from movement.io import load_poses | ||
from movement.napari.convert import poses_to_napari_tracks | ||
from movement.napari.layer_styles import PointsStyle | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# Allowed poses file suffixes for each supported source software | ||
SUPPORTED_POSES_FILES = { | ||
"DeepLabCut": ["*.h5", "*.csv"], | ||
"LightningPose": ["*.csv"], | ||
"SLEAP": ["*.h5", "*.slp"], | ||
} | ||
|
||
|
||
class PosesLoader(QWidget): | ||
"""Widget for loading movement poses datasets from file.""" | ||
|
||
def __init__(self, napari_viewer: Viewer, parent=None): | ||
"""Initialize the loader widget.""" | ||
super().__init__(parent=parent) | ||
self.viewer = napari_viewer | ||
self.setLayout(QFormLayout()) | ||
# Create widgets | ||
self._create_source_software_widget() | ||
self._create_fps_widget() | ||
self._create_file_path_widget() | ||
self._create_load_button() | ||
# Enable layer tooltips from napari settings | ||
self._enable_layer_tooltips() | ||
|
||
def _create_source_software_widget(self): | ||
"""Create a combo box for selecting the source software.""" | ||
self.source_software_combo = QComboBox() | ||
self.source_software_combo.setObjectName("source_software_combo") | ||
self.source_software_combo.addItems(SUPPORTED_POSES_FILES.keys()) | ||
self.layout().addRow("source software:", self.source_software_combo) | ||
|
||
def _create_fps_widget(self): | ||
"""Create a spinbox for selecting the frames per second (fps).""" | ||
self.fps_spinbox = QSpinBox() | ||
self.fps_spinbox.setObjectName("fps_spinbox") | ||
self.fps_spinbox.setMinimum(1) | ||
self.fps_spinbox.setMaximum(1000) | ||
self.fps_spinbox.setValue(30) | ||
self.layout().addRow("fps:", self.fps_spinbox) | ||
|
||
def _create_file_path_widget(self): | ||
"""Create a line edit and browse button for selecting the file path. | ||
This allows the user to either browse the file system, | ||
or type the path directly into the line edit. | ||
""" | ||
# File path line edit and browse button | ||
self.file_path_edit = QLineEdit() | ||
self.file_path_edit.setObjectName("file_path_edit") | ||
self.browse_button = QPushButton("Browse") | ||
self.browse_button.setObjectName("browse_button") | ||
self.browse_button.clicked.connect(self._on_browse_clicked) | ||
# Layout for line edit and button | ||
self.file_path_layout = QHBoxLayout() | ||
self.file_path_layout.addWidget(self.file_path_edit) | ||
self.file_path_layout.addWidget(self.browse_button) | ||
self.layout().addRow("file path:", self.file_path_layout) | ||
|
||
def _create_load_button(self): | ||
"""Create a button to load the file and add layers to the viewer.""" | ||
self.load_button = QPushButton("Load") | ||
self.load_button.setObjectName("load_button") | ||
self.load_button.clicked.connect(lambda: self._on_load_clicked()) | ||
self.layout().addRow(self.load_button) | ||
|
||
def _on_browse_clicked(self): | ||
"""Open a file dialog to select a file.""" | ||
file_suffixes = SUPPORTED_POSES_FILES[ | ||
self.source_software_combo.currentText() | ||
] | ||
|
||
file_path, _ = QFileDialog.getOpenFileName( | ||
self, | ||
caption="Open file containing predicted poses", | ||
filter=f"Poses files ({' '.join(file_suffixes)})", | ||
) | ||
|
||
# A blank string is returned if the user cancels the dialog | ||
if not file_path: | ||
return | ||
|
||
# Add the file path to the line edit (text field) | ||
self.file_path_edit.setText(file_path) | ||
|
||
def _on_load_clicked(self): | ||
"""Load the file and add as a Points layer to the viewer.""" | ||
fps = self.fps_spinbox.value() | ||
source_software = self.source_software_combo.currentText() | ||
file_path = self.file_path_edit.text() | ||
if file_path == "": | ||
show_warning("No file path specified.") | ||
return | ||
ds = load_poses.from_file(file_path, source_software, fps) | ||
|
||
self.data, self.props = poses_to_napari_tracks(ds) | ||
logger.info("Converted poses dataset to a napari Tracks array.") | ||
logger.debug(f"Tracks array shape: {self.data.shape}") | ||
|
||
self.file_name = Path(file_path).name | ||
self._add_points_layer() | ||
|
||
self._set_playback_fps(fps) | ||
logger.debug(f"Set napari playback speed to {fps} fps.") | ||
|
||
def _add_points_layer(self): | ||
"""Add the predicted poses to the viewer as a Points layer.""" | ||
# Style properties for the napari Points layer | ||
points_style = PointsStyle( | ||
name=f"poses: {self.file_name}", | ||
properties=self.props, | ||
) | ||
# Color the points by individual if there are multiple individuals | ||
# Otherwise, color by keypoint | ||
n_individuals = len(self.props["individual"].unique()) | ||
points_style.set_color_by( | ||
prop="individual" if n_individuals > 1 else "keypoint" | ||
) | ||
# Add the points layer to the viewer | ||
self.viewer.add_points(self.data[:, 1:], **points_style.as_kwargs()) | ||
logger.info("Added poses dataset as a napari Points layer.") | ||
|
||
@staticmethod | ||
def _set_playback_fps(fps: int): | ||
"""Set the playback speed for the napari viewer.""" | ||
settings = get_settings() | ||
settings.application.playback_fps = fps | ||
|
||
@staticmethod | ||
def _enable_layer_tooltips(): | ||
"""Toggle on tooltip visibility for napari layers. | ||
This nicely displays the layer properties as a tooltip | ||
when hovering over the layer in the napari viewer. | ||
""" | ||
settings = get_settings() | ||
settings.appearance.layer_tooltip_visibility = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
"""Conversion functions from ``movement`` datasets to napari layers.""" | ||
|
||
import logging | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import xarray as xr | ||
|
||
# get logger | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def _construct_properties_dataframe(ds: xr.Dataset) -> pd.DataFrame: | ||
"""Construct a properties DataFrame from a ``movement`` dataset.""" | ||
return pd.DataFrame( | ||
{ | ||
"individual": ds.coords["individuals"].values, | ||
"keypoint": ds.coords["keypoints"].values, | ||
"time": ds.coords["time"].values, | ||
"confidence": ds["confidence"].values.flatten(), | ||
} | ||
) | ||
|
||
|
||
def poses_to_napari_tracks(ds: xr.Dataset) -> tuple[np.ndarray, pd.DataFrame]: | ||
"""Convert poses dataset to napari Tracks array and properties. | ||
Parameters | ||
---------- | ||
ds : xr.Dataset | ||
``movement`` dataset containing pose tracks, confidence scores, | ||
and associated metadata. | ||
Returns | ||
------- | ||
data : np.ndarray | ||
napari Tracks array with shape (N, 4), | ||
where N is n_keypoints * n_individuals * n_frames | ||
and the 4 columns are (track_id, frame_idx, y, x). | ||
properties : pd.DataFrame | ||
DataFrame with properties (individual, keypoint, time, confidence). | ||
Notes | ||
----- | ||
A corresponding napari Points array can be derived from the Tracks array | ||
by taking its last 3 columns: (frame_idx, y, x). See the documentation | ||
on the napari Tracks [1]_ and Points [2]_ layers. | ||
References | ||
---------- | ||
.. [1] https://napari.org/stable/howtos/layers/tracks.html | ||
.. [2] https://napari.org/stable/howtos/layers/points.html | ||
""" | ||
n_frames = ds.sizes["time"] | ||
n_individuals = ds.sizes["individuals"] | ||
n_keypoints = ds.sizes["keypoints"] | ||
n_tracks = n_individuals * n_keypoints | ||
# Construct the napari Tracks array | ||
# Reorder axes to (individuals, keypoints, frames, xy) | ||
yx_cols = np.transpose(ds.position.values, (1, 2, 0, 3)).reshape(-1, 2)[ | ||
:, [1, 0] # swap x and y columns | ||
] | ||
# Each keypoint of each individual is a separate track | ||
track_id_col = np.repeat(np.arange(n_tracks), n_frames).reshape(-1, 1) | ||
time_col = np.tile(np.arange(n_frames), (n_tracks)).reshape(-1, 1) | ||
data = np.hstack((track_id_col, time_col, yx_cols)) | ||
# Construct the properties DataFrame | ||
# Stack 3 dimensions into a new single dimension named "tracks" | ||
ds_ = ds.stack(tracks=("individuals", "keypoints", "time")) | ||
properties = _construct_properties_dataframe(ds_) | ||
|
||
return data, properties |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
"""Dataclasses containing layer styles for napari.""" | ||
|
||
from dataclasses import dataclass, field | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from napari.utils.colormaps import ensure_colormap | ||
|
||
DEFAULT_COLORMAP = "turbo" | ||
|
||
|
||
@dataclass | ||
class LayerStyle: | ||
"""Base class for napari layer styles.""" | ||
|
||
name: str | ||
properties: pd.DataFrame | ||
visible: bool = True | ||
blending: str = "translucent" | ||
|
||
def as_kwargs(self) -> dict: | ||
"""Return the style properties as a dictionary of kwargs.""" | ||
return self.__dict__ | ||
|
||
|
||
@dataclass | ||
class PointsStyle(LayerStyle): | ||
"""Style properties for a napari Points layer.""" | ||
|
||
symbol: str = "disc" | ||
size: int = 10 | ||
border_width: int = 0 | ||
face_color: str | None = None | ||
face_color_cycle: list[tuple] | None = None | ||
face_colormap: str = DEFAULT_COLORMAP | ||
text: dict = field(default_factory=lambda: {"visible": False}) | ||
|
||
def set_color_by(self, prop: str, cmap: str | None = None) -> None: | ||
"""Set the face_color to a column in the properties DataFrame. | ||
Parameters | ||
---------- | ||
prop : str | ||
The column name in the properties DataFrame to color by. | ||
cmap : str, optional | ||
The name of the colormap to use, otherwise use the face_colormap. | ||
""" | ||
if cmap is None: | ||
cmap = self.face_colormap | ||
self.face_color = prop | ||
self.text["string"] = prop | ||
n_colors = len(self.properties[prop].unique()) | ||
self.face_color_cycle = _sample_colormap(n_colors, cmap) | ||
|
||
|
||
def _sample_colormap(n: int, cmap_name: str) -> list[tuple]: | ||
"""Sample n equally-spaced colors from a napari colormap. | ||
This includes the endpoints of the colormap. | ||
""" | ||
cmap = ensure_colormap(cmap_name) | ||
samples = np.linspace(0, len(cmap.colors) - 1, n).astype(int) | ||
return [tuple(cmap.colors[i]) for i in samples] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.