From d190ce57228f579dae345a552349cd986d4f9f4b Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Wed, 11 Dec 2024 15:25:32 +0100 Subject: [PATCH] Loading function for Anipose data (#358) * first draft of loading function * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * adapted to new dimensions order * adapted to work with new dims arrangement * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * anipose loader test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * validator for anipose file * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * anipose validator finished * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * linting fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update tests/test_unit/test_validators/test_files_validators.py Co-authored-by: Niko Sirmpilatze * simplified validator test * Update movement/io/load_poses.py Co-authored-by: Niko Sirmpilatze * Update movement/validators/files.py Co-authored-by: Niko Sirmpilatze * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update movement/validators/files.py Co-authored-by: Niko Sirmpilatze * Update movement/validators/files.py Co-authored-by: Niko Sirmpilatze * implementing fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more consistency fixes * moved anipose loading test to load_poses * fixed validators tests * tests for anipose loading done properly * docstring fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Implementing direct anipose load from from_file * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * ruffed * trying to fix mypy check * Update movement/io/load_poses.py Co-authored-by: Niko Sirmpilatze * Update movement/io/load_poses.py Co-authored-by: Niko Sirmpilatze * Update movement/io/load_poses.py Co-authored-by: Niko Sirmpilatze * Update movement/io/load_poses.py Co-authored-by: Niko Sirmpilatze * final touches to docstrings * added entry in input_output docs * define anipose link in conf.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Niko Sirmpilatze --- docs/source/conf.py | 1 + docs/source/user_guide/input_output.md | 17 +++ movement/io/load_poses.py | 133 +++++++++++++++++- movement/validators/files.py | 88 ++++++++++++ tests/conftest.py | 55 ++++++++ tests/test_unit/test_load_poses.py | 24 +++- .../test_validators/test_files_validators.py | 36 +++++ 7 files changed, 350 insertions(+), 4 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index cdf14786..cc9faf76 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -203,6 +203,7 @@ "xarray": "https://docs.xarray.dev/en/stable/{{path}}#{{fragment}}", "lp": "https://lightning-pose.readthedocs.io/en/stable/{{path}}#{{fragment}}", "via": "https://www.robots.ox.ac.uk/~vgg/software/via/{{path}}#{{fragment}}", + "anipose": "https://anipose.readthedocs.io/en/latest/", } intersphinx_mapping = { diff --git a/docs/source/user_guide/input_output.md b/docs/source/user_guide/input_output.md index a27691f0..4dd7445f 100644 --- a/docs/source/user_guide/input_output.md +++ b/docs/source/user_guide/input_output.md @@ -10,6 +10,7 @@ To analyse pose tracks, `movement` supports loading data from various frameworks - [DeepLabCut](dlc:) (DLC) - [SLEAP](sleap:) (SLEAP) - [LightingPose](lp:) (LP) +- [Anipose](anipose:) (Anipose) To analyse bounding boxes' tracks, `movement` currently supports the [VGG Image Annotator](via:) (VIA) format for [tracks annotation](via:docs/face_track_annotation.html). @@ -84,6 +85,22 @@ ds = load_poses.from_file( ``` ::: +:::{tab-item} Anipose + +To load Anipose files in .csv format: +```python +ds = load_poses.from_anipose_file( + "/path/to/file.analysis.csv", fps=30, individual_name="individual_0" +) # We can optionally specify the individual name, by default it is "individual_0" + +# or equivalently +ds = load_poses.from_file( + "/path/to/file.analysis.csv", source_software="Anipose", fps=30, individual_name="individual_0" +) + +``` +::: + :::{tab-item} From NumPy In the example below, we create random position data for two individuals, ``Alice`` and ``Bob``, diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index 745f64c1..9eaa814f 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -13,7 +13,12 @@ from movement.utils.logging import log_error, log_warning from movement.validators.datasets import ValidPosesDataset -from movement.validators.files import ValidDeepLabCutCSV, ValidFile, ValidHDF5 +from movement.validators.files import ( + ValidAniposeCSV, + ValidDeepLabCutCSV, + ValidFile, + ValidHDF5, +) logger = logging.getLogger(__name__) @@ -91,8 +96,11 @@ def from_numpy( def from_file( file_path: Path | str, - source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"], + source_software: Literal[ + "DeepLabCut", "SLEAP", "LightningPose", "Anipose" + ], fps: float | None = None, + **kwargs, ) -> xr.Dataset: """Create a ``movement`` poses dataset from any supported file. @@ -104,11 +112,14 @@ def from_file( ``from_slp_file()`` or ``from_lp_file()`` functions. One of these these functions will be called internally, based on the value of ``source_software``. - source_software : "DeepLabCut", "SLEAP" or "LightningPose" + source_software : "DeepLabCut", "SLEAP", "LightningPose", or "Anipose" The source software of the file. fps : float, optional The number of frames per second in the video. If None (default), the ``time`` coordinates will be in frame numbers. + **kwargs : dict, optional + Additional keyword arguments to pass to the software-specific + loading functions that are listed under "See Also". Returns ------- @@ -121,6 +132,7 @@ def from_file( movement.io.load_poses.from_dlc_file movement.io.load_poses.from_sleap_file movement.io.load_poses.from_lp_file + movement.io.load_poses.from_anipose_file Examples -------- @@ -136,6 +148,8 @@ def from_file( return from_sleap_file(file_path, fps) elif source_software == "LightningPose": return from_lp_file(file_path, fps) + elif source_software == "Anipose": + return from_anipose_file(file_path, fps, **kwargs) else: raise log_error( ValueError, f"Unsupported source software: {source_software}" @@ -696,3 +710,116 @@ def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset: "ds_type": "poses", }, ) + + +def from_anipose_style_df( + df: pd.DataFrame, + fps: float | None = None, + individual_name: str = "individual_0", +) -> xr.Dataset: + """Create a ``movement`` poses dataset from an Anipose 3D dataframe. + + Parameters + ---------- + df : pd.DataFrame + Anipose triangulation dataframe + fps : float, optional + The number of frames per second in the video. If None (default), + the ``time`` coordinates will be in frame units. + individual_name : str, optional + Name of the individual, by default "individual_0" + + Returns + ------- + xarray.Dataset + ``movement`` dataset containing the pose tracks, confidence scores, + and associated metadata. + + + Notes + ----- + Reshape dataframe with columns keypoint1_x, keypoint1_y, keypoint1_z, + keypoint1_score,keypoint2_x, keypoint2_y, keypoint2_z, + keypoint2_score...to array of positions with dimensions + time, space, keypoints, individuals, and array of confidence (from scores) + with dimensions time, keypoints, individuals. + + """ + keypoint_names = sorted( + list( + set( + [ + col.rsplit("_", 1)[0] + for col in df.columns + if any(col.endswith(f"_{s}") for s in ["x", "y", "z"]) + ] + ) + ) + ) + + n_frames = len(df) + n_keypoints = len(keypoint_names) + + # Initialize arrays and fill + position_array = np.zeros( + (n_frames, 3, n_keypoints, 1) + ) # 1 for single individual + confidence_array = np.zeros((n_frames, n_keypoints, 1)) + for i, kp in enumerate(keypoint_names): + for j, coord in enumerate(["x", "y", "z"]): + position_array[:, j, i, 0] = df[f"{kp}_{coord}"] + confidence_array[:, i, 0] = df[f"{kp}_score"] + + individual_names = [individual_name] + + return from_numpy( + position_array=position_array, + confidence_array=confidence_array, + individual_names=individual_names, + keypoint_names=keypoint_names, + source_software="Anipose", + fps=fps, + ) + + +def from_anipose_file( + file_path: Path | str, + fps: float | None = None, + individual_name: str = "individual_0", +) -> xr.Dataset: + """Create a ``movement`` poses dataset from an Anipose 3D .csv file. + + Parameters + ---------- + file_path : pathlib.Path + Path to the Anipose triangulation .csv file + fps : float, optional + The number of frames per second in the video. If None (default), + the ``time`` coordinates will be in frame units. + individual_name : str, optional + Name of the individual, by default "individual_0" + + Returns + ------- + xarray.Dataset + ``movement`` dataset containing the pose tracks, confidence scores, + and associated metadata. + + Notes + ----- + We currently do not load all information, only x, y, z, and score + (confidence) for each keypoint. Future versions will load n of cameras + and error. + + """ + file = ValidFile( + file_path, + expected_permission="r", + expected_suffix=[".csv"], + ) + anipose_file = ValidAniposeCSV(file.path) + anipose_df = pd.read_csv(anipose_file.path) + + return from_anipose_style_df( + anipose_df, fps=fps, individual_name=individual_name + ) diff --git a/movement/validators/files.py b/movement/validators/files.py index 0bc5c446..bfca9116 100644 --- a/movement/validators/files.py +++ b/movement/validators/files.py @@ -221,6 +221,94 @@ def _file_contains_expected_levels(self, attribute, value): ) +@define +class ValidAniposeCSV: + """Class for validating Anipose-style 3D pose .csv files. + + The validator ensures that the file contains the + expected column names in its header (first row). + + Attributes + ---------- + path : pathlib.Path + Path to the .csv file. + + Raises + ------ + ValueError + If the .csv file does not contain the expected Anipose columns. + + """ + + path: Path = field(validator=validators.instance_of(Path)) + + @path.validator + def _file_contains_expected_columns(self, attribute, value): + """Ensure that the .csv file contains the expected columns.""" + expected_column_suffixes = [ + "_x", + "_y", + "_z", + "_score", + "_error", + "_ncams", + ] + expected_non_keypoint_columns = [ + "fnum", + "center_0", + "center_1", + "center_2", + "M_00", + "M_01", + "M_02", + "M_10", + "M_11", + "M_12", + "M_20", + "M_21", + "M_22", + ] + + # Read the first line of the CSV to get the headers + with open(value) as f: + columns = f.readline().strip().split(",") + + # Check that all expected headers are present + if not all(col in columns for col in expected_non_keypoint_columns): + raise log_error( + ValueError, + "CSV file is missing some expected columns." + f"Expected: {expected_non_keypoint_columns}.", + ) + + # For other headers, check they have expected suffixes and base names + other_columns = [ + col for col in columns if col not in expected_non_keypoint_columns + ] + for column in other_columns: + # Check suffix + if not any( + column.endswith(suffix) for suffix in expected_column_suffixes + ): + raise log_error( + ValueError, + f"Column {column} ends with an unexpected suffix.", + ) + # Get base name by removing suffix + base = column.rsplit("_", 1)[0] + # Check base name has all expected suffixes + if not all( + f"{base}{suffix}" in columns + for suffix in expected_column_suffixes + ): + raise log_error( + ValueError, + f"Keypoint {base} is missing some expected suffixes." + f"Expected: {expected_column_suffixes};" + f"Got: {columns}.", + ) + + @define class ValidVIATracksCSV: """Class for validating VIA tracks .csv files. diff --git a/tests/conftest.py b/tests/conftest.py index 5443941e..4a760514 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -199,6 +199,61 @@ def dlc_style_df(): return pd.read_hdf(pytest.DATA_PATHS.get("DLC_single-wasp.predictions.h5")) +@pytest.fixture +def missing_keypoint_columns_anipose_csv_file(tmp_path): + """Return the file path for a fake single-individual .csv file.""" + file_path = tmp_path / "missing_keypoint_columns.csv" + columns = [ + "fnum", + "center_0", + "center_1", + "center_2", + "M_00", + "M_01", + "M_02", + "M_10", + "M_11", + "M_12", + "M_20", + "M_21", + "M_22", + ] + # Here we are missing kp0_z: + columns.extend(["kp0_x", "kp0_y", "kp0_score", "kp0_error", "kp0_ncams"]) + with open(file_path, "w") as f: + f.write(",".join(columns)) + f.write("\n") + f.write(",".join(["1"] * len(columns))) + return file_path + + +@pytest.fixture +def spurious_column_anipose_csv_file(tmp_path): + """Return the file path for a fake single-individual .csv file.""" + file_path = tmp_path / "spurious_column.csv" + columns = [ + "fnum", + "center_0", + "center_1", + "center_2", + "M_00", + "M_01", + "M_02", + "M_10", + "M_11", + "M_12", + "M_20", + "M_21", + "M_22", + ] + columns.extend(["funny_column"]) + with open(file_path, "w") as f: + f.write(",".join(columns)) + f.write("\n") + f.write(",".join(["1"] * len(columns))) + return file_path + + @pytest.fixture( params=[ "SLEAP_single-mouse_EPM.analysis.h5", diff --git a/tests/test_unit/test_load_poses.py b/tests/test_unit/test_load_poses.py index 91bee074..8c07ae9d 100644 --- a/tests/test_unit/test_load_poses.py +++ b/tests/test_unit/test_load_poses.py @@ -257,7 +257,8 @@ def test_load_multi_individual_from_lp_file_raises(): @pytest.mark.parametrize( - "source_software", ["SLEAP", "DeepLabCut", "LightningPose", "Unknown"] + "source_software", + ["SLEAP", "DeepLabCut", "LightningPose", "Anipose", "Unknown"], ) @pytest.mark.parametrize("fps", [None, 30, 60.0]) def test_from_file_delegates_correctly(source_software, fps): @@ -268,6 +269,7 @@ def test_from_file_delegates_correctly(source_software, fps): "SLEAP": "movement.io.load_poses.from_sleap_file", "DeepLabCut": "movement.io.load_poses.from_dlc_file", "LightningPose": "movement.io.load_poses.from_lp_file", + "Anipose": "movement.io.load_poses.from_anipose_file", } if source_software == "Unknown": with pytest.raises(ValueError, match="Unsupported source"): @@ -318,3 +320,23 @@ def test_from_multiview_files(): assert isinstance(multi_view_ds, xr.Dataset) assert "view" in multi_view_ds.dims assert multi_view_ds.view.values.tolist() == view_names + + +def test_load_from_anipose_file(): + """Test that loading pose tracks from an Anipose triangulation + csv file returns the same Dataset. + """ + file_path = DATA_PATHS.get( + "anipose_mouse-paw_anipose-paper.triangulation.csv" + ) + ds = load_poses.from_anipose_file(file_path) + assert ds.position.shape == (246, 3, 6, 1) + assert ds.confidence.shape == (246, 6, 1) + assert ds.coords["keypoints"].values.tolist() == [ + "l-base", + "l-edge", + "l-middle", + "r-base", + "r-edge", + "r-middle", + ] diff --git a/tests/test_unit/test_validators/test_files_validators.py b/tests/test_unit/test_validators/test_files_validators.py index 7ce0a722..b3149d64 100644 --- a/tests/test_unit/test_validators/test_files_validators.py +++ b/tests/test_unit/test_validators/test_files_validators.py @@ -1,6 +1,7 @@ import pytest from movement.validators.files import ( + ValidAniposeCSV, ValidDeepLabCutCSV, ValidFile, ValidHDF5, @@ -175,3 +176,38 @@ def test_via_tracks_csv_validator_with_invalid_input( ValidVIATracksCSV(file_path) assert str(excinfo.value) == log_message + + +@pytest.mark.parametrize( + "invalid_input, log_message", + [ + ( + "invalid_single_individual_csv_file", + "CSV file is missing some expected columns.", + ), + ( + "missing_keypoint_columns_anipose_csv_file", + "Keypoint kp0 is missing some expected suffixes.", + ), + ( + "spurious_column_anipose_csv_file", + "Column funny_column ends with an unexpected suffix.", + ), + ], +) +def test_anipose_csv_validator_with_invalid_input( + invalid_input, log_message, request +): + """Test that invalid Anipose .csv files raise the appropriate errors. + + Errors to check: + - error if .csv is missing some columns + - error if .csv misses some of the expected columns for a keypoint + - error if .csv has columns that are not expected + (either common ones or keypoint-specific ones) + """ + file_path = request.getfixturevalue(invalid_input) + with pytest.raises(ValueError) as excinfo: + ValidAniposeCSV(file_path) + + assert log_message in str(excinfo.value)