Skip to content

Commit

Permalink
SLEAP read user labels (#79)
Browse files Browse the repository at this point in the history
Co-authored-by: Niko Sirmpilatze <[email protected]>
  • Loading branch information
lochhh and niksirbi authored Nov 16, 2023
1 parent 73858ad commit 3b61db6
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 32 deletions.
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
linkcheck_anchors_ignore_for_url = [
"https://gin.g-node.org/G-Node/Info/wiki/",
"https://neuroinformatics.zulipchat.com/",
"https://github.com/talmolab/sleap/blob/v1.3.3/sleap/info/write_tracking_h5.py",
]

myst_url_schemes = {
Expand Down
96 changes: 80 additions & 16 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd
import xarray as xr
from sleap_io.io.slp import read_labels
from sleap_io.model.labels import Labels

from movement.io.poses_accessor import PosesAccessor
from movement.io.validators import (
Expand Down Expand Up @@ -108,20 +109,21 @@ def from_sleap_file(
"Export Analysis HDF5…" from the "File" menu) [1]_. This is the
preferred format for loading pose tracks from SLEAP into *movement*.
You can also try directly loading the ".slp" file, but this feature is
experimental and doesnot work in all cases. If the ".slp" file contains
both user-labeled and predicted instances, only the predicted ones will be
loaded. If there are multiple videos in the file, only the first one will
be used.
You can also directly load the ".slp" file. However, if the file contains
multiple videos, only the pose tracks from the first video will be loaded.
If the file contains a mix of user-labelled and predicted instances, user
labels are prioritised over predicted instances to mirror SLEAP's approach
when exporting ".h5" analysis files [2]_.
*movement* expects the tracks to be assigned and proofread before loading
them, meaning each track is interpreted as a single individual/animal.
Follow the SLEAP guide for tracking and proofreading [2]_.
Follow the SLEAP guide for tracking and proofreading [3]_.
References
----------
.. [1] https://sleap.ai/tutorials/analysis.html
.. [2] https://sleap.ai/guides/proofreading.html
.. [2] https://github.com/talmolab/sleap/blob/v1.3.3/sleap/info/write_tracking_h5.py#L129-L150
.. [3] https://sleap.ai/guides/proofreading.html
Examples
--------
Expand Down Expand Up @@ -235,15 +237,14 @@ def _load_from_sleap_analysis_file(
# transpose to shape: (n_frames, n_tracks, n_keypoints, n_space)
tracks = f["tracks"][:].transpose((3, 0, 2, 1))
# Create an array of NaNs for the confidence scores
scores = np.full(tracks.shape[:-1], np.nan, dtype="float32")
scores = np.full(tracks.shape[:-1], np.nan)
# If present, read the point-wise scores,
# and transpose to shape: (n_frames, n_tracks, n_keypoints)
if "point_scores" in f.keys():
scores = f["point_scores"][:].transpose((2, 0, 1))

return ValidPoseTracks(
tracks_array=tracks,
scores_array=scores,
tracks_array=tracks.astype(np.float32),
scores_array=scores.astype(np.float32),
individual_names=[n.decode() for n in f["track_names"][:]],
keypoint_names=[n.decode() for n in f["node_names"][:]],
fps=fps,
Expand Down Expand Up @@ -271,10 +272,8 @@ def _load_from_sleap_labels_file(
"""

file = ValidHDF5(file_path, expected_datasets=["pred_points", "metadata"])

labels = read_labels(file.path.as_posix())
tracks_with_scores = labels.numpy(untracked=False, return_confidence=True)

tracks_with_scores = _sleap_labels_to_numpy(labels)
return ValidPoseTracks(
tracks_array=tracks_with_scores[:, :, :, :-1],
scores_array=tracks_with_scores[:, :, :, -1],
Expand All @@ -284,16 +283,81 @@ def _load_from_sleap_labels_file(
)


def _sleap_labels_to_numpy(labels: Labels) -> np.ndarray:
"""Convert a SLEAP `Labels` object to a NumPy array containing
pose tracks with point-wise confidence scores.
Parameters
----------
labels : Labels
A SLEAP `Labels` object.
Returns
-------
numpy.ndarray
A NumPy array containing pose tracks and confidence scores.
Notes
-----
This function only considers SLEAP instances in the first
video of the SLEAP `Labels` object. User-labelled instances are
prioritised over predicted instances, mirroring SLEAP's approach
when exporting ".h5" analysis files [1]_.
This function is adapted from `Labels.numpy()` from the
`sleap_io` package [2]_.
References
----------
.. [1] https://github.com/talmolab/sleap/blob/v1.3.3/sleap/info/write_tracking_h5.py#L129-L150
.. [2] https://github.com/talmolab/sleap-io
"""
# Select frames from the first video only
lfs = [lf for lf in labels.labeled_frames if lf.video == labels.videos[0]]
# Figure out frame index range
frame_idxs = [lf.frame_idx for lf in lfs]
first_frame = min(0, min(frame_idxs))
last_frame = max(0, max(frame_idxs))

n_tracks = len(labels.tracks)
skeleton = labels.skeletons[-1] # Assume project only uses last skeleton
n_nodes = len(skeleton.nodes)
n_frames = int(last_frame - first_frame + 1)
tracks = np.full((n_frames, n_tracks, n_nodes, 3), np.nan, dtype="float32")

for lf in lfs:
i = int(lf.frame_idx - first_frame)
user_instances = lf.user_instances
predicted_instances = lf.predicted_instances
for j, track in enumerate(labels.tracks):
user_track_instances = [
inst for inst in user_instances if inst.track == track
]
predicted_track_instances = [
inst for inst in predicted_instances if inst.track == track
]
# Use user-labelled instance if available
if user_track_instances:
inst = user_track_instances[0]
tracks[i, j] = np.hstack(
(inst.numpy(), np.full((n_nodes, 1), np.nan))
)
elif predicted_track_instances:
inst = predicted_track_instances[0]
tracks[i, j] = inst.numpy(scores=True)
return tracks


def _parse_dlc_csv_to_df(file_path: Path) -> pd.DataFrame:
"""If poses are loaded from a DeepLabCut .csv file, the DataFrame
lacks the multi-index columns that are present in the .h5 file. This
function parses the csv file to a pandas DataFrame with multi-index
function parses the .csv file to a pandas DataFrame with multi-index
columns, i.e. the same format as in the .h5 file.
Parameters
----------
file_path : pathlib.Path
Path to the DeepLabCut-style CSV file.
Path to the DeepLabCut-style .csv file.
Returns
-------
Expand Down
16 changes: 8 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def setup_logging(tmp_path):
@pytest.fixture
def unreadable_file(tmp_path):
"""Return a dictionary containing the file path and
expected permission for an unreadable h5 file."""
expected permission for an unreadable .h5 file."""
file_path = tmp_path / "unreadable.h5"
with open(file_path, "w") as f:
f.write("unreadable data")
Expand All @@ -61,7 +61,7 @@ def unreadable_file(tmp_path):
@pytest.fixture
def unwriteable_file(tmp_path):
"""Return a dictionary containing the file path and
expected permission for an unwriteable h5 file."""
expected permission for an unwriteable .h5 file."""
unwriteable_dir = tmp_path / "no_write"
unwriteable_dir.mkdir()
os.chmod(unwriteable_dir, not stat.S_IWUSR)
Expand Down Expand Up @@ -101,7 +101,7 @@ def nonexistent_file(tmp_path):


@pytest.fixture
def directory(tmp_path): # used in save_poses, validators
def directory(tmp_path):
"""Return a dictionary containing the file path and
expected permission for a directory."""
file_path = tmp_path / "directory"
Expand All @@ -115,7 +115,7 @@ def directory(tmp_path): # used in save_poses, validators
@pytest.fixture
def h5_file_no_dataframe(tmp_path):
"""Return a dictionary containing the file path and
expected datasets for an h5 file with no dataframe."""
expected datasets for a .h5 file with no dataframe."""
file_path = tmp_path / "no_dataframe.h5"
with h5py.File(file_path, "w") as f:
f.create_dataset("data_in_list", data=[1, 2, 3])
Expand All @@ -126,7 +126,7 @@ def h5_file_no_dataframe(tmp_path):


@pytest.fixture
def fake_h5_file(tmp_path): # used in save_poses, validators
def fake_h5_file(tmp_path):
"""Return a dictionary containing the file path,
expected exception, and expected datasets for
a file with .h5 extension that is not in HDF5 format.
Expand All @@ -143,7 +143,7 @@ def fake_h5_file(tmp_path): # used in save_poses, validators

@pytest.fixture
def invalid_single_animal_csv_file(tmp_path):
"""Return the file path for a fake single-animal csv file."""
"""Return the file path for a fake single-animal .csv file."""
file_path = tmp_path / "fake_single_animal.csv"
with open(file_path, "w") as f:
f.write("scorer,columns\nsome,columns\ncoords,columns\n")
Expand All @@ -153,7 +153,7 @@ def invalid_single_animal_csv_file(tmp_path):

@pytest.fixture
def invalid_multi_animal_csv_file(tmp_path):
"""Return the file path for a fake multi-animal csv file."""
"""Return the file path for a fake multi-animal .csv file."""
file_path = tmp_path / "fake_multi_animal.csv"
with open(file_path, "w") as f:
f.write(
Expand Down Expand Up @@ -186,7 +186,7 @@ def sleap_file(request):

@pytest.fixture
def valid_tracks_array():
"""Return a function that generate different kinds
"""Return a function that generates different kinds
of valid tracks array."""

def _valid_tracks_array(array_type):
Expand Down
23 changes: 17 additions & 6 deletions tests/test_integration/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,30 @@
class TestPosesIO:
"""Test the IO functionalities of the PoseTracks class."""

@pytest.fixture(params=["dlc.h5", "dlc.csv"])
def dlc_output_file(self, request, tmp_path):
"""Return the output file path for a DLC .h5 or .csv file."""
return tmp_path / request.param

def test_load_and_save_to_dlc_df(self, dlc_style_df):
"""Test that loading pose tracks from a DLC-style DataFrame and
converting back to a DataFrame returns the same data values."""
ds = load_poses.from_dlc_df(dlc_style_df)
df = save_poses.to_dlc_df(ds)
np.testing.assert_allclose(df.values, dlc_style_df.values)

@pytest.mark.parametrize("file_name", ["dlc.h5", "dlc.csv"])
def test_save_and_load_dlc_file(
self, file_name, valid_pose_dataset, tmp_path
):
def test_save_and_load_dlc_file(self, dlc_output_file, valid_pose_dataset):
"""Test that saving pose tracks to DLC .h5 and .csv files and then
loading them back in returns the same Dataset."""
save_poses.to_dlc_file(valid_pose_dataset, tmp_path / file_name)
ds = load_poses.from_dlc_file(tmp_path / file_name)
save_poses.to_dlc_file(valid_pose_dataset, dlc_output_file)
ds = load_poses.from_dlc_file(dlc_output_file)
xr.testing.assert_allclose(ds, valid_pose_dataset)

def test_convert_sleap_to_dlc_file(self, sleap_file, dlc_output_file):
"""Test that pose tracks loaded from SLEAP .slp and .h5 files,
when converted to DLC .h5 and .csv files and re-loaded return
the same Datasets."""
sleap_ds = load_poses.from_sleap_file(sleap_file)
save_poses.to_dlc_file(sleap_ds, dlc_output_file)
dlc_ds = load_poses.from_dlc_file(dlc_output_file)
xr.testing.assert_allclose(sleap_ds, dlc_ds)
28 changes: 28 additions & 0 deletions tests/test_unit/test_load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,34 @@ def test_load_from_slp_file(self, sleap_file):
ds = load_poses.from_sleap_file(sleap_file)
self.assert_dataset(ds, sleap_file, "SLEAP")

@pytest.mark.parametrize(
"slp_file, h5_file",
[
(
"SLEAP_single-mouse_EPM.analysis.h5",
"SLEAP_single-mouse_EPM.predictions.slp",
),
(
"SLEAP_three-mice_Aeon_proofread.analysis.h5",
"SLEAP_three-mice_Aeon_proofread.predictions.slp",
),
(
"SLEAP_three-mice_Aeon_mixed-labels.analysis.h5",
"SLEAP_three-mice_Aeon_mixed-labels.predictions.slp",
),
],
)
def test_load_from_sleap_slp_file_or_h5_file_returns_same(
self, slp_file, h5_file
):
"""Test that loading pose tracks from SLEAP .slp and .h5 files
return the same Dataset."""
slp_file_path = POSE_DATA.get(slp_file)
h5_file_path = POSE_DATA.get(h5_file)
ds_from_slp = load_poses.from_sleap_file(slp_file_path)
ds_from_h5 = load_poses.from_sleap_file(h5_file_path)
xr.testing.assert_allclose(ds_from_h5, ds_from_slp)

@pytest.mark.parametrize(
"file_name",
[
Expand Down
18 changes: 16 additions & 2 deletions tests/test_unit/test_save_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def new_file_wrong_ext(self, tmp_path):

@pytest.fixture
def new_dlc_h5_file(self, tmp_path):
"""Return the file path for a new DeepLabCut H5 file."""
"""Return the file path for a new DeepLabCut .h5 file."""
return tmp_path / "new_dlc_file.h5"

@pytest.fixture
def new_dlc_csv_file(self, tmp_path):
"""Return the file path for a new DeepLabCut csv file."""
"""Return the file path for a new DeepLabCut .csv file."""
return tmp_path / "new_dlc_file.csv"

@pytest.fixture
Expand All @@ -63,6 +63,20 @@ def missing_dim_dataset(self, valid_pose_dataset):
),
does_not_raise(),
), # valid dataset
(
load_poses.from_sleap_file(
POSE_DATA.get("SLEAP_single-mouse_EPM.analysis.h5")
),
does_not_raise(),
), # valid dataset
(
load_poses.from_sleap_file(
POSE_DATA.get(
"SLEAP_three-mice_Aeon_proofread.predictions.slp"
)
),
does_not_raise(),
), # valid dataset
],
)
def test_to_dlc_df(self, ds, expected_exception):
Expand Down

0 comments on commit 3b61db6

Please sign in to comment.