diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index f601474d..34444722 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -7,7 +7,6 @@ import pandas as pd import xarray as xr from sleap_io.io.slp import read_labels -from sleap_io.model.instance import PredictedInstance from sleap_io.model.labels import Labels from movement.io.poses_accessor import PosesAccessor @@ -17,7 +16,7 @@ ValidPosesCSV, ValidPoseTracks, ) -from movement.logging import log_error, log_warning +from movement.logging import log_error logger = logging.getLogger(__name__) @@ -110,14 +109,8 @@ def from_sleap_file( "Export Analysis HDF5…" from the "File" menu) [1]_. This is the preferred format for loading pose tracks from SLEAP into *movement*. - You can also try directly loading the ".slp" file, but this feature - does not work in all cases. If the ".slp" file contains only predicted - instances, they will be imported alongside their point-wise confidence - scores. Alternatively, if it contains only user-labelled instances, they - will be loaded and assigned fixed point-wise confidence scores of 1.0. - Lastly, if a mix of both user-labelled and predicted instances are present, - only the predicted ones will be loaded. - If there are multiple videos in the file, only the first one will be used. + You can also directly load the ".slp" file. However, if the file contains + multiple videos, only the pose tracks from the first video will be loaded. *movement* expects the tracks to be assigned and proofread before loading them, meaning each track is interpreted as a single individual/animal. @@ -232,12 +225,6 @@ def _load_from_sleap_analysis_file( ------- movement.io.tracks_validators.ValidPoseTracks The validated pose tracks and confidence scores. - - Notes - ----- - If the point-wise confidence scores in the SLEAP - analysis file are all NaNs, this function assumes that the pose - tracks are user-labelled, and will assign a fixed score of 1.0. """ file = ValidHDF5(file_path, expected_datasets=["tracks"]) @@ -251,15 +238,6 @@ def _load_from_sleap_analysis_file( # and transpose to shape: (n_frames, n_tracks, n_keypoints) if "point_scores" in f.keys(): scores = f["point_scores"][:].transpose((2, 0, 1)) - if np.all(np.isnan(scores)): - # Assume user-labelled and assign 1.0 as score - mask = np.isnan(tracks[:, :, :, 1]) - scores = np.where(mask, scores, 1.0) - log_warning( - f"Could not find confidence scores in {file.path}. " - "Assuming pose tracks are user-labelled and " - "assigning fixed confidence scores of 1.0." - ) return ValidPoseTracks( tracks_array=tracks.astype(np.float32), scores_array=scores.astype(np.float32), @@ -290,21 +268,8 @@ def _load_from_sleap_labels_file( """ file = ValidHDF5(file_path, expected_datasets=["pred_points", "metadata"]) - labels = read_labels(file.path.as_posix()) - - if _sleap_labels_contains_user_labels_only(labels): - tracks_with_scores = _sleap_user_labels_to_numpy(labels) - log_warning( - f"Could not find PredictedInstances in {file.path}. " - "Assuming pose tracks are user-labelled and " - "assigning fixed confidence scores of 1.0." - ) - else: - tracks_with_scores = labels.numpy( - untracked=False, return_confidence=True - ) - + tracks_with_scores = _sleap_labels_to_numpy(labels) return ValidPoseTracks( tracks_array=tracks_with_scores[:, :, :, :-1], scores_array=tracks_with_scores[:, :, :, -1], @@ -314,32 +279,9 @@ def _load_from_sleap_labels_file( ) -def _sleap_labels_contains_user_labels_only(labels: Labels) -> bool: - """Check if a SLEAP `Labels` object contains only user-labelled instances. - - Parameters - ---------- - labels : Labels - A SLEAP `Labels` object. - - Returns - ------- - bool - A boolean value indicating whether the labels contain only - user-labelled instances. - """ - all_instances = [ - instance for lf in labels.labeled_frames for instance in lf.instances - ] - return all( - not isinstance(instance, PredictedInstance) - for instance in all_instances - ) - - -def _sleap_user_labels_to_numpy(labels: Labels) -> np.ndarray: +def _sleap_labels_to_numpy(labels: Labels) -> np.ndarray: """Convert a SLEAP `Labels` object to a NumPy array containing - pose tracks with point-wise confidence scores set as 1.0. + pose tracks with point-wise confidence scores. Parameters ---------- @@ -354,20 +296,17 @@ def _sleap_user_labels_to_numpy(labels: Labels) -> np.ndarray: Notes ----- This function only considers SLEAP instances in the first - video of the SLEAP `Labels` object. It is primarily meant - to be used with `Labels` containing only user-labelled instances. - If `Labels` contains predicted instances only, this function - will overwrite all point-wise confidence scores to 1.0. If `Labels` - contains both user-labelled and predicted instances, the returned - NumPy array will contain the last occurrence of each tracked instance - in each frame. + video of the SLEAP `Labels` object. User-labelled instances are + prioritised over predicted instances, mirroring SLEAP's approach + when exporting ".h5" analysis files [1]_. This function is adapted from `Labels.numpy()` from the - `sleap_io` package [1]_ to allow user-labelled instances. + `sleap_io` package [2]_. References ---------- - .. [1] https://github.com/talmolab/sleap-io + .. [1] https://github.com/talmolab/sleap/blob/v1.3.3/sleap/info/write_tracking_h5.py#L129-L150 + .. [2] https://github.com/talmolab/sleap-io """ # Select frames from the first video only lfs = [lf for lf in labels.labeled_frames if lf.video == labels.videos[0]] @@ -384,18 +323,24 @@ def _sleap_user_labels_to_numpy(labels: Labels) -> np.ndarray: for lf in lfs: i = int(lf.frame_idx - first_frame) - tracked_instances = [ - inst for inst in lf.instances if inst.track is not None - ] - for inst in tracked_instances: - j = labels.tracks.index(inst.track) - # Set all point scores to 1.0 - tracks[i, j] = np.hstack( - (inst.numpy(), np.full((n_nodes, 1), 1.0)) - ) - # Reset NaN point scores to NaN - mask = np.isnan(tracks[:, :, :, 1]) - tracks[:, :, :, 2] = np.where(mask, np.nan, tracks[:, :, :, 2]) + user_instances = lf.user_instances + predicted_instances = lf.predicted_instances + for j, track in enumerate(labels.tracks): + user_track_instances = [ + inst for inst in user_instances if inst.track == track + ] + predicted_track_instances = [ + inst for inst in predicted_instances if inst.track == track + ] + # Use user-labelled instance if available + if user_track_instances: + inst = user_track_instances[0] + tracks[i, j] = np.hstack( + (inst.numpy(), np.full((n_nodes, 1), np.nan)) + ) + elif predicted_track_instances: + inst = predicted_track_instances[0] + tracks[i, j] = inst.numpy(scores=True) return tracks diff --git a/tests/test_unit/test_load_poses.py b/tests/test_unit/test_load_poses.py index d5875657..dca31d6c 100644 --- a/tests/test_unit/test_load_poses.py +++ b/tests/test_unit/test_load_poses.py @@ -2,7 +2,6 @@ import pytest import xarray as xr from pytest import POSE_DATA -from sleap_io.io.slp import read_labels from movement.io import PosesAccessor, load_poses @@ -10,24 +9,6 @@ class TestLoadPoses: """Test suite for the load_poses module.""" - @pytest.fixture(scope="module") - def sleap_user_labels(self): - """Return a SLEAP `Labels` object containing only - user-labelled instances. - """ - slp_file = POSE_DATA.get( - "SLEAP_three-mice_Aeon_proofread.predictions.slp" - ) - return read_labels(slp_file) - - @pytest.fixture(scope="module") - def sleap_predicted_labels(self): - """Return a SLEAP `Labels` object containing only - predicted instances. - """ - slp_file = POSE_DATA.get("SLEAP_single-mouse_EPM.predictions.slp") - return read_labels(slp_file) - def assert_dataset( self, dataset, file_path=None, expected_source_software=None ): @@ -78,6 +59,10 @@ def test_load_from_slp_file(self, sleap_file): "SLEAP_three-mice_Aeon_proofread.analysis.h5", "SLEAP_three-mice_Aeon_proofread.predictions.slp", ), + ( + "SLEAP_three-mice_Aeon_mixed-labels.analysis.h5", + "SLEAP_three-mice_Aeon_mixed-labels.predictions.slp", + ), ], ) def test_load_from_sleap_slp_file_or_h5_file_returns_same( @@ -91,37 +76,6 @@ def test_load_from_sleap_slp_file_or_h5_file_returns_same( ds_from_h5 = load_poses.from_sleap_file(h5_file_path) xr.testing.assert_allclose(ds_from_h5, ds_from_slp) - @pytest.mark.parametrize( - "sleap_labels, expected", - [ - ("sleap_user_labels", True), - ("sleap_predicted_labels", False), - ], - ) - def test_sleap_labels_contains_user_labels_only( - self, sleap_labels, expected, request - ): - labels = request.getfixturevalue(sleap_labels) - assert ( - load_poses._sleap_labels_contains_user_labels_only(labels) - == expected - ) - - @pytest.mark.parametrize( - "sleap_labels", - [ - "sleap_user_labels", - "sleap_predicted_labels", - ], - ) - def test_sleap_user_labels_to_numpy_confidence_equals_one( - self, sleap_labels, request - ): - labels = request.getfixturevalue(sleap_labels) - confidence = load_poses._sleap_user_labels_to_numpy(labels)[:, :, :, 2] - mask = np.isnan(confidence) - assert np.all(confidence[~mask] == 1) - @pytest.mark.parametrize( "file_name", [