From 9f65a963be2a45f9b2205d42e15353c428e91601 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 10 Oct 2024 16:44:09 +0100 Subject: [PATCH] fixed and refactor compute_path_length and its tests --- movement/analysis/kinematics.py | 187 +++++++++++++++++++-------- pyproject.toml | 1 + tests/conftest.py | 34 +++++ tests/test_unit/test_kinematics.py | 194 ++++++++++++++++++++++++++--- 4 files changed, 352 insertions(+), 64 deletions(-) diff --git a/movement/analysis/kinematics.py b/movement/analysis/kinematics.py index dd596b224..456b528e4 100644 --- a/movement/analysis/kinematics.py +++ b/movement/analysis/kinematics.py @@ -6,6 +6,7 @@ import xarray as xr from movement.utils.logging import log_error, log_warning +from movement.utils.reports import report_nan_values from movement.utils.vector import compute_norm from movement.validators.arrays import validate_dims_coords @@ -226,13 +227,15 @@ def compute_path_length( - ``"drop"``: drop any NaN values before computing path length. This is the default behavior, and it equates to assuming that a track follows a straight line between two valid points flanking a missing - segment. This approach tends to underestimate the path length, - and the error increases with the number of missing values. + segment. Missing segments at the beginning or end of the specified + time range are not counted. This approach tends to underestimate + the path length, and the error increases with the number of missing + values. - ``"scale"``: scale path length based on the proportion of valid - values per point track. For example, if only 80% of the values are - present, the path length will be computed based on these values, - and the result will be multiplied by 1/0.8 = 1.25. This approach - assumes that the point's dynamics are similar across present + segments per point track. For example, if only 80% of segments + are present, the path length will be computed based on these + and the result will be divided by 0.8. This approach + assumes that motion dynamics are similar across present and missing time segments, which may not be the case. nan_warn_threshold : float, optional If more than this proportion of values are missing in any point track, @@ -246,53 +249,15 @@ def compute_path_length( and ``space`` which will be removed. """ - # We validate the time dimension here, on top of its later validation - # inside compute_displacement, because we rely on it for start/stop times. - validate_dims_coords(data, {"time": []}) _validate_start_stop_times(data, start, stop) - - # Select data within the specified time range data = data.sel(time=slice(start, stop)) - - # Emit a warning for point tracks with many missing values - nan_counts = data.isnull().any(dim="space").sum(dim="time") - dims_to_stack = [dim for dim in data.dims if dim not in ["time", "space"]] - # Stack individual and keypoints dims into a single 'tracks' dimension - stacked_nan_counts = nan_counts.stack(tracks=dims_to_stack) - tracks_with_warning = stacked_nan_counts.where( - stacked_nan_counts > nan_warn_threshold, drop=True - ).tracks.values - if len(tracks_with_warning) > 0: - log_warning( - "The following point tracks have more than " - f"{nan_warn_threshold * 100}% missing values, which may lead to " - "unreliable path length estimates: " - f"{', '.join(tracks_with_warning)}." - ) + _warn_about_nan_proportion(data, nan_warn_threshold) if nan_policy == "drop": - stacked_data = data.stack(tracks=dims_to_stack) - # Create an empty data array to hold the path length for each track - stacked_path_length = xr.zeros_like(stacked_nan_counts) - # Compute path length for each track - for track_name in stacked_data.tracks: - track_data = stacked_data.sel(tracks=track_name, drop=True).dropna( - dim="time", how="any" - ) - stacked_path_length.loc[track_name] = compute_norm( - compute_displacement(track_data) - ).sum(dim="time") - # Return the unstacked path length (restore individual and keypoints) - return stacked_path_length.unstack("tracks") + return _compute_path_length_drop_nan(data) elif nan_policy == "scale": - valid_path_length = compute_norm(compute_displacement(data)).sum( - dim="time", - skipna=True, # path length only for valid points - ) - scale_factor = 1 / (1 - nan_counts / data.sizes["time"]) - return valid_path_length * scale_factor - + return _compute_scaled_path_length(data) else: raise log_error( ValueError, @@ -496,10 +461,15 @@ def _validate_start_stop_times( TypeError If the start or stop time is not numeric. ValueError - If either of the provided times is outside the time range of the data, - or if the start time is later than the stop time. + If the time dimension is missing, if either of the provided times is + outside the data time range, or if the start time is later than the + stop time. """ + # We validate the time dimension here, on top of any validation that may + # occur downstream, because we rely on it for start/stop times. + validate_dims_coords(data, {"time": []}) + provided_time_points = {"start time": start, "stop time": stop} expected_time_range = (data.time.min(), data.time.max()) @@ -526,3 +496,122 @@ def _validate_start_stop_times( ValueError, "The start time must be earlier than the stop time.", ) + + +def _warn_about_nan_proportion( + data: xr.DataArray, nan_warn_threshold: float +) -> None: + """Print a warning if the proportion of NaN values exceeds a threshold. + + The NaN proportion is evaluated per point track, and a given point is + considered NaN if any of its ``space`` coordinates are NaN. The warning + specifically lists the point tracks that exceed the threshold. + + Parameters + ---------- + data : xarray.DataArray + The input data array. + nan_warn_threshold : float + The threshold for the proportion of NaN values. Must be a number + between 0 and 1. + + """ + nan_warn_threshold = float(nan_warn_threshold) + if nan_warn_threshold < 0 or nan_warn_threshold > 1: + raise log_error( + ValueError, + "nan_warn_threshold must be a number between 0 and 1.", + ) + + n_nans = data.isnull().any(dim="space").sum(dim="time") + data_to_warn_about = data.where( + n_nans > data.sizes["time"] * nan_warn_threshold, drop=True + ) + if len(data_to_warn_about) > 0: + log_warning( + "The result may be unreliable for point tracks with many " + "missing values. The following tracks have more than " + f"{nan_warn_threshold * 100:.3} %) NaN values:", + ) + report_nan_values(data_to_warn_about) + + +def _compute_scaled_path_length( + data: xr.DataArray, +) -> xr.DataArray: + """Compute scaled path length based on proportion of valid segments. + + Path length is first computed based on valid segments (non-NaN values + on both ends of the segment) and then scaled based on the proportion of + valid segments per point track - i.e. the result is divided by the + proportion of valid segments. + + Parameters + ---------- + data : xarray.DataArray + The input data containing position information in Cartesian + coordinates, with ``time`` and ``space`` among the dimensions. + + Returns + ------- + xarray.DataArray + An xarray DataArray containing the computed path length. + Will have the same dimensions as the input data, except for ``time`` + and ``space`` which will be removed. + + """ + # Skip first displacement segment (always 0) to not mess up the scaling + displacement = compute_displacement(data).isel(time=slice(1, None)) + # count number of valid displacement segments per point track + valid_segments = (~displacement.isnull()).any(dim="space").sum(dim="time") + # compute proportion of valid segments per point track + valid_proportion = valid_segments / (data.sizes["time"] - 1) + # return scaled path length + return compute_norm(displacement).sum(dim="time") / valid_proportion + + +def _compute_path_length_drop_nan( + data: xr.DataArray, +) -> xr.DataArray: + """Compute path length by dropping NaN values before computation. + + This function iterates over point tracks, drops NaN values from each + track, and then computes the path length for the remaining valid + segments (takes the sum of the norms of the displacement vectors). + If there is no valid segment in a track, the path length for that + track will be NaN. + + Parameters + ---------- + data : xarray.DataArray + The input data containing position information in Cartesian + coordinates, with ``time`` and ``space`` among the dimensions. + + Returns + ------- + xarray.DataArray + An xarray DataArray containing the computed path length. + Will have the same dimensions as the input data, except for ``time`` + and ``space`` which will be removed. + + """ + # Create array for holding results + path_length = xr.full_like( + data.isel(time=0, space=0, drop=True), fill_value=np.nan + ) + + # Stack data to iterate over point tracks + dims_to_stack = [d for d in data.dims if d not in ["time", "space"]] + stacked_data = data.stack(tracks=dims_to_stack) + for track_name in stacked_data.tracks.values: + # Drop NaN values from current point track + track_data = stacked_data.sel(tracks=track_name, drop=True).dropna( + dim="time", how="any" + ) + # Compute path length for current point track + # and store it in the result array + target_loc = {k: v for k, v in zip(dims_to_stack, track_name)} + path_length.loc[target_loc] = compute_norm( + compute_displacement(track_data) + ).sum(dim="time", min_count=1) # returns NaN if no valid segment + return path_length diff --git a/pyproject.toml b/pyproject.toml index 27348c291..2d7b1aa41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,6 +105,7 @@ fix = true ignore = [ "D203", # one blank line before class "D213", # multi-line-summary second line + "B905", # zip without explicit strict ] select = [ "E", # pycodestyle errors diff --git a/tests/conftest.py b/tests/conftest.py index 272e5eaa8..5cbea4c6a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -518,6 +518,40 @@ def valid_poses_dataset_uniform_linear_motion( ) +@pytest.fixture +def valid_poses_dataset_uniform_linear_motion_with_nans( + valid_poses_dataset_uniform_linear_motion, +): + """Return a valid poses dataset with NaN values in the position array. + + Specifically, we will introducde: + - 1 NaN value in the centroid keypoint of individual id_1 at time=0 + - 5 NaN values in the left keypoint of individual id_1 (frames 3-7) + - 10 NaN values in the right keypoint of individual id_1 (all frames) + """ + valid_poses_dataset_uniform_linear_motion.position.loc[ + { + "individuals": "id_1", + "keypoints": "centroid", + "time": 0, + } + ] = np.nan + valid_poses_dataset_uniform_linear_motion.position.loc[ + { + "individuals": "id_1", + "keypoints": "left", + "time": slice(3, 7), + } + ] = np.nan + valid_poses_dataset_uniform_linear_motion.position.loc[ + { + "individuals": "id_1", + "keypoints": "right", + } + ] = np.nan + return valid_poses_dataset_uniform_linear_motion + + # -------------------- Invalid datasets fixtures ------------------------------ @pytest.fixture def not_a_dataset(): diff --git a/tests/test_unit/test_kinematics.py b/tests/test_unit/test_kinematics.py index 4fb9c42ce..7d064dea8 100644 --- a/tests/test_unit/test_kinematics.py +++ b/tests/test_unit/test_kinematics.py @@ -218,36 +218,81 @@ def test_approximate_derivative_with_invalid_order(order): (2, None, does_not_raise()), (None, 6.3, does_not_raise()), # invalid time ranges - (0, 10, pytest.raises(ValueError)), # stop > n_frames - (-1, 9, pytest.raises(ValueError)), # start < 0 - (9, 0, pytest.raises(ValueError)), # start > stop - ("text", 9, pytest.raises(TypeError)), # start is not a number - (0, [0, 1], pytest.raises(TypeError)), # stop is not a number + ( + 0, + 10, + pytest.raises( + ValueError, match="stop time 10 is outside the time range" + ), + ), + ( + -1, + 9, + pytest.raises( + ValueError, match="start time -1 is outside the time range" + ), + ), + ( + 9, + 0, + pytest.raises( + ValueError, + match="start time must be earlier than the stop time", + ), + ), + ( + "text", + 9, + pytest.raises( + TypeError, match="Expected a numeric value for start" + ), + ), + ( + 0, + [0, 1], + pytest.raises( + TypeError, match="Expected a numeric value for stop" + ), + ), ], ) -def test_compute_path_length_across_time_ranges( +@pytest.mark.parametrize( + "nan_policy", + ["drop", "scale"], # results should be same for both here +) +def test_path_length_across_time_ranges( valid_poses_dataset_uniform_linear_motion, start, stop, + nan_policy, expected_exception, ): - """Test that the path length is computed correctly for a uniform linear - motion case. + """Test path length computation for a uniform linear motion case, + across different time ranges. + + The test dataset ``valid_poses_dataset_uniform_linear_motion`` + contains 2 individuals ("id_0" and "id_1"), moving + along x=y and x=-y lines, respectively, at a constant velocity. + At each frame they cover a distance of sqrt(2) in x-y space, so in total + we expect a path length of sqrt(2) * num_segments, where num_segments is + the number of selected frames minus 1. """ position = valid_poses_dataset_uniform_linear_motion.position with expected_exception: path_length = kinematics.compute_path_length( - position, start=start, stop=stop, nan_policy="scale" + position, start=start, stop=stop, nan_policy=nan_policy ) - # Expected number of steps (displacements) in selected time range - num_steps = 9 # full time range: 10 frames - 1 + + # Expected number of segments (displacements) in selected time range + num_segments = 9 # full time range: 10 frames - 1 if start is not None: - num_steps -= np.ceil(start) + num_segments -= np.ceil(start) if stop is not None: - num_steps -= 9 - np.floor(stop) - # Each step has a magnitude of sqrt(2) in x-y space + num_segments -= 9 - np.floor(stop) + print("num_segments", num_segments) + expected_path_length = xr.DataArray( - np.ones((2, 3)) * np.sqrt(2) * num_steps, + np.ones((2, 3)) * np.sqrt(2) * num_segments, dims=["individuals", "keypoints"], coords={ "individuals": position.coords["individuals"], @@ -257,6 +302,125 @@ def test_compute_path_length_across_time_ranges( xr.testing.assert_allclose(path_length, expected_path_length) +@pytest.mark.parametrize( + "nan_policy, expected_path_lengths_id_1, expected_exception", + [ + ( + "drop", + { + # 9 segments - 1 missing on edge + "centroid": np.sqrt(2) * 8, + # missing mid frames should have no effect + "left": np.sqrt(2) * 9, + "right": np.nan, # all frames missing + }, + does_not_raise(), + ), + ( + "scale", + { + # scaling should restore "true" path length + "centroid": np.sqrt(2) * 9, + "left": np.sqrt(2) * 9, + "right": np.nan, # all frames missing + }, + does_not_raise(), + ), + ( + "invalid", # invalid value for nan_policy + {}, + pytest.raises(ValueError, match="Invalid value for nan_policy"), + ), + ], +) +def test_path_length_with_nans( + valid_poses_dataset_uniform_linear_motion_with_nans, + nan_policy, + expected_path_lengths_id_1, + expected_exception, +): + """Test path length computation for a uniform linear motion case, + with varying number of missing values per individual and keypoint. + + The test dataset ``valid_poses_dataset_uniform_linear_motion_with_nans`` + contains 2 individuals ("id_0" and "id_1"), moving + along x=y and x=-y lines, respectively, at a constant velocity. + At each frame they cover a distance of sqrt(2) in x-y space. + + Individual "id_1" has some missing values per keypoint: + - "centroid" is missing a value on the very first frame + - "left" is missing 5 values in middle frames (not at the edges) + - "right" is missing values in all frames + + Individual "id_0" has no missing values. + + Because the underlying motion is uniform linear, the "scale" policy should + perfectly restore the path length for individual "id_1" to its true value. + The "drop" policy should do likewise if frames are missing in the middle, + but will not count any missing frames at the edges. + """ + position = valid_poses_dataset_uniform_linear_motion_with_nans.position + with expected_exception: + path_length = kinematics.compute_path_length( + position, + nan_policy=nan_policy, + ) + # Initialise with expected path lengths for scenario without NaNs + expected_array = xr.DataArray( + np.ones((2, 3)) * np.sqrt(2) * 9, + dims=["individuals", "keypoints"], + coords={ + "individuals": position.coords["individuals"], + "keypoints": position.coords["keypoints"], + }, + ) + # insert expected path lengths for individual id_1 + for kpt, value in expected_path_lengths_id_1.items(): + target_loc = {"individuals": "id_1", "keypoints": kpt} + expected_array.loc[target_loc] = value + xr.testing.assert_allclose(path_length, expected_array) + + +@pytest.mark.parametrize( + "nan_warn_threshold, expected_exception", + [ + (1, does_not_raise()), + (0.2, does_not_raise()), + (-1, pytest.raises(ValueError, match="a number between 0 and 1")), + ], +) +def test_path_length_warns_about_nans( + valid_poses_dataset_uniform_linear_motion_with_nans, + nan_warn_threshold, + expected_exception, + caplog, +): + """Test that a warning is raised when the number of missing values + exceeds a given threshold. + + See the docstring of ``test_path_length_with_nans`` for details + about what's in the dataset. + """ + position = valid_poses_dataset_uniform_linear_motion_with_nans.position + with expected_exception: + kinematics.compute_path_length( + position, nan_warn_threshold=nan_warn_threshold + ) + + if (nan_warn_threshold > 0.1) and (nan_warn_threshold < 0.5): + # Make sure that a warning was emitted + assert caplog.records[0].levelname == "WARNING" + assert "The result may be unreliable" in caplog.records[0].message + # Make sure that the NaN report only mentions + # the individual and keypoint that violate the threshold + assert caplog.records[1].levelname == "INFO" + assert "Individual: id_1" in caplog.records[1].message + assert "Individual: id_2" not in caplog.records[1].message + assert "left: 5/10 (50.0%)" in caplog.records[1].message + assert "right: 10/10 (100.0%)" in caplog.records[1].message + assert "centroid" not in caplog.records[1].message + + @pytest.fixture def valid_data_array_for_forward_vector(): """Return a position data array for an individual with 3 keypoints