Skip to content

Commit

Permalink
Added functions for loading/saving LightningPose data (#92)
Browse files Browse the repository at this point in the history
* Added fucntion to load LP data and tests

* fixed code smell

* fixed comments from draft pull request

* fix minor inconsistencies in docstrings

* ensure docs are up-to-date and consistent

* validate that LP datasets are single-individual

* added function for saving poses to LP file

---------

Co-authored-by: niksirbi <[email protected]>
  • Loading branch information
DhruvSkyy and niksirbi authored Dec 19, 2023
1 parent cc35888 commit a3671df
Show file tree
Hide file tree
Showing 14 changed files with 257 additions and 21 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ instance/
docs/build/
docs/source/examples/
docs/source/api/
sg_execution_times.rst

# MkDocs documentation
/site/
Expand Down
2 changes: 2 additions & 0 deletions docs/source/api_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Input/Output
from_sleap_file
from_dlc_file
from_dlc_df
from_lp_file

.. currentmodule:: movement.io.save_poses
.. autosummary::
Expand All @@ -21,6 +22,7 @@ Input/Output
to_dlc_file
to_dlc_df
to_sleap_analysis_file
to_lp_file

.. currentmodule:: movement.io.validators
.. autosummary::
Expand Down
2 changes: 1 addition & 1 deletion docs/source/community/roadmap.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ The following features are being considered for the first stable version `v1.0`.
## Short-term milestone - `v0.1`
We plan to release version `v0.1` of movement in early 2024, providing a minimal set of features to demonstrate the project's potential and to gather feedback from users. At minimum, it should include the following features:

- Importing pose tracks from [DeepLabCut](dlc:) and [SLEAP](sleap:) into a common `xarray.Dataset` structure. This has been largely accomplished, but some remaining work is required to handle special cases.
- Importing pose tracks from [DeepLabCut](dlc:), [SLEAP](sleap:) and [LightningPose](lp:) into a common `xarray.Dataset` structure. This has been already accomplished.
- Visualisation of pose tracks using [napari](napari:). We aim to represent pose tracks via the [napari tracks layer](napari:howtos/layers/tracks) and overlay them on a video frame. This should be accompanied by a minimal GUI widget to allow selection of a subset of the tracks to plot. This line of work is still in a pilot phase. We may decide to use a different visualisation framework if we encounter roadblocks.
- At least one function for cleaning the pose tracks. Once the first one is in place, it can serve as a template for others.
- Computing velocity and acceleration from pose tracks. Again, this should serve as a template for other kinematic variables.
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,5 @@
"sleap": "https://sleap.ai/{{path}}#{{fragment}}",
"sphinx-gallery": "https://sphinx-gallery.github.io/stable/{{path}}",
"xarray": "https://docs.xarray.dev/en/stable/{{path}}#{{fragment}}",
"lp": "https://lightning-pose.readthedocs.io/en/stable/{{path}}#{{fragment}}",
}
37 changes: 30 additions & 7 deletions docs/source/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,26 +61,26 @@ First import the `movement.io.load_poses` module:
from movement.io import load_poses
```

Then, use the `from_dlc_file` or `from_sleap_file` functions to load the data.
Then, depending on the source of your data, use one of the following functions:

::::{tab-set}

:::{tab-item} SLEAP

Load from [SLEAP analysis files](sleap:tutorials/analysis) (`.h5`):
Load from [SLEAP analysis files](sleap:tutorials/analysis) (.h5):
```python
ds = load_poses.from_sleap_file("/path/to/file.analysis.h5", fps=30)
```
:::

:::{tab-item} DeepLabCut

Load pose estimation outputs from `.h5` files:
Load pose estimation outputs from .h5 files:
```python
ds = load_poses.from_dlc_file("/path/to/file.h5", fps=30)
```

You may also load `.csv` files (assuming they are formatted as DeepLabCut expects them):
You may also load .csv files (assuming they are formatted as DeepLabCut expects them):
```python
ds = load_poses.from_dlc_file("/path/to/file.csv", fps=30)
```
Expand All @@ -95,6 +95,14 @@ ds = load_poses.from_dlc_df(df, fps=30)
```
:::

:::{tab-item} LightningPose

Load from LightningPose (LP) files (.csv):
```python
ds = load_poses.from_lp_file("/path/to/file.analysis.csv", fps=30)
```
:::

::::

