diff --git a/docs/requirements.txt b/docs/requirements.txt index 0a950f2e..2816d4d5 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,4 @@ --e . +-e .[napari] linkify-it-py myst-parser nbsphinx diff --git a/docs/source/conf.py b/docs/source/conf.py index 42ac0668..f3fff3e1 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -178,6 +178,7 @@ "https://opensource.org/license/bsd-3-clause/", # to avoid odd 403 error ] + myst_url_schemes = { "http": None, "https": None, diff --git a/movement/napari/_loader_widget.py b/movement/napari/_loader_widget.py deleted file mode 100644 index 7da5c3ce..00000000 --- a/movement/napari/_loader_widget.py +++ /dev/null @@ -1,32 +0,0 @@ -from napari.utils.notifications import show_info -from napari.viewer import Viewer -from qtpy.QtWidgets import ( - QFormLayout, - QPushButton, - QWidget, -) - - -class Loader(QWidget): - """Widget for loading data from files.""" - - def __init__(self, napari_viewer: Viewer, parent=None): - """Initialize the loader widget.""" - super().__init__(parent=parent) - self.viewer = napari_viewer - self.setLayout(QFormLayout()) - # Create widgets - self._create_hello_widget() - - def _create_hello_widget(self): - """Create the hello widget. - - This widget contains a button that, when clicked, shows a greeting. - """ - hello_button = QPushButton("Say hello") - hello_button.clicked.connect(self._on_hello_clicked) - self.layout().addRow("Greeting", hello_button) - - def _on_hello_clicked(self): - """Show a greeting.""" - show_info("Hello, world!") diff --git a/movement/napari/_loader_widgets.py b/movement/napari/_loader_widgets.py new file mode 100644 index 00000000..6e3dc8d3 --- /dev/null +++ b/movement/napari/_loader_widgets.py @@ -0,0 +1,161 @@ +"""Widgets for loading movement datasets from file.""" + +import logging +from pathlib import Path + +from napari.settings import get_settings +from napari.utils.notifications import show_warning +from napari.viewer import Viewer +from qtpy.QtWidgets import ( + QComboBox, + QFileDialog, + QFormLayout, + QHBoxLayout, + QLineEdit, + QPushButton, + QSpinBox, + QWidget, +) + +from movement.io import load_poses +from movement.napari.convert import poses_to_napari_tracks +from movement.napari.layer_styles import PointsStyle + +logger = logging.getLogger(__name__) + +# Allowed poses file suffixes for each supported source software +SUPPORTED_POSES_FILES = { + "DeepLabCut": ["*.h5", "*.csv"], + "LightningPose": ["*.csv"], + "SLEAP": ["*.h5", "*.slp"], +} + + +class PosesLoader(QWidget): + """Widget for loading movement poses datasets from file.""" + + def __init__(self, napari_viewer: Viewer, parent=None): + """Initialize the loader widget.""" + super().__init__(parent=parent) + self.viewer = napari_viewer + self.setLayout(QFormLayout()) + # Create widgets + self._create_source_software_widget() + self._create_fps_widget() + self._create_file_path_widget() + self._create_load_button() + # Enable layer tooltips from napari settings + self._enable_layer_tooltips() + + def _create_source_software_widget(self): + """Create a combo box for selecting the source software.""" + self.source_software_combo = QComboBox() + self.source_software_combo.setObjectName("source_software_combo") + self.source_software_combo.addItems(SUPPORTED_POSES_FILES.keys()) + self.layout().addRow("source software:", self.source_software_combo) + + def _create_fps_widget(self): + """Create a spinbox for selecting the frames per second (fps).""" + self.fps_spinbox = QSpinBox() + self.fps_spinbox.setObjectName("fps_spinbox") + self.fps_spinbox.setMinimum(1) + self.fps_spinbox.setMaximum(1000) + self.fps_spinbox.setValue(30) + self.layout().addRow("fps:", self.fps_spinbox) + + def _create_file_path_widget(self): + """Create a line edit and browse button for selecting the file path. + + This allows the user to either browse the file system, + or type the path directly into the line edit. + """ + # File path line edit and browse button + self.file_path_edit = QLineEdit() + self.file_path_edit.setObjectName("file_path_edit") + self.browse_button = QPushButton("Browse") + self.browse_button.setObjectName("browse_button") + self.browse_button.clicked.connect(self._on_browse_clicked) + # Layout for line edit and button + self.file_path_layout = QHBoxLayout() + self.file_path_layout.addWidget(self.file_path_edit) + self.file_path_layout.addWidget(self.browse_button) + self.layout().addRow("file path:", self.file_path_layout) + + def _create_load_button(self): + """Create a button to load the file and add layers to the viewer.""" + self.load_button = QPushButton("Load") + self.load_button.setObjectName("load_button") + self.load_button.clicked.connect(lambda: self._on_load_clicked()) + self.layout().addRow(self.load_button) + + def _on_browse_clicked(self): + """Open a file dialog to select a file.""" + file_suffixes = SUPPORTED_POSES_FILES[ + self.source_software_combo.currentText() + ] + + file_path, _ = QFileDialog.getOpenFileName( + self, + caption="Open file containing predicted poses", + filter=f"Poses files ({' '.join(file_suffixes)})", + ) + + # A blank string is returned if the user cancels the dialog + if not file_path: + return + + # Add the file path to the line edit (text field) + self.file_path_edit.setText(file_path) + + def _on_load_clicked(self): + """Load the file and add as a Points layer to the viewer.""" + fps = self.fps_spinbox.value() + source_software = self.source_software_combo.currentText() + file_path = self.file_path_edit.text() + if file_path == "": + show_warning("No file path specified.") + return + ds = load_poses.from_file(file_path, source_software, fps) + + self.data, self.props = poses_to_napari_tracks(ds) + logger.info("Converted poses dataset to a napari Tracks array.") + logger.debug(f"Tracks array shape: {self.data.shape}") + + self.file_name = Path(file_path).name + self._add_points_layer() + + self._set_playback_fps(fps) + logger.debug(f"Set napari playback speed to {fps} fps.") + + def _add_points_layer(self): + """Add the predicted poses to the viewer as a Points layer.""" + # Style properties for the napari Points layer + points_style = PointsStyle( + name=f"poses: {self.file_name}", + properties=self.props, + ) + # Color the points by individual if there are multiple individuals + # Otherwise, color by keypoint + n_individuals = len(self.props["individual"].unique()) + points_style.set_color_by( + prop="individual" if n_individuals > 1 else "keypoint" + ) + # Add the points layer to the viewer + self.viewer.add_points(self.data[:, 1:], **points_style.as_kwargs()) + logger.info("Added poses dataset as a napari Points layer.") + + @staticmethod + def _set_playback_fps(fps: int): + """Set the playback speed for the napari viewer.""" + settings = get_settings() + settings.application.playback_fps = fps + + @staticmethod + def _enable_layer_tooltips(): + """Toggle on tooltip visibility for napari layers. + + This nicely displays the layer properties as a tooltip + when hovering over the layer in the napari viewer. + """ + settings = get_settings() + settings.appearance.layer_tooltip_visibility = True diff --git a/movement/napari/_meta_widget.py b/movement/napari/_meta_widget.py index 3ed09575..a2147790 100644 --- a/movement/napari/_meta_widget.py +++ b/movement/napari/_meta_widget.py @@ -3,7 +3,7 @@ from brainglobe_utils.qtpy.collapsible_widget import CollapsibleWidgetContainer from napari.viewer import Viewer -from movement.napari._loader_widget import Loader +from movement.napari._loader_widgets import PosesLoader class MovementMetaWidget(CollapsibleWidgetContainer): @@ -18,9 +18,9 @@ def __init__(self, napari_viewer: Viewer, parent=None): super().__init__() self.add_widget( - Loader(napari_viewer, parent=self), + PosesLoader(napari_viewer, parent=self), collapsible=True, - widget_title="Load data", + widget_title="Load poses", ) self.loader = self.collapsible_widgets[0] diff --git a/movement/napari/convert.py b/movement/napari/convert.py new file mode 100644 index 00000000..890161be --- /dev/null +++ b/movement/napari/convert.py @@ -0,0 +1,73 @@ +"""Conversion functions from ``movement`` datasets to napari layers.""" + +import logging + +import numpy as np +import pandas as pd +import xarray as xr + +# get logger +logger = logging.getLogger(__name__) + + +def _construct_properties_dataframe(ds: xr.Dataset) -> pd.DataFrame: + """Construct a properties DataFrame from a ``movement`` dataset.""" + return pd.DataFrame( + { + "individual": ds.coords["individuals"].values, + "keypoint": ds.coords["keypoints"].values, + "time": ds.coords["time"].values, + "confidence": ds["confidence"].values.flatten(), + } + ) + + +def poses_to_napari_tracks(ds: xr.Dataset) -> tuple[np.ndarray, pd.DataFrame]: + """Convert poses dataset to napari Tracks array and properties. + + Parameters + ---------- + ds : xr.Dataset + ``movement`` dataset containing pose tracks, confidence scores, + and associated metadata. + + Returns + ------- + data : np.ndarray + napari Tracks array with shape (N, 4), + where N is n_keypoints * n_individuals * n_frames + and the 4 columns are (track_id, frame_idx, y, x). + properties : pd.DataFrame + DataFrame with properties (individual, keypoint, time, confidence). + + Notes + ----- + A corresponding napari Points array can be derived from the Tracks array + by taking its last 3 columns: (frame_idx, y, x). See the documentation + on the napari Tracks [1]_ and Points [2]_ layers. + + References + ---------- + .. [1] https://napari.org/stable/howtos/layers/tracks.html + .. [2] https://napari.org/stable/howtos/layers/points.html + + """ + n_frames = ds.sizes["time"] + n_individuals = ds.sizes["individuals"] + n_keypoints = ds.sizes["keypoints"] + n_tracks = n_individuals * n_keypoints + # Construct the napari Tracks array + # Reorder axes to (individuals, keypoints, frames, xy) + yx_cols = np.transpose(ds.position.values, (1, 2, 0, 3)).reshape(-1, 2)[ + :, [1, 0] # swap x and y columns + ] + # Each keypoint of each individual is a separate track + track_id_col = np.repeat(np.arange(n_tracks), n_frames).reshape(-1, 1) + time_col = np.tile(np.arange(n_frames), (n_tracks)).reshape(-1, 1) + data = np.hstack((track_id_col, time_col, yx_cols)) + # Construct the properties DataFrame + # Stack 3 dimensions into a new single dimension named "tracks" + ds_ = ds.stack(tracks=("individuals", "keypoints", "time")) + properties = _construct_properties_dataframe(ds_) + + return data, properties diff --git a/movement/napari/layer_styles.py b/movement/napari/layer_styles.py new file mode 100644 index 00000000..d95dd048 --- /dev/null +++ b/movement/napari/layer_styles.py @@ -0,0 +1,64 @@ +"""Dataclasses containing layer styles for napari.""" + +from dataclasses import dataclass, field + +import numpy as np +import pandas as pd +from napari.utils.colormaps import ensure_colormap + +DEFAULT_COLORMAP = "turbo" + + +@dataclass +class LayerStyle: + """Base class for napari layer styles.""" + + name: str + properties: pd.DataFrame + visible: bool = True + blending: str = "translucent" + + def as_kwargs(self) -> dict: + """Return the style properties as a dictionary of kwargs.""" + return self.__dict__ + + +@dataclass +class PointsStyle(LayerStyle): + """Style properties for a napari Points layer.""" + + symbol: str = "disc" + size: int = 10 + border_width: int = 0 + face_color: str | None = None + face_color_cycle: list[tuple] | None = None + face_colormap: str = DEFAULT_COLORMAP + text: dict = field(default_factory=lambda: {"visible": False}) + + def set_color_by(self, prop: str, cmap: str | None = None) -> None: + """Set the face_color to a column in the properties DataFrame. + + Parameters + ---------- + prop : str + The column name in the properties DataFrame to color by. + cmap : str, optional + The name of the colormap to use, otherwise use the face_colormap. + + """ + if cmap is None: + cmap = self.face_colormap + self.face_color = prop + self.text["string"] = prop + n_colors = len(self.properties[prop].unique()) + self.face_color_cycle = _sample_colormap(n_colors, cmap) + + +def _sample_colormap(n: int, cmap_name: str) -> list[tuple]: + """Sample n equally-spaced colors from a napari colormap. + + This includes the endpoints of the colormap. + """ + cmap = ensure_colormap(cmap_name) + samples = np.linspace(0, len(cmap.colors) - 1, n).astype(int) + return [tuple(cmap.colors[i]) for i in samples] diff --git a/pyproject.toml b/pyproject.toml index c805d72e..577a80a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,12 +47,8 @@ entry-points."napari.manifest".movement = "movement.napari:napari.yaml" [project.optional-dependencies] napari = [ - "napari[all]>=0.4.19", - # the rest will be replaced by brainglobe-utils[qt]>=0.6 after release - "brainglobe-atlasapi>=2.0.7", - "brainglobe-utils>=0.5", - "qtpy", - "superqt", + "napari[all]>=0.5.0", + "brainglobe-utils[qt]>=0.6" # needed for collapsible widgets ] dev = [ "pytest", diff --git a/tests/test_integration/test_napari_plugin.py b/tests/test_integration/test_napari_plugin.py deleted file mode 100644 index e7225dc5..00000000 --- a/tests/test_integration/test_napari_plugin.py +++ /dev/null @@ -1,61 +0,0 @@ -import pytest -from qtpy.QtWidgets import QPushButton, QWidget - -from movement.napari._loader_widget import Loader -from movement.napari._meta_widget import MovementMetaWidget - - -@pytest.fixture -def meta_widget(make_napari_viewer_proxy) -> MovementMetaWidget: - """Fixture to expose the MovementMetaWidget for testing. - Simultaneously acts as a smoke test that the widget - can be instantiated without crashing. - """ - viewer = make_napari_viewer_proxy() - return MovementMetaWidget(viewer) - - -@pytest.fixture -def loader_widget(meta_widget) -> QWidget: - """Fixture to expose the Loader widget for testing.""" - loader = meta_widget.loader.content() - return loader - - -def test_meta_widget(meta_widget): - """Test that the meta widget is properly instantiated.""" - assert meta_widget is not None - assert len(meta_widget.collapsible_widgets) == 1 - - first_widget = meta_widget.collapsible_widgets[0] - assert first_widget._text == "Load data" - assert first_widget.isExpanded() - - -def test_loader_widget(loader_widget): - """Test that the loader widget is properly instantiated.""" - assert loader_widget is not None - assert loader_widget.layout().rowCount() == 1 - - -def test_hello_button_calls_on_hello_clicked(make_napari_viewer_proxy, mocker): - """Test that clicking the hello button calls _on_hello_clicked. - - Here we have to create a new Loader widget after mocking the method. - We cannot reuse the existing widget fixture because then it would be too - late to mock (the widget has already "decided" which method to call). - """ - mock_method = mocker.patch( - "movement.napari._loader_widget.Loader._on_hello_clicked" - ) - loader = Loader(make_napari_viewer_proxy) - hello_button = loader.findChildren(QPushButton)[0] - hello_button.click() - mock_method.assert_called_once() - - -def test_on_hello_clicked_outputs_message(loader_widget, capsys): - """Test that _on_hello_clicked outputs the expected message.""" - loader_widget._on_hello_clicked() - captured = capsys.readouterr() - assert "INFO: Hello, world!" in captured.out diff --git a/tests/test_unit/test_napari_plugin/test_convert.py b/tests/test_unit/test_napari_plugin/test_convert.py new file mode 100644 index 00000000..c3704f11 --- /dev/null +++ b/tests/test_unit/test_napari_plugin/test_convert.py @@ -0,0 +1,104 @@ +"""Test suite for the movement.napari.convert module.""" + +import numpy as np +import pandas as pd +import pytest +from pandas.testing import assert_frame_equal + +from movement.napari.convert import poses_to_napari_tracks + + +@pytest.fixture +def confidence_with_some_nan(valid_poses_dataset_uniform_linear_motion): + """Return a valid poses dataset with some NaNs in confidence values.""" + ds = valid_poses_dataset_uniform_linear_motion + ds["confidence"].loc[{"individuals": "id_1", "time": [3, 7, 8]}] = np.nan + return ds + + +@pytest.fixture +def confidence_with_all_nan(valid_poses_dataset_uniform_linear_motion): + """Return a valid poses dataset with all NaNs in confidence values.""" + ds = valid_poses_dataset_uniform_linear_motion + ds["confidence"].data = np.full_like(ds["confidence"].data, np.nan) + return ds + + +@pytest.mark.parametrize( + "ds_name", + [ + "valid_poses_dataset_uniform_linear_motion", + "valid_poses_dataset_uniform_linear_motion_with_nans", + "confidence_with_some_nan", + "confidence_with_all_nan", + ], +) +def test_valid_poses_to_napari_tracks(ds_name, request): + """Test that the conversion from movement poses dataset to napari + tracks returns the expected data and properties. + """ + ds = request.getfixturevalue(ds_name) + n_frames = ds.sizes["time"] + n_individuals = ds.sizes["individuals"] + n_keypoints = ds.sizes["keypoints"] + n_tracks = n_individuals * n_keypoints # total tracked points + + data, props = poses_to_napari_tracks(ds) + + # Prepare expected y, x positions and corresponding confidence values. + # Assume values are extracted from the dataset in a specific way, + # by iterating first over individuals and then over keypoints. + y_coords, x_coords, confidence = [], [], [] + for id in ds.individuals.values: + for kpt in ds.keypoints.values: + position = ds.position.sel(individuals=id, keypoints=kpt) + y_coords.extend(position.sel(space="y").values) + x_coords.extend(position.sel(space="x").values) + conf = ds.confidence.sel(individuals=id, keypoints=kpt) + confidence.extend(conf.values) + + # Generate expected data array + expected_track_ids = np.repeat(np.arange(n_tracks), n_frames) + expected_frame_ids = np.tile(np.arange(n_frames), n_tracks) + expected_yx = np.column_stack((y_coords, x_coords)) + expected_data = np.column_stack( + (expected_track_ids, expected_frame_ids, expected_yx) + ) + + # Generate expected properties DataFrame + expected_props = pd.DataFrame( + { + "individual": np.repeat( + ds.individuals.values.repeat(n_keypoints), n_frames + ), + "keypoint": np.repeat( + np.tile(ds.keypoints.values, n_individuals), n_frames + ), + "time": expected_frame_ids, + "confidence": confidence, + } + ) + + # Assert that the data array matches the expected data + np.testing.assert_allclose(data, expected_data, equal_nan=True) + + # Assert that the properties DataFrame matches the expected properties + assert_frame_equal(props, expected_props) + + +@pytest.mark.parametrize( + "ds_name, expected_exception", + [ + ("not_a_dataset", AttributeError), + ("empty_dataset", KeyError), + ("missing_var_poses_dataset", AttributeError), + ("missing_dim_poses_dataset", KeyError), + ], +) +def test_invalid_poses_to_napari_tracks(ds_name, expected_exception, request): + """Test that the conversion from movement poses dataset to napari + tracks raises the expected error for invalid datasets. + """ + ds = request.getfixturevalue(ds_name) + with pytest.raises(expected_exception): + poses_to_napari_tracks(ds) diff --git a/tests/test_unit/test_napari_plugin/test_layer_styles.py b/tests/test_unit/test_napari_plugin/test_layer_styles.py new file mode 100644 index 00000000..6f76a223 --- /dev/null +++ b/tests/test_unit/test_napari_plugin/test_layer_styles.py @@ -0,0 +1,114 @@ +"""Unit tests for the LayerStyle and PointsStyle classes.""" + +import pandas as pd +import pytest + +from movement.napari.layer_styles import ( + DEFAULT_COLORMAP, + LayerStyle, + PointsStyle, +) + + +@pytest.fixture +def sample_properties(): + """Fixture that provides a sample "properties" DataFrame.""" + data = {"category": ["A", "B", "A", "C", "B"], "value": [1, 2, 3, 4, 5]} + return pd.DataFrame(data) + + +@pytest.fixture +def sample_layer_style(sample_properties): + """Fixture that provides a sample LayerStyle or subclass instance.""" + + def _sample_layer_style(layer_class): + return layer_class(name="Layer1", properties=sample_properties) + + return _sample_layer_style + + +@pytest.fixture +def default_style_attributes(sample_properties): + """Fixture that provides expected attributes for LayerStyle and subclasses. + + It holds the default values we expect after initialisation, as well as the + "name" and "properties" attributes that are defined in this test module. + """ + return { + # Shared attributes for LayerStyle and all its subclasses + LayerStyle: { + "name": "Layer1", # as given in sample_layer_style + "visible": True, + "blending": "translucent", + "properties": sample_properties, # as given by fixture above + }, + # Additional attributes for PointsStyle + PointsStyle: { + "symbol": "disc", + "size": 10, + "border_width": 0, + "face_color": None, + "face_color_cycle": None, + "face_colormap": DEFAULT_COLORMAP, + "text": {"visible": False}, + }, + } + + +@pytest.mark.parametrize( + "layer_class", + [LayerStyle, PointsStyle], +) +def test_layer_style_initialization( + sample_layer_style, layer_class, default_style_attributes +): + """Test that LayerStyle and subclasses initialize with default values.""" + style = sample_layer_style(layer_class) + + # Expected attributes of base LayerStyle, shared by all subclasses + expected_attrs = default_style_attributes[LayerStyle].copy() + # Additional attributes, specific to subclasses of LayerStyle + if layer_class != LayerStyle: + expected_attrs.update(default_style_attributes[layer_class]) + + # Check that all attributes are set correctly + for attr, expected_value in expected_attrs.items(): + actual_value = getattr(style, attr) + if isinstance(expected_value, pd.DataFrame): + assert actual_value.equals(expected_value) + else: + assert actual_value == expected_value + + +def test_layer_style_as_kwargs(sample_layer_style, default_style_attributes): + """Test that the as_kwargs method returns the correct dictionary.""" + style = sample_layer_style(LayerStyle).as_kwargs() + expected_attrs = default_style_attributes[LayerStyle] + assert style == expected_attrs + + +@pytest.mark.parametrize( + "prop, expected_n_colors", + [ + ("category", 3), + ("value", 5), + ], +) +def test_points_style_set_color_by( + sample_layer_style, prop, expected_n_colors +): + """Test that set_color_by updates face_color and face_color_cycle.""" + points_style = sample_layer_style(PointsStyle) + + points_style.set_color_by(prop=prop) + # Check that face_color and text are updated correctly + assert points_style.face_color == prop + assert points_style.text == {"visible": False, "string": prop} + + # Check that face_color_cycle has the correct number of colors + assert len(points_style.face_color_cycle) == expected_n_colors + # Check that all colors are tuples of length 4 (RGBA) + assert all( + isinstance(c, tuple) and len(c) == 4 + for c in points_style.face_color_cycle + ) diff --git a/tests/test_unit/test_napari_plugin/test_meta_widget.py b/tests/test_unit/test_napari_plugin/test_meta_widget.py new file mode 100644 index 00000000..09802f48 --- /dev/null +++ b/tests/test_unit/test_napari_plugin/test_meta_widget.py @@ -0,0 +1,15 @@ +"""Test the napari plugin meta widget.""" + +from movement.napari._meta_widget import MovementMetaWidget + + +def test_meta_widget_instantiation(make_napari_viewer_proxy): + """Test that the meta widget can be properly instantiated.""" + viewer = make_napari_viewer_proxy() + meta_widget = MovementMetaWidget(viewer) + + assert len(meta_widget.collapsible_widgets) == 1 + + first_widget = meta_widget.collapsible_widgets[0] + assert first_widget._text == "Load poses" + assert first_widget.isExpanded() diff --git a/tests/test_unit/test_napari_plugin/test_poses_loader_widget.py b/tests/test_unit/test_napari_plugin/test_poses_loader_widget.py new file mode 100644 index 00000000..da0c7c8e --- /dev/null +++ b/tests/test_unit/test_napari_plugin/test_poses_loader_widget.py @@ -0,0 +1,173 @@ +"""Unit tests for loader widgets in the napari plugin. + +We instantiate the PosesLoader widget in each test instead of using a fixture. +This is because mocking widget methods would not work after the widget is +instantiated (the methods would have already been connected to signals). +""" + +import pytest +from napari.settings import get_settings +from pytest import DATA_PATHS +from qtpy.QtWidgets import QComboBox, QLineEdit, QPushButton, QSpinBox + +from movement.napari._loader_widgets import PosesLoader + + +# ------------------- tests for widget instantiation--------------------------# +def test_poses_loader_widget_instantiation(make_napari_viewer_proxy): + """Test that the loader widget is properly instantiated.""" + # Instantiate the poses loader widget + poses_loader_widget = PosesLoader(make_napari_viewer_proxy) + + # Check that the widget has the expected number of rows + assert poses_loader_widget.layout().rowCount() == 4 + + # Check that the expected widgets are present in the layout + expected_widgets = [ + (QComboBox, "source_software_combo"), + (QSpinBox, "fps_spinbox"), + (QLineEdit, "file_path_edit"), + (QPushButton, "load_button"), + (QPushButton, "browse_button"), + ] + assert all( + poses_loader_widget.findChild(widget_type, widget_name) is not None + for widget_type, widget_name in expected_widgets + ), "Some widgets are missing." + + # Make sure that layer tooltips are enabled + assert get_settings().appearance.layer_tooltip_visibility is True + + +# --------test connection between widget buttons and methods------------------# +@pytest.mark.parametrize("button", ["browse", "load"]) +def test_button_connected_to_on_clicked( + make_napari_viewer_proxy, mocker, button +): + """Test that clicking a button calls the right function.""" + mock_method = mocker.patch( + f"movement.napari._loader_widgets.PosesLoader._on_{button}_clicked" + ) + poses_loader_widget = PosesLoader(make_napari_viewer_proxy) + button = poses_loader_widget.findChild(QPushButton, f"{button}_button") + button.click() + mock_method.assert_called_once() + + +# ------------------- tests for widget methods--------------------------------# +# In these tests we check if calling a widget method has the expected effects + + +@pytest.mark.parametrize( + "file_path", + [ + # valid file path + str(DATA_PATHS.get("DLC_single-wasp.predictions.h5").parent), + # empty string, simulate user canceling the dialog + "", + ], +) +def test_on_browse_clicked(file_path, make_napari_viewer_proxy, mocker): + """Test that the _on_browse_clicked method correctly sets the + file path in the QLineEdit widget (file_path_edit). + The file path is provided by mocking the return of the + QFileDialog.getOpenFileName method. + """ + # Instantiate the napari viewer and the poses loader widget + viewer = make_napari_viewer_proxy() + poses_loader_widget = PosesLoader(viewer) + + # Mock the QFileDialog.getOpenFileName method to return the file path + mocker.patch( + "movement.napari._loader_widgets.QFileDialog.getOpenFileName", + return_value=(file_path, None), # tuple(file_path, filter) + ) + # Simulate the user clicking the 'Browse' button + poses_loader_widget._on_browse_clicked() + # Check that the file path edit text has been updated + assert poses_loader_widget.file_path_edit.text() == file_path + + +@pytest.mark.parametrize( + "source_software, expected_file_filter", + [ + ("DeepLabCut", "Poses files (*.h5 *.csv)"), + ("SLEAP", "Poses files (*.h5 *.slp)"), + ("LightningPose", "Poses files (*.csv)"), + ], +) +def test_file_filters_per_source_software( + source_software, expected_file_filter, make_napari_viewer_proxy, mocker +): + """Test that the file dialog is opened with the correct filters.""" + poses_loader_widget = PosesLoader(make_napari_viewer_proxy) + poses_loader_widget.source_software_combo.setCurrentText(source_software) + mock_file_dialog = mocker.patch( + "movement.napari._loader_widgets.QFileDialog.getOpenFileName", + return_value=("", None), + ) + poses_loader_widget._on_browse_clicked() + mock_file_dialog.assert_called_once_with( + poses_loader_widget, + caption="Open file containing predicted poses", + filter=expected_file_filter, + ) + + +def test_on_load_clicked_without_file_path(make_napari_viewer_proxy, capsys): + """Test that clicking 'Load' without a file path shows a warning.""" + # Instantiate the napari viewer and the poses loader widget + viewer = make_napari_viewer_proxy() + poses_loader_widget = PosesLoader(viewer) + # Call the _on_load_clicked method (pretend the user clicked "Load") + poses_loader_widget._on_load_clicked() + captured = capsys.readouterr() + assert "No file path specified." in captured.out + + +def test_on_load_clicked_with_valid_file_path( + make_napari_viewer_proxy, caplog +): + """Test clicking 'Load' with a valid file path. + + This test checks that the `_on_load_clicked` method causes the following: + - creates the `data`, `props`, and `file_name` attributes + - emits a series of expected log messages + - adds a Points layer to the viewer (with the expected name) + - sets the playback fps to the specified value + """ + # Instantiate the napari viewer and the poses loader widget + viewer = make_napari_viewer_proxy() + poses_loader_widget = PosesLoader(viewer) + # Set the file path to a valid file + file_path = pytest.DATA_PATHS.get("DLC_single-wasp.predictions.h5") + poses_loader_widget.file_path_edit.setText(file_path.as_posix()) + + # Set the fps to 60 + poses_loader_widget.fps_spinbox.setValue(60) + + # Call the _on_load_clicked method (pretend the user clicked "Load") + poses_loader_widget._on_load_clicked() + + # Check that class attributes have been created + assert poses_loader_widget.file_name == file_path.name + assert poses_loader_widget.data is not None + assert poses_loader_widget.props is not None + + # Check that the expected log messages were emitted + # Check that the expected log messages were emitted + expected_log_messages = { + "Converted poses dataset to a napari Tracks array.", + "Tracks array shape: (2170, 4)", + "Added poses dataset as a napari Points layer.", + "Set napari playback speed to 60 fps.", + } + log_messages = {record.getMessage() for record in caplog.records} + assert expected_log_messages <= log_messages + + # Check that a Points layer was added to the viewer + points_layer = poses_loader_widget.viewer.layers[0] + assert points_layer.name == f"poses: {file_path.name}" + + # Check that the playback fps was set correctly + assert get_settings().application.playback_fps == 60