diff --git a/tests/conftest.py b/tests/conftest.py index f0c3a0f56..8b4727356 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -882,3 +882,50 @@ def count_consecutive_nans(da): def helpers(): """Return an instance of the ``Helpers`` class.""" return Helpers + + +# --------- movement dataset assertion fixtures --------- +class MovementDatasetAsserts: + """Class for asserting valid ``movement`` poses or bboxes datasets.""" + + @staticmethod + def valid_dataset(dataset, expected_values): + """Assert the dataset is a proper ``movement`` Dataset.""" + expected_dim_names = expected_values.get("dim_names") + expected_file_path = expected_values.get("file_path") + assert isinstance(dataset, xr.Dataset) + # Expected variables are present and of right shape/type + for var, ndim in expected_values.get("vars_dims").items(): + data_var = dataset.get(var) + assert isinstance(data_var, xr.DataArray) + assert data_var.ndim == ndim + position_shape = dataset.position.shape + # Confidence has the same shape as position, except for the space dim + assert ( + dataset.confidence.shape == position_shape[:1] + position_shape[2:] + ) + # Check the dims and coords + expected_dim_length_dict = dict( + zip(expected_dim_names, position_shape, strict=True) + ) + assert expected_dim_length_dict == dataset.sizes + # Check the coords + for dim in expected_dim_names[1:]: + assert all(isinstance(s, str) for s in dataset.coords[dim].values) + assert all(coord in dataset.coords["space"] for coord in ["x", "y"]) + # Check the metadata attributes + assert dataset.source_file == ( + expected_file_path.as_posix() + if expected_file_path is not None + else None + ) + assert dataset.source_software == expected_values.get( + "source_software" + ) + assert dataset.fps == expected_values.get("fps") + + +@pytest.fixture +def movement_dataset_asserts(): + """Return an instance of the ``MovementDatasetAsserts`` class.""" + return MovementDatasetAsserts diff --git a/tests/test_unit/test_load_bboxes.py b/tests/test_unit/test_load_bboxes.py index 4200b1c90..697014576 100644 --- a/tests/test_unit/test_load_bboxes.py +++ b/tests/test_unit/test_load_bboxes.py @@ -6,7 +6,6 @@ import numpy as np import pandas as pd import pytest -import xarray as xr from movement.io import load_bboxes from movement.validators.datasets import ValidBboxesDataset @@ -112,49 +111,6 @@ def update_attribute_column(df_input, attribute_column_name, dict_to_append): return df -def assert_dataset( - dataset, file_path=None, expected_source_software=None, expected_fps=None -): - """Assert that the dataset is a proper ``movement`` Dataset.""" - assert isinstance(dataset, xr.Dataset) - - # Expected variables are present and of right shape/type - for var in ["position", "shape", "confidence"]: - assert var in dataset.data_vars - assert isinstance(dataset[var], xr.DataArray) - assert dataset.position.ndim == 3 - assert dataset.shape.ndim == 3 - position_shape = dataset.position.shape - # Confidence has the same shape as position, except for the space dim - assert dataset.confidence.shape == position_shape[:1] + position_shape[2:] - # Check the dims and coords - dim_names = ValidBboxesDataset.DIM_NAMES - expected_dim_length_dict = dict( - zip(dim_names, position_shape, strict=True) - ) - assert expected_dim_length_dict == dataset.sizes - # Check the coords - for dim in dim_names[1:]: - assert all(isinstance(s, str) for s in dataset.coords[dim].values) - assert all(coord in dataset.coords["space"] for coord in ["x", "y"]) - # Check the metadata attributes - assert ( - dataset.source_file is None - if file_path is None - else dataset.source_file == file_path.as_posix() - ) - assert ( - dataset.source_software is None - if expected_source_software is None - else dataset.source_software == expected_source_software - ) - assert ( - dataset.fps is None - if expected_fps is None - else dataset.fps == expected_fps - ) - - def assert_time_coordinates(ds, fps, start_frame): """Assert that the time coordinates are as expected, depending on fps value and start_frame. @@ -210,10 +166,16 @@ def test_from_file(source_software, fps, use_frame_numbers_from_file): ) +expected_values_bboxes = { + "vars_dims": {"position": 3, "shape": 3, "confidence": 2}, + "dim_names": ValidBboxesDataset.DIM_NAMES, +} + + @pytest.mark.parametrize("fps", [None, 30, 60.0]) @pytest.mark.parametrize("use_frame_numbers_from_file", [True, False]) def test_from_via_tracks_file( - via_tracks_file, fps, use_frame_numbers_from_file + via_tracks_file, fps, use_frame_numbers_from_file, movement_dataset_asserts ): """Test that loading tracked bounding box data from a valid VIA tracks .csv file returns a proper Dataset @@ -223,8 +185,13 @@ def test_from_via_tracks_file( ds = load_bboxes.from_via_tracks_file( via_tracks_file, fps, use_frame_numbers_from_file ) - assert_dataset(ds, via_tracks_file, "VIA-tracks", fps) - + expected_values = { + **expected_values_bboxes, + "source_software": "VIA-tracks", + "fps": fps, + "file_path": via_tracks_file, + } + movement_dataset_asserts.valid_dataset(ds, expected_values) # check time coordinates are as expected # in sample VIA tracks .csv file frame numbers start from 1 start_frame = 1 if use_frame_numbers_from_file else 0 @@ -240,28 +207,36 @@ def test_from_via_tracks_file( ) @pytest.mark.parametrize("fps", [None, 30, 60.0]) @pytest.mark.parametrize("source_software", [None, "VIA-tracks"]) -def test_from_numpy(valid_from_numpy_inputs, fps, source_software, request): +def test_from_numpy( + valid_from_numpy_inputs, + fps, + source_software, + movement_dataset_asserts, + request, +): """Test that loading bounding boxes trajectories from the input numpy arrays returns a proper Dataset. """ # get the input arrays from_numpy_inputs = request.getfixturevalue(valid_from_numpy_inputs) - # run general dataset checks ds = load_bboxes.from_numpy( **from_numpy_inputs, fps=fps, source_software=source_software, ) - assert_dataset( - ds, expected_source_software=source_software, expected_fps=fps - ) - + expected_values = { + **expected_values_bboxes, + "source_software": source_software, + "fps": fps, + } + movement_dataset_asserts.valid_dataset(ds, expected_values) # check time coordinates are as expected - if "frame_array" in from_numpy_inputs: - start_frame = from_numpy_inputs["frame_array"][0, 0] - else: - start_frame = 0 + start_frame = ( + from_numpy_inputs["frame_array"][0, 0] + if "frame_array" in from_numpy_inputs + else 0 + ) assert_time_coordinates(ds, fps, start_frame) @@ -417,10 +392,11 @@ def test_fps_and_time_coords( assert ds.fps == expected_fps # check time coordinates - if use_frame_numbers_from_file: - start_frame = ds_in_frames_from_file.coords["time"].data[0] - else: - start_frame = 0 + start_frame = ( + ds_in_frames_from_file.coords["time"].data[0] + if use_frame_numbers_from_file + else 0 + ) assert_time_coordinates(ds, expected_fps, start_frame) diff --git a/tests/test_unit/test_load_poses.py b/tests/test_unit/test_load_poses.py index 1a1a4c90c..722c67995 100644 --- a/tests/test_unit/test_load_poses.py +++ b/tests/test_unit/test_load_poses.py @@ -68,48 +68,23 @@ def sleap_file_without_tracks(request): return request.getfixturevalue(request.param) -def assert_dataset(dataset, file_path=None, expected_source_software=None): - """Assert that the dataset is a proper xarray Dataset.""" - assert isinstance(dataset, xr.Dataset) - # Expected variables are present and of right shape/type - for var in ["position", "confidence"]: - assert var in dataset.data_vars - assert isinstance(dataset[var], xr.DataArray) - assert dataset.position.ndim == 4 - - position_shape = dataset.position.shape - # Confidence has the same shape as position, except for the space dim - assert dataset.confidence.shape == position_shape[:1] + position_shape[2:] - # Check the dims - dim_names = ValidPosesDataset.DIM_NAMES - expected_dim_length_dict = dict( - zip(dim_names, position_shape, strict=True) - ) - assert expected_dim_length_dict == dataset.sizes - # Check the coords - for dim in dim_names[1:]: - assert all(isinstance(s, str) for s in dataset.coords[dim].values) - assert all(coord in dataset.coords["space"] for coord in ["x", "y"]) - # Check the metadata attributes - assert ( - dataset.source_file is None - if file_path is None - else dataset.source_file == file_path.as_posix() - ) - assert ( - dataset.source_software is None - if expected_source_software is None - else dataset.source_software == expected_source_software - ) - assert dataset.fps is None +expected_values_poses = { + "vars_dims": {"position": 4, "confidence": 3}, + "dim_names": ValidPosesDataset.DIM_NAMES, +} -def test_load_from_sleap_file(sleap_file): +def test_load_from_sleap_file(sleap_file, movement_dataset_asserts): """Test that loading pose tracks from valid SLEAP files returns a proper Dataset. """ ds = load_poses.from_sleap_file(sleap_file) - assert_dataset(ds, sleap_file, "SLEAP") + expected_values = { + **expected_values_poses, + "source_software": "SLEAP", + "file_path": sleap_file, + } + movement_dataset_asserts.valid_dataset(ds, expected_values) def test_load_from_sleap_file_without_tracks(sleap_file_without_tracks): @@ -167,26 +142,37 @@ def test_load_from_sleap_slp_file_or_h5_file_returns_same(slp_file, h5_file): "DLC_two-mice.predictions.csv", ], ) -def test_load_from_dlc_file(file_name): +def test_load_from_dlc_file(file_name, movement_dataset_asserts): """Test that loading pose tracks from valid DLC files returns a proper Dataset. """ file_path = DATA_PATHS.get(file_name) ds = load_poses.from_dlc_file(file_path) - assert_dataset(ds, file_path, "DeepLabCut") + expected_values = { + **expected_values_poses, + "source_software": "DeepLabCut", + "file_path": file_path, + } + movement_dataset_asserts.valid_dataset(ds, expected_values) @pytest.mark.parametrize( "source_software", ["DeepLabCut", "LightningPose", None] ) -def test_load_from_dlc_style_df(dlc_style_df, source_software): +def test_load_from_dlc_style_df( + dlc_style_df, source_software, movement_dataset_asserts +): """Test that loading pose tracks from a valid DLC-style DataFrame returns a proper Dataset. """ ds = load_poses.from_dlc_style_df( dlc_style_df, source_software=source_software ) - assert_dataset(ds, expected_source_software=source_software) + expected_values = { + **expected_values_poses, + "source_software": source_software, + } + movement_dataset_asserts.valid_dataset(ds, expected_values) def test_load_from_dlc_file_csv_or_h5_file_returns_same(): @@ -234,13 +220,18 @@ def test_fps_and_time_coords(fps, expected_fps, expected_time_unit): "LP_mouse-twoview_AIND.predictions.csv", ], ) -def test_load_from_lp_file(file_name): +def test_load_from_lp_file(file_name, movement_dataset_asserts): """Test that loading pose tracks from valid LightningPose (LP) files returns a proper Dataset. """ file_path = DATA_PATHS.get(file_name) ds = load_poses.from_lp_file(file_path) - assert_dataset(ds, file_path, "LightningPose") + expected_values = { + **expected_values_poses, + "source_software": "LightningPose", + "file_path": file_path, + } + movement_dataset_asserts.valid_dataset(ds, expected_values) def test_load_from_lp_or_dlc_file_returns_same(): @@ -289,14 +280,15 @@ def test_from_file_delegates_correctly(source_software, fps): @pytest.mark.parametrize("source_software", [None, "SLEAP"]) -def test_from_numpy_valid(valid_position_array, source_software): +def test_from_numpy_valid( + valid_position_array, source_software, movement_dataset_asserts +): """Test that loading pose tracks from a multi-animal numpy array with valid parameters returns a proper Dataset. """ valid_position = valid_position_array("multi_individual_array") rng = np.random.default_rng(seed=42) valid_confidence = rng.random(valid_position.shape[:-1]) - ds = load_poses.from_numpy( valid_position, valid_confidence, @@ -305,7 +297,11 @@ def test_from_numpy_valid(valid_position_array, source_software): fps=None, source_software=source_software, ) - assert_dataset(ds, expected_source_software=source_software) + expected_values = { + **expected_values_poses, + "source_software": source_software, + } + movement_dataset_asserts.valid_dataset(ds, expected_values) def test_from_multiview_files():