diff --git a/docs/source/conf.py b/docs/source/conf.py index 007496d7..a122cf07 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -139,6 +139,7 @@ linkcheck_anchors_ignore_for_url = [ "https://gin.g-node.org/G-Node/Info/wiki/", "https://neuroinformatics.zulipchat.com/", + "https://github.com/talmolab/sleap/blob/v1.3.3/sleap/info/write_tracking_h5.py", ] myst_url_schemes = { diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index 2ed8d4de..e12dec58 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -7,6 +7,7 @@ import pandas as pd import xarray as xr from sleap_io.io.slp import read_labels +from sleap_io.model.labels import Labels from movement.io.poses_accessor import PosesAccessor from movement.io.validators import ( @@ -108,20 +109,21 @@ def from_sleap_file( "Export Analysis HDF5…" from the "File" menu) [1]_. This is the preferred format for loading pose tracks from SLEAP into *movement*. - You can also try directly loading the ".slp" file, but this feature is - experimental and doesnot work in all cases. If the ".slp" file contains - both user-labeled and predicted instances, only the predicted ones will be - loaded. If there are multiple videos in the file, only the first one will - be used. + You can also directly load the ".slp" file. However, if the file contains + multiple videos, only the pose tracks from the first video will be loaded. + If the file contains a mix of user-labelled and predicted instances, user + labels are prioritised over predicted instances to mirror SLEAP's approach + when exporting ".h5" analysis files [2]_. *movement* expects the tracks to be assigned and proofread before loading them, meaning each track is interpreted as a single individual/animal. - Follow the SLEAP guide for tracking and proofreading [2]_. + Follow the SLEAP guide for tracking and proofreading [3]_. References ---------- .. [1] https://sleap.ai/tutorials/analysis.html - .. [2] https://sleap.ai/guides/proofreading.html + .. [2] https://github.com/talmolab/sleap/blob/v1.3.3/sleap/info/write_tracking_h5.py#L129-L150 + .. [3] https://sleap.ai/guides/proofreading.html Examples -------- @@ -235,15 +237,14 @@ def _load_from_sleap_analysis_file( # transpose to shape: (n_frames, n_tracks, n_keypoints, n_space) tracks = f["tracks"][:].transpose((3, 0, 2, 1)) # Create an array of NaNs for the confidence scores - scores = np.full(tracks.shape[:-1], np.nan, dtype="float32") + scores = np.full(tracks.shape[:-1], np.nan) # If present, read the point-wise scores, # and transpose to shape: (n_frames, n_tracks, n_keypoints) if "point_scores" in f.keys(): scores = f["point_scores"][:].transpose((2, 0, 1)) - return ValidPoseTracks( - tracks_array=tracks, - scores_array=scores, + tracks_array=tracks.astype(np.float32), + scores_array=scores.astype(np.float32), individual_names=[n.decode() for n in f["track_names"][:]], keypoint_names=[n.decode() for n in f["node_names"][:]], fps=fps, @@ -271,10 +272,8 @@ def _load_from_sleap_labels_file( """ file = ValidHDF5(file_path, expected_datasets=["pred_points", "metadata"]) - labels = read_labels(file.path.as_posix()) - tracks_with_scores = labels.numpy(untracked=False, return_confidence=True) - + tracks_with_scores = _sleap_labels_to_numpy(labels) return ValidPoseTracks( tracks_array=tracks_with_scores[:, :, :, :-1], scores_array=tracks_with_scores[:, :, :, -1], @@ -284,16 +283,81 @@ def _load_from_sleap_labels_file( ) +def _sleap_labels_to_numpy(labels: Labels) -> np.ndarray: + """Convert a SLEAP `Labels` object to a NumPy array containing + pose tracks with point-wise confidence scores. + + Parameters + ---------- + labels : Labels + A SLEAP `Labels` object. + + Returns + ------- + numpy.ndarray + A NumPy array containing pose tracks and confidence scores. + + Notes + ----- + This function only considers SLEAP instances in the first + video of the SLEAP `Labels` object. User-labelled instances are + prioritised over predicted instances, mirroring SLEAP's approach + when exporting ".h5" analysis files [1]_. + + This function is adapted from `Labels.numpy()` from the + `sleap_io` package [2]_. + + References + ---------- + .. [1] https://github.com/talmolab/sleap/blob/v1.3.3/sleap/info/write_tracking_h5.py#L129-L150 + .. [2] https://github.com/talmolab/sleap-io + """ + # Select frames from the first video only + lfs = [lf for lf in labels.labeled_frames if lf.video == labels.videos[0]] + # Figure out frame index range + frame_idxs = [lf.frame_idx for lf in lfs] + first_frame = min(0, min(frame_idxs)) + last_frame = max(0, max(frame_idxs)) + + n_tracks = len(labels.tracks) + skeleton = labels.skeletons[-1] # Assume project only uses last skeleton + n_nodes = len(skeleton.nodes) + n_frames = int(last_frame - first_frame + 1) + tracks = np.full((n_frames, n_tracks, n_nodes, 3), np.nan, dtype="float32") + + for lf in lfs: + i = int(lf.frame_idx - first_frame) + user_instances = lf.user_instances + predicted_instances = lf.predicted_instances + for j, track in enumerate(labels.tracks): + user_track_instances = [ + inst for inst in user_instances if inst.track == track + ] + predicted_track_instances = [ + inst for inst in predicted_instances if inst.track == track + ] + # Use user-labelled instance if available + if user_track_instances: + inst = user_track_instances[0] + tracks[i, j] = np.hstack( + (inst.numpy(), np.full((n_nodes, 1), np.nan)) + ) + elif predicted_track_instances: + inst = predicted_track_instances[0] + tracks[i, j] = inst.numpy(scores=True) + return tracks + + def _parse_dlc_csv_to_df(file_path: Path) -> pd.DataFrame: """If poses are loaded from a DeepLabCut .csv file, the DataFrame lacks the multi-index columns that are present in the .h5 file. This - function parses the csv file to a pandas DataFrame with multi-index + function parses the .csv file to a pandas DataFrame with multi-index columns, i.e. the same format as in the .h5 file. Parameters ---------- file_path : pathlib.Path - Path to the DeepLabCut-style CSV file. + Path to the DeepLabCut-style .csv file. Returns ------- diff --git a/tests/conftest.py b/tests/conftest.py index 835f57e1..bc7a47e3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,7 +46,7 @@ def setup_logging(tmp_path): @pytest.fixture def unreadable_file(tmp_path): """Return a dictionary containing the file path and - expected permission for an unreadable h5 file.""" + expected permission for an unreadable .h5 file.""" file_path = tmp_path / "unreadable.h5" with open(file_path, "w") as f: f.write("unreadable data") @@ -61,7 +61,7 @@ def unreadable_file(tmp_path): @pytest.fixture def unwriteable_file(tmp_path): """Return a dictionary containing the file path and - expected permission for an unwriteable h5 file.""" + expected permission for an unwriteable .h5 file.""" unwriteable_dir = tmp_path / "no_write" unwriteable_dir.mkdir() os.chmod(unwriteable_dir, not stat.S_IWUSR) @@ -101,7 +101,7 @@ def nonexistent_file(tmp_path): @pytest.fixture -def directory(tmp_path): # used in save_poses, validators +def directory(tmp_path): """Return a dictionary containing the file path and expected permission for a directory.""" file_path = tmp_path / "directory" @@ -115,7 +115,7 @@ def directory(tmp_path): # used in save_poses, validators @pytest.fixture def h5_file_no_dataframe(tmp_path): """Return a dictionary containing the file path and - expected datasets for an h5 file with no dataframe.""" + expected datasets for a .h5 file with no dataframe.""" file_path = tmp_path / "no_dataframe.h5" with h5py.File(file_path, "w") as f: f.create_dataset("data_in_list", data=[1, 2, 3]) @@ -126,7 +126,7 @@ def h5_file_no_dataframe(tmp_path): @pytest.fixture -def fake_h5_file(tmp_path): # used in save_poses, validators +def fake_h5_file(tmp_path): """Return a dictionary containing the file path, expected exception, and expected datasets for a file with .h5 extension that is not in HDF5 format. @@ -143,7 +143,7 @@ def fake_h5_file(tmp_path): # used in save_poses, validators @pytest.fixture def invalid_single_animal_csv_file(tmp_path): - """Return the file path for a fake single-animal csv file.""" + """Return the file path for a fake single-animal .csv file.""" file_path = tmp_path / "fake_single_animal.csv" with open(file_path, "w") as f: f.write("scorer,columns\nsome,columns\ncoords,columns\n") @@ -153,7 +153,7 @@ def invalid_single_animal_csv_file(tmp_path): @pytest.fixture def invalid_multi_animal_csv_file(tmp_path): - """Return the file path for a fake multi-animal csv file.""" + """Return the file path for a fake multi-animal .csv file.""" file_path = tmp_path / "fake_multi_animal.csv" with open(file_path, "w") as f: f.write( @@ -186,7 +186,7 @@ def sleap_file(request): @pytest.fixture def valid_tracks_array(): - """Return a function that generate different kinds + """Return a function that generates different kinds of valid tracks array.""" def _valid_tracks_array(array_type): diff --git a/tests/test_integration/test_io.py b/tests/test_integration/test_io.py index 468fd108..2070ae9e 100644 --- a/tests/test_integration/test_io.py +++ b/tests/test_integration/test_io.py @@ -8,6 +8,11 @@ class TestPosesIO: """Test the IO functionalities of the PoseTracks class.""" + @pytest.fixture(params=["dlc.h5", "dlc.csv"]) + def dlc_output_file(self, request, tmp_path): + """Return the output file path for a DLC .h5 or .csv file.""" + return tmp_path / request.param + def test_load_and_save_to_dlc_df(self, dlc_style_df): """Test that loading pose tracks from a DLC-style DataFrame and converting back to a DataFrame returns the same data values.""" @@ -15,12 +20,18 @@ def test_load_and_save_to_dlc_df(self, dlc_style_df): df = save_poses.to_dlc_df(ds) np.testing.assert_allclose(df.values, dlc_style_df.values) - @pytest.mark.parametrize("file_name", ["dlc.h5", "dlc.csv"]) - def test_save_and_load_dlc_file( - self, file_name, valid_pose_dataset, tmp_path - ): + def test_save_and_load_dlc_file(self, dlc_output_file, valid_pose_dataset): """Test that saving pose tracks to DLC .h5 and .csv files and then loading them back in returns the same Dataset.""" - save_poses.to_dlc_file(valid_pose_dataset, tmp_path / file_name) - ds = load_poses.from_dlc_file(tmp_path / file_name) + save_poses.to_dlc_file(valid_pose_dataset, dlc_output_file) + ds = load_poses.from_dlc_file(dlc_output_file) xr.testing.assert_allclose(ds, valid_pose_dataset) + + def test_convert_sleap_to_dlc_file(self, sleap_file, dlc_output_file): + """Test that pose tracks loaded from SLEAP .slp and .h5 files, + when converted to DLC .h5 and .csv files and re-loaded return + the same Datasets.""" + sleap_ds = load_poses.from_sleap_file(sleap_file) + save_poses.to_dlc_file(sleap_ds, dlc_output_file) + dlc_ds = load_poses.from_dlc_file(dlc_output_file) + xr.testing.assert_allclose(sleap_ds, dlc_ds) diff --git a/tests/test_unit/test_load_poses.py b/tests/test_unit/test_load_poses.py index 7183758c..dca31d6c 100644 --- a/tests/test_unit/test_load_poses.py +++ b/tests/test_unit/test_load_poses.py @@ -48,6 +48,34 @@ def test_load_from_slp_file(self, sleap_file): ds = load_poses.from_sleap_file(sleap_file) self.assert_dataset(ds, sleap_file, "SLEAP") + @pytest.mark.parametrize( + "slp_file, h5_file", + [ + ( + "SLEAP_single-mouse_EPM.analysis.h5", + "SLEAP_single-mouse_EPM.predictions.slp", + ), + ( + "SLEAP_three-mice_Aeon_proofread.analysis.h5", + "SLEAP_three-mice_Aeon_proofread.predictions.slp", + ), + ( + "SLEAP_three-mice_Aeon_mixed-labels.analysis.h5", + "SLEAP_three-mice_Aeon_mixed-labels.predictions.slp", + ), + ], + ) + def test_load_from_sleap_slp_file_or_h5_file_returns_same( + self, slp_file, h5_file + ): + """Test that loading pose tracks from SLEAP .slp and .h5 files + return the same Dataset.""" + slp_file_path = POSE_DATA.get(slp_file) + h5_file_path = POSE_DATA.get(h5_file) + ds_from_slp = load_poses.from_sleap_file(slp_file_path) + ds_from_h5 = load_poses.from_sleap_file(h5_file_path) + xr.testing.assert_allclose(ds_from_h5, ds_from_slp) + @pytest.mark.parametrize( "file_name", [ diff --git a/tests/test_unit/test_save_poses.py b/tests/test_unit/test_save_poses.py index 8446ef5a..2f086b4a 100644 --- a/tests/test_unit/test_save_poses.py +++ b/tests/test_unit/test_save_poses.py @@ -34,12 +34,12 @@ def new_file_wrong_ext(self, tmp_path): @pytest.fixture def new_dlc_h5_file(self, tmp_path): - """Return the file path for a new DeepLabCut H5 file.""" + """Return the file path for a new DeepLabCut .h5 file.""" return tmp_path / "new_dlc_file.h5" @pytest.fixture def new_dlc_csv_file(self, tmp_path): - """Return the file path for a new DeepLabCut csv file.""" + """Return the file path for a new DeepLabCut .csv file.""" return tmp_path / "new_dlc_file.csv" @pytest.fixture @@ -63,6 +63,20 @@ def missing_dim_dataset(self, valid_pose_dataset): ), does_not_raise(), ), # valid dataset + ( + load_poses.from_sleap_file( + POSE_DATA.get("SLEAP_single-mouse_EPM.analysis.h5") + ), + does_not_raise(), + ), # valid dataset + ( + load_poses.from_sleap_file( + POSE_DATA.get( + "SLEAP_three-mice_Aeon_proofread.predictions.slp" + ) + ), + does_not_raise(), + ), # valid dataset ], ) def test_to_dlc_df(self, ds, expected_exception):