You can also try movement out on some sample data included in the package.
Expand Down Expand Up @@ -215,13 +223,13 @@ First import the `movement.io.save_poses` module:
from movement.io import save_poses
```

Then, use the `to_dlc_file` or `to_sleap_analysis_file` functions to save the data.
Then, depending on the desired format, use one of the following functions:

:::::{tab-set}

::::{tab-item} SLEAP

Save to SLEAP-style analysis files (`.h5`):
Save to SLEAP-style analysis files (.h5):
```python
save_poses.to_sleap_analysis_file(ds, "/path/to/file.h5")
```
Expand All @@ -240,7 +248,7 @@ each attribute and data variable represents, see the

::::{tab-item} DeepLabCut

Save to DeepLabCut-style files (`.h5` or `.csv`):
Save to DeepLabCut-style files (.h5 or .csv):
```python
save_poses.to_dlc_file(ds, "/path/to/file.h5") # preferred format
save_poses.to_dlc_file(ds, "/path/to/file.csv")
Expand All @@ -254,4 +262,19 @@ df = save_poses.to_dlc_df(ds)
and then save it to file using any `pandas` method, e.g. `to_hdf` or `to_csv`.
::::

::::{tab-item} LightningPose

Save to LightningPose (LP) files (.csv).
```python
save_poses.to_lp_file(ds, "/path/to/file.csv")
```
:::{note}
Because LP saves pose estimation outputs in the same format as single-animal
DeepLabCut projects, the above command is equivalent to:
```python
save_poses.to_dlc_file(ds, "/path/to/file.csv", split_individuals=True)
```
:::

::::
:::::
2 changes: 1 addition & 1 deletion examples/load_and_explore_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

# %%
# The loaded dataset contains two data variables:
# ``pose_tracks`` and ``confidence```
# ``pose_tracks`` and ``confidence``.
# To get the pose tracks:
pose_tracks = ds.pose_tracks

Expand Down
79 changes: 74 additions & 5 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import Optional, Union
from typing import Literal, Optional, Union

