Skip to content

Commit

Permalink
Provide a generic load_poses.from_file() function (#110)
Browse files Browse the repository at this point in the history
* added from_files() function

* added from_files() function

* Use the updated upload_pypi action (#108)

* Bump action versions in upload_pypi workflow step

* reuse the upload_pypi action instead of custom steps

* adding code review suggestionas and tests

* added log error

* added from_files() function

* added from_files() function

* adding code review suggestionas and tests

* added log error

* formatted docstrign and added to API reference

* added regex matching to ValueError test

* documented new funciton in Getting started guide

* use from_file() for fetching sample data

---------

Co-authored-by: Niko Sirmpilatze <[email protected]>
  • Loading branch information
DhruvSkyy and niksirbi authored Feb 19, 2024
1 parent d5b1335 commit b4ec636
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/source/api_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Input/Output
.. autosummary::
:toctree: api

from_file
from_sleap_file
from_dlc_file
from_dlc_df
Expand Down
15 changes: 15 additions & 0 deletions docs/source/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ Then, depending on the source of your data, use one of the following functions:
Load from [SLEAP analysis files](sleap:tutorials/analysis) (.h5):
```python
ds = load_poses.from_sleap_file("/path/to/file.analysis.h5", fps=30)

# or equivalently
ds = load_poses.from_file(
"/path/to/file.analysis.h5", source_software="SLEAP", fps=30
)
```
:::

Expand All @@ -86,6 +91,11 @@ ds = load_poses.from_dlc_file("/path/to/file.h5", fps=30)
You may also load .csv files (assuming they are formatted as DeepLabCut expects them):
```python
ds = load_poses.from_dlc_file("/path/to/file.csv", fps=30)

# or equivalently
ds = load_poses.from_file(
"/path/to/file.csv", source_software="DeepLabCut", fps=30
)
```

If you have already imported the data into a pandas DataFrame, you can
Expand All @@ -103,6 +113,11 @@ ds = load_poses.from_dlc_df(df, fps=30)
Load from LightningPose (LP) files (.csv):
```python
ds = load_poses.from_lp_file("/path/to/file.analysis.csv", fps=30)

# or equivalently
ds = load_poses.from_file(
"/path/to/file.analysis.csv", source_software="LightningPose", fps=30
)
```
:::

Expand Down
46 changes: 46 additions & 0 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,52 @@
logger = logging.getLogger(__name__)


def from_file(
file_path: Union[Path, str],
source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"],
fps: Optional[float] = None,
) -> xr.Dataset:
"""Load pose tracking data from a DeepLabCut (DLC), LightningPose (LP) or
SLEAP output file into an xarray Dataset.
Parameters
----------
file_path : pathlib.Path or str
Path to the file containing predicted poses. The file format must
be among those supported by the ``from_dlc_file()``,
``from_slp_file()`` or ``from_lp_file()`` functions,
since one of these functions will be called internally, based on
the value of ``source_software``.
source_software : "DeepLabCut", "SLEAP" or "LightningPose"
The source software of the file.
fps : float, optional
The number of frames per second in the video. If None (default),
the ``time`` coordinates will be in frame numbers.
Returns
-------
xarray.Dataset
Dataset containing the pose tracks, confidence scores, and metadata.
See Also
--------
movement.io.load_poses.from_dlc_file
movement.io.load_poses.from_sleap_file
movement.io.load_poses.from_lp_file
"""

if source_software == "DeepLabCut":
return from_dlc_file(file_path, fps)
elif source_software == "SLEAP":
return from_sleap_file(file_path, fps)
elif source_software == "LightningPose":
return from_lp_file(file_path, fps)
else:
raise log_error(
ValueError, f"Unsupported source software: {source_software}"
)


def from_dlc_df(df: pd.DataFrame, fps: Optional[float] = None) -> xr.Dataset:
"""Create an xarray.Dataset from a DeepLabCut-style pandas DataFrame.
Expand Down
11 changes: 5 additions & 6 deletions movement/sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,9 @@ def fetch_sample_data(
file for file in metadata if file["file_name"] == filename
)

if file_metadata["source_software"] == "SLEAP":
ds = load_poses.from_sleap_file(file_path, fps=file_metadata["fps"])
elif file_metadata["source_software"] == "DeepLabCut":
ds = load_poses.from_dlc_file(file_path, fps=file_metadata["fps"])
elif file_metadata["source_software"] == "LightningPose":
ds = load_poses.from_lp_file(file_path, fps=file_metadata["fps"])
ds = load_poses.from_file(
file_path,
source_software=file_metadata["source_software"],
fps=file_metadata["fps"],
)
return ds
24 changes: 24 additions & 0 deletions tests/test_unit/test_load_poses.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import patch

import h5py
import numpy as np
import pytest
Expand Down Expand Up @@ -239,3 +241,25 @@ def test_load_multi_animal_from_lp_file_raises(self):
file_path = POSE_DATA_PATHS.get("DLC_two-mice.predictions.csv")
with pytest.raises(ValueError):
load_poses.from_lp_file(file_path)

@pytest.mark.parametrize(
"source_software", ["SLEAP", "DeepLabCut", "LightningPose", "Unknown"]
)
@pytest.mark.parametrize("fps", [None, 30, 60.0])
def test_from_file_delegates_correctly(self, source_software, fps):
"""Test that the from_file() function delegates to the correct
loader function according to the source_software."""

software_to_loader = {
"SLEAP": "movement.io.load_poses.from_sleap_file",
"DeepLabCut": "movement.io.load_poses.from_dlc_file",
"LightningPose": "movement.io.load_poses.from_lp_file",
}

if source_software == "Unknown":
with pytest.raises(ValueError, match="Unsupported source"):
load_poses.from_file("some_file", source_software)
else:
with patch(software_to_loader[source_software]) as mock_loader:
load_poses.from_file("some_file", source_software, fps)
mock_loader.assert_called_with("some_file", fps)

0 comments on commit b4ec636

Please sign in to comment.