Skip to content

Commit

Permalink
Refactor load from SLEAP labels file
Browse files Browse the repository at this point in the history
  • Loading branch information
lochhh committed Nov 15, 2023
1 parent 1f571fa commit 724b8f0
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 135 deletions.
115 changes: 30 additions & 85 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pandas as pd
import xarray as xr
from sleap_io.io.slp import read_labels
from sleap_io.model.instance import PredictedInstance
from sleap_io.model.labels import Labels

from movement.io.poses_accessor import PosesAccessor
Expand All @@ -17,7 +16,7 @@
ValidPosesCSV,
ValidPoseTracks,
)
from movement.logging import log_error, log_warning
from movement.logging import log_error

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -110,14 +109,8 @@ def from_sleap_file(
"Export Analysis HDF5…" from the "File" menu) [1]_. This is the
preferred format for loading pose tracks from SLEAP into *movement*.
You can also try directly loading the ".slp" file, but this feature
does not work in all cases. If the ".slp" file contains only predicted
instances, they will be imported alongside their point-wise confidence
scores. Alternatively, if it contains only user-labelled instances, they
will be loaded and assigned fixed point-wise confidence scores of 1.0.
Lastly, if a mix of both user-labelled and predicted instances are present,
only the predicted ones will be loaded.
If there are multiple videos in the file, only the first one will be used.
You can also directly load the ".slp" file. However, if the file contains
multiple videos, only the pose tracks from the first video will be loaded.
*movement* expects the tracks to be assigned and proofread before loading
them, meaning each track is interpreted as a single individual/animal.
Expand Down Expand Up @@ -232,12 +225,6 @@ def _load_from_sleap_analysis_file(
-------
movement.io.tracks_validators.ValidPoseTracks
The validated pose tracks and confidence scores.
Notes
-----
If the point-wise confidence scores in the SLEAP
analysis file are all NaNs, this function assumes that the pose
tracks are user-labelled, and will assign a fixed score of 1.0.
"""

file = ValidHDF5(file_path, expected_datasets=["tracks"])
Expand All @@ -251,15 +238,6 @@ def _load_from_sleap_analysis_file(
# and transpose to shape: (n_frames, n_tracks, n_keypoints)
if "point_scores" in f.keys():
scores = f["point_scores"][:].transpose((2, 0, 1))
if np.all(np.isnan(scores)):
# Assume user-labelled and assign 1.0 as score
mask = np.isnan(tracks[:, :, :, 1])
scores = np.where(mask, scores, 1.0)
log_warning(
f"Could not find confidence scores in {file.path}. "
"Assuming pose tracks are user-labelled and "
"assigning fixed confidence scores of 1.0."
)
return ValidPoseTracks(
tracks_array=tracks.astype(np.float32),
scores_array=scores.astype(np.float32),
Expand Down Expand Up @@ -290,21 +268,8 @@ def _load_from_sleap_labels_file(
"""

file = ValidHDF5(file_path, expected_datasets=["pred_points", "metadata"])

labels = read_labels(file.path.as_posix())

if _sleap_labels_contains_user_labels_only(labels):
tracks_with_scores = _sleap_user_labels_to_numpy(labels)
log_warning(
f"Could not find PredictedInstances in {file.path}. "
"Assuming pose tracks are user-labelled and "
"assigning fixed confidence scores of 1.0."
)
else:
tracks_with_scores = labels.numpy(
untracked=False, return_confidence=True
)

tracks_with_scores = _sleap_labels_to_numpy(labels)
return ValidPoseTracks(
tracks_array=tracks_with_scores[:, :, :, :-1],
scores_array=tracks_with_scores[:, :, :, -1],
Expand All @@ -314,32 +279,9 @@ def _load_from_sleap_labels_file(
)


def _sleap_labels_contains_user_labels_only(labels: Labels) -> bool:
"""Check if a SLEAP `Labels` object contains only user-labelled instances.
Parameters
----------
labels : Labels
A SLEAP `Labels` object.
Returns
-------
bool
A boolean value indicating whether the labels contain only
user-labelled instances.
"""
all_instances = [
instance for lf in labels.labeled_frames for instance in lf.instances
]
return all(
not isinstance(instance, PredictedInstance)
for instance in all_instances
)


def _sleap_user_labels_to_numpy(labels: Labels) -> np.ndarray:
def _sleap_labels_to_numpy(labels: Labels) -> np.ndarray:
"""Convert a SLEAP `Labels` object to a NumPy array containing
pose tracks with point-wise confidence scores set as 1.0.
pose tracks with point-wise confidence scores.
Parameters
----------
Expand All @@ -354,20 +296,17 @@ def _sleap_user_labels_to_numpy(labels: Labels) -> np.ndarray:
Notes
-----
This function only considers SLEAP instances in the first
video of the SLEAP `Labels` object. It is primarily meant
to be used with `Labels` containing only user-labelled instances.
If `Labels` contains predicted instances only, this function
will overwrite all point-wise confidence scores to 1.0. If `Labels`
contains both user-labelled and predicted instances, the returned
NumPy array will contain the last occurrence of each tracked instance
in each frame.
video of the SLEAP `Labels` object. User-labelled instances are
prioritised over predicted instances, mirroring SLEAP's approach
when exporting ".h5" analysis files [1]_.
This function is adapted from `Labels.numpy()` from the
`sleap_io` package [1]_ to allow user-labelled instances.
`sleap_io` package [2]_.
References
----------
.. [1] https://github.com/talmolab/sleap-io
.. [1] https://github.com/talmolab/sleap/blob/v1.3.3/sleap/info/write_tracking_h5.py#L129-L150
.. [2] https://github.com/talmolab/sleap-io
"""
# Select frames from the first video only
lfs = [lf for lf in labels.labeled_frames if lf.video == labels.videos[0]]
Expand All @@ -384,18 +323,24 @@ def _sleap_user_labels_to_numpy(labels: Labels) -> np.ndarray:

for lf in lfs:
i = int(lf.frame_idx - first_frame)
tracked_instances = [
inst for inst in lf.instances if inst.track is not None
]
for inst in tracked_instances:
j = labels.tracks.index(inst.track)
# Set all point scores to 1.0
tracks[i, j] = np.hstack(
(inst.numpy(), np.full((n_nodes, 1), 1.0))
)
# Reset NaN point scores to NaN
mask = np.isnan(tracks[:, :, :, 1])
tracks[:, :, :, 2] = np.where(mask, np.nan, tracks[:, :, :, 2])
user_instances = lf.user_instances
predicted_instances = lf.predicted_instances
for j, track in enumerate(labels.tracks):
user_track_instances = [
inst for inst in user_instances if inst.track == track
]
predicted_track_instances = [
inst for inst in predicted_instances if inst.track == track
]
# Use user-labelled instance if available
if user_track_instances:
inst = user_track_instances[0]
tracks[i, j] = np.hstack(
(inst.numpy(), np.full((n_nodes, 1), np.nan))
)
elif predicted_track_instances:
inst = predicted_track_instances[0]
tracks[i, j] = inst.numpy(scores=True)
return tracks


Expand Down
54 changes: 4 additions & 50 deletions tests/test_unit/test_load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,13 @@
import pytest
import xarray as xr
from pytest import POSE_DATA
from sleap_io.io.slp import read_labels

from movement.io import PosesAccessor, load_poses


class TestLoadPoses:
"""Test suite for the load_poses module."""

@pytest.fixture(scope="module")
def sleap_user_labels(self):
"""Return a SLEAP `Labels` object containing only
user-labelled instances.
"""
slp_file = POSE_DATA.get(
"SLEAP_three-mice_Aeon_proofread.predictions.slp"
)
return read_labels(slp_file)

@pytest.fixture(scope="module")
def sleap_predicted_labels(self):
"""Return a SLEAP `Labels` object containing only
predicted instances.
"""
slp_file = POSE_DATA.get("SLEAP_single-mouse_EPM.predictions.slp")
return read_labels(slp_file)

def assert_dataset(
self, dataset, file_path=None, expected_source_software=None
):
Expand Down Expand Up @@ -78,6 +59,10 @@ def test_load_from_slp_file(self, sleap_file):
"SLEAP_three-mice_Aeon_proofread.analysis.h5",
"SLEAP_three-mice_Aeon_proofread.predictions.slp",
),
(
"SLEAP_three-mice_Aeon_mixed-labels.analysis.h5",
"SLEAP_three-mice_Aeon_mixed-labels.predictions.slp",
),
],
)
def test_load_from_sleap_slp_file_or_h5_file_returns_same(
Expand All @@ -91,37 +76,6 @@ def test_load_from_sleap_slp_file_or_h5_file_returns_same(
ds_from_h5 = load_poses.from_sleap_file(h5_file_path)
xr.testing.assert_allclose(ds_from_h5, ds_from_slp)

@pytest.mark.parametrize(
"sleap_labels, expected",
[
("sleap_user_labels", True),
("sleap_predicted_labels", False),
],
)
def test_sleap_labels_contains_user_labels_only(
self, sleap_labels, expected, request
):
labels = request.getfixturevalue(sleap_labels)
assert (
load_poses._sleap_labels_contains_user_labels_only(labels)
== expected
)

@pytest.mark.parametrize(
"sleap_labels",
[
"sleap_user_labels",
"sleap_predicted_labels",
],
)
def test_sleap_user_labels_to_numpy_confidence_equals_one(
self, sleap_labels, request
):
labels = request.getfixturevalue(sleap_labels)
confidence = load_poses._sleap_user_labels_to_numpy(labels)[:, :, :, 2]
mask = np.isnan(confidence)
assert np.all(confidence[~mask] == 1)

@pytest.mark.parametrize(
"file_name",
[
Expand Down

0 comments on commit 724b8f0

Please sign in to comment.