import h5py
import numpy as np
Expand Down Expand Up @@ -160,6 +160,36 @@ def from_sleap_file(
return ds


def from_lp_file(
file_path: Union[Path, str], fps: Optional[float] = None
) -> xr.Dataset:
"""Load pose tracking data from a LightningPose (LP) output file
into an xarray Dataset.
Parameters
----------
file_path : pathlib.Path or str
Path to the file containing the LP predicted poses, in .csv format.
fps : float, optional
The number of frames per second in the video. If None (default),
the `time` coordinates will be in frame numbers.
Returns
-------
xarray.Dataset
Dataset containing the pose tracks, confidence scores, and metadata.
Examples
--------
>>> from movement.io import load_poses
>>> ds = load_poses.from_lp_file("path/to/file.csv", fps=30)
"""

return _from_lp_or_dlc_file(
file_path=file_path, source_software="LightningPose", fps=fps
)


def from_dlc_file(
file_path: Union[Path, str], fps: Optional[float] = None
) -> xr.Dataset:
Expand Down Expand Up @@ -190,10 +220,41 @@ def from_dlc_file(
>>> ds = load_poses.from_dlc_file("path/to/file.h5", fps=30)
"""

return _from_lp_or_dlc_file(
file_path=file_path, source_software="DeepLabCut", fps=fps
)


def _from_lp_or_dlc_file(
file_path: Union[Path, str],
source_software: Literal["LightningPose", "DeepLabCut"],
fps: Optional[float] = None,
) -> xr.Dataset:
"""Loads pose tracking data from a DeepLabCut (DLC) or
a LightningPose (LP) output file into an xarray Dataset.
Parameters
----------
file_path : pathlib.Path or str
Path to the file containing the DLC predicted poses, either in .h5
or .csv format.
source_software : {'LightningPose', 'DeepLabCut'}
fps : float, optional
The number of frames per second in the video. If None (default),
the `time` coordinates will be in frame numbers.
Returns
-------
xarray.Dataset
Dataset containing the pose tracks, confidence scores, and metadata.
"""

expected_suffix = [".csv"]
if source_software == "DeepLabCut":
expected_suffix.append(".h5")

file = ValidFile(
file_path,
expected_permission="r",
expected_suffix=[".csv", ".h5"],
file_path, expected_permission="r", expected_suffix=expected_suffix
)

# Load the DLC poses into a DataFrame
Expand All @@ -207,9 +268,15 @@ def from_dlc_file(
ds = from_dlc_df(df=df, fps=fps)

# Add metadata as attrs
ds.attrs["source_software"] = "DeepLabCut"
ds.attrs["source_software"] = source_software
ds.attrs["source_file"] = file.path.as_posix()

# If source_software="LightningPose", we need to re-validate (because the
# validation call in from_dlc_df was run with source_software="DeepLabCut")
# This rerun enforces a single individual for LightningPose datasets.
if source_software == "LightningPose":
ds.poses.validate()

logger.info(f"Loaded pose tracks from {file.path}:")
logger.info(ds)
return ds
Expand Down Expand Up @@ -259,6 +326,7 @@ def _load_from_sleap_analysis_file(
individual_names=individual_names,
keypoint_names=[n.decode() for n in f["node_names"][:]],
fps=fps,
source_software="SLEAP",
)


Expand Down Expand Up @@ -298,6 +366,7 @@ def _load_from_sleap_labels_file(
individual_names=individual_names,
keypoint_names=[kp.name for kp in labels.skeletons[0].nodes],
fps=fps,
source_software="SLEAP",
)


Expand Down
2 changes: 2 additions & 0 deletions movement/io/poses_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,15 @@ def __init__(self, ds: xr.Dataset):
def validate(self) -> None:
"""Validate the PoseTracks dataset."""
fps = self._obj.attrs.get("fps", None)
source_software = self._obj.attrs.get("source_software", None)
try:
ValidPoseTracks(
tracks_array=self._obj[self.var_names[0]].values,
scores_array=self._obj[self.var_names[1]].values,
individual_names=self._obj.coords[self.dim_names[1]].values,
keypoint_names=self._obj.coords[self.dim_names[2]].values,
fps=fps,
source_software=source_software,
)
except Exception as e:
error_msg = "The dataset does not contain valid pose tracks."
Expand Down
43 changes: 39 additions & 4 deletions movement/io/save_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,18 @@ def to_dlc_file(
file_path : pathlib.Path or str
Path to the file to save the DLC poses to. The file extension
must be either .h5 (recommended) or .csv.
split_individuals : bool, optional
split_individuals : bool or "auto", optional
Whether to save individuals to separate files or to the same file.\n
If True, each individual will be saved to a separate file,
formatted as in a single-animal DeepLabCut project - i.e. without
the "individuals" column level. The individual's name will be appended
to the file path, just before the file extension, i.e.
"/path/to/filename_individual1.h5".
"/path/to/filename_individual1.h5".\n
If False, all individuals will be saved to the same file,
formatted as in a multi-animal DeepLabCut project - i.e. the columns
will include the "individuals" level. The file path will not be
modified.
If "auto" the argument's value be determined based on the number of
modified.\n
If "auto" the argument's value is determined based on the number of
individuals in the dataset: True if there is only one, and
False if there are more than one. This is the default.
Expand Down Expand Up @@ -226,6 +227,40 @@ def to_dlc_file(
logger.info(f"Saved PoseTracks dataset to {file.path}.")


def to_lp_file(
ds: xr.Dataset,
file_path: Union[str, Path],
) -> None:
"""Save the xarray dataset containing pose tracks to a LightningPose-style
.csv file. See Notes for more details.
Parameters
----------
ds : xarray.Dataset
Dataset containing pose tracks, confidence scores, and metadata.
file_path : pathlib.Path or str
Path to the .csv file to save the poses to.
Notes
-----
LightningPose saves pose estimation outputs as .csv files, using the same
format as single-animal DeepLabCut projects. Therefore, under the hood,
this function calls ``to_dlc_file`` with ``split_individuals=True``. This
setting means that each individual is saved to a separate file, with
the individual's name appended to the file path, just before the file
extension, i.e. "/path/to/filename_individual1.csv".
See Also
--------
to_dlc_file : Save the xarray dataset containing pose tracks to a
DeepLabCut-style .h5 or .csv file.
"""

file = _validate_file_path(file_path=file_path, expected_suffix=[".csv"])
_validate_dataset(ds)
to_dlc_file(ds, file.path, split_individuals=True)


def to_sleap_analysis_file(
ds: xr.Dataset, file_path: Union[str, Path]
) -> None:
Expand Down
13 changes: 12 additions & 1 deletion movement/io/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ class ValidPoseTracks:
etc.
fps : float, optional
Frames per second of the video. Defaults to None.
source_software : str, optional
Name of the software from which the pose tracks were loaded.
Defaults to None.
"""

# Define class attributes
Expand All @@ -285,6 +288,10 @@ class ValidPoseTracks:
converters.optional(float), _set_fps_to_none_if_invalid
),
)
source_software: Optional[str] = field(
default=None,
validator=validators.optional(validators.instance_of(str)),
)

# Add validators
@tracks_array.validator
Expand Down Expand Up @@ -316,7 +323,11 @@ def _validate_scores_array(self, attribute, value):

@individual_names.validator
def _validate_individual_names(self, attribute, value):
_validate_list_length(attribute, value, self.tracks_array.shape[1])
if self.source_software == "LightningPose":
# LightningPose only supports a single individual
_validate_list_length(attribute, value, 1)
else:
_validate_list_length(attribute, value, self.tracks_array.shape[1])

@keypoint_names.validator
def _validate_keypoint_names(self, attribute, value):
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def pytest_configure():
"SLEAP_three-mice_Aeon_proofread.predictions.slp",
"SLEAP_three-mice_Aeon_mixed-labels.analysis.h5",
"SLEAP_three-mice_Aeon_mixed-labels.predictions.slp",
"LP_mouse-face_AIND.predictions.csv",
"LP_mouse-twoview_AIND.predictions.csv",
]
}

Expand Down
Loading

0 comments on commit a3671df

Please sign in to comment.