From e09f42dcf67e523ea73a53259a9c91690a5a8098 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 31 Oct 2024 19:53:44 +0000 Subject: [PATCH] (mostly) leave time range validation to xarray slice --- movement/analysis/kinematics.py | 69 +++++------------------------- tests/test_unit/test_kinematics.py | 46 ++++++++------------ 2 files changed, 27 insertions(+), 88 deletions(-) diff --git a/movement/analysis/kinematics.py b/movement/analysis/kinematics.py index 1dd0944c..8f20748f 100644 --- a/movement/analysis/kinematics.py +++ b/movement/analysis/kinematics.py @@ -255,8 +255,17 @@ def compute_path_length( segments, which may not accurately reflect actual conditions. """ - _validate_start_stop_times(data, start, stop) + validate_dims_coords(data, {"time": [], "space": []}) data = data.sel(time=slice(start, stop)) + # Check that the data is not empty or too short + n_time = data.sizes["time"] + if n_time < 2: + raise log_error( + ValueError, + f"At least 2 time points are required to compute path length, " + f"but {n_time} were found. Double-check the start and stop times.", + ) + _warn_about_nan_proportion(data, nan_warn_threshold) if nan_policy == "drop": @@ -446,64 +455,6 @@ def _validate_type_data_array(data: xr.DataArray) -> None: ) -def _validate_start_stop_times( - data: xr.DataArray, - start: int | float | None, - stop: int | float | None, -) -> None: - """Validate the start and stop times for path length computation. - - Parameters - ---------- - data: xarray.DataArray - The input data array containing position information. - start : float or None - The start time point for path length computation. - stop : float or None - The stop time point for path length computation. - - Raises - ------ - TypeError - If the start or stop time is not numeric. - ValueError - If the time dimension is missing, if either of the provided times is - outside the data time range, or if the start time is later than the - stop time. - - """ - # We validate the time dimension here, on top of any validation that may - # occur downstream, because we rely on it for start/stop times. - validate_dims_coords(data, {"time": []}) - - provided_time_points = {"start time": start, "stop time": stop} - expected_time_range = (data.time.min(), data.time.max()) - - for name, value in provided_time_points.items(): - if value is None: # Skip if the time point is not provided - continue - # Check that the provided value is numeric - if not isinstance(value, int | float): - raise log_error( - TypeError, - f"Expected a numeric value for {name}, but got {type(value)}.", - ) - # Check that the provided value is within the time range of the data - if value < expected_time_range[0] or value > expected_time_range[1]: - raise log_error( - ValueError, - f"The provided {name} {value} is outside the time range " - f"of the data array ({expected_time_range}).", - ) - - # Check that the start time is earlier than the stop time - if start is not None and stop is not None and start >= stop: - raise log_error( - ValueError, - "The start time must be earlier than the stop time.", - ) - - def _warn_about_nan_proportion( data: xr.DataArray, nan_warn_threshold: float ) -> None: diff --git a/tests/test_unit/test_kinematics.py b/tests/test_unit/test_kinematics.py index 453a5d83..8fe7a147 100644 --- a/tests/test_unit/test_kinematics.py +++ b/tests/test_unit/test_kinematics.py @@ -210,61 +210,46 @@ def test_approximate_derivative_with_invalid_order(order): # full time ranges (None, None, does_not_raise()), (0, None, does_not_raise()), - (None, 9, does_not_raise()), (0, 9, does_not_raise()), + (0, 10, does_not_raise()), # xarray.sel will truncate to 0, 9 + (-1, 9, does_not_raise()), # xarray.sel will truncate to 0, 9 # partial time ranges (1, 8, does_not_raise()), (1.5, 8.5, does_not_raise()), (2, None, does_not_raise()), - (None, 6.3, does_not_raise()), - # invalid time ranges - ( - 0, - 10, - pytest.raises( - ValueError, match="stop time 10 is outside the time range" - ), - ), - ( - -1, - 9, - pytest.raises( - ValueError, match="start time -1 is outside the time range" - ), - ), + # Empty time range (because start > stop) ( 9, 0, pytest.raises( ValueError, - match="start time must be earlier than the stop time", + match="At least 2 time points", ), ), + # Empty time range (because of invalid start type) ( "text", 9, pytest.raises( - TypeError, match="Expected a numeric value for start" + ValueError, + match="At least 2 time points", ), ), + # Time range too short ( 0, - [0, 1], + 0.5, pytest.raises( - TypeError, match="Expected a numeric value for stop" + ValueError, + match="At least 2 time points", ), ), ], ) -@pytest.mark.parametrize( - "nan_policy", - ["drop", "scale"], # results should be same for both here -) def test_path_length_across_time_ranges( valid_poses_dataset_uniform_linear_motion, start, stop, - nan_policy, expected_exception, ): """Test path length computation for a uniform linear motion case, @@ -280,15 +265,18 @@ def test_path_length_across_time_ranges( position = valid_poses_dataset_uniform_linear_motion.position with expected_exception: path_length = kinematics.compute_path_length( - position, start=start, stop=stop, nan_policy=nan_policy + position, start=start, stop=stop ) # Expected number of segments (displacements) in selected time range num_segments = 9 # full time range: 10 frames - 1 + start = max(0, start) if start is not None else 0 + stop = min(9, stop) if stop is not None else 9 if start is not None: - num_segments -= np.ceil(start) + num_segments -= np.ceil(max(0, start)) if stop is not None: - num_segments -= 9 - np.floor(stop) + stop = min(9, stop) + num_segments -= 9 - np.floor(min(9, stop)) expected_path_length = xr.DataArray( np.ones((2, 3)) * np.sqrt(2) * num_segments,