Skip to content

Commit

Permalink
(mostly) leave time range validation to xarray slice
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Oct 31, 2024
1 parent 9fb0b26 commit e09f42d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 88 deletions.
69 changes: 10 additions & 59 deletions movement/analysis/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,17 @@ def compute_path_length(
segments, which may not accurately reflect actual conditions.
"""
_validate_start_stop_times(data, start, stop)
validate_dims_coords(data, {"time": [], "space": []})
data = data.sel(time=slice(start, stop))
# Check that the data is not empty or too short
n_time = data.sizes["time"]
if n_time < 2:
raise log_error(
ValueError,
f"At least 2 time points are required to compute path length, "
f"but {n_time} were found. Double-check the start and stop times.",
)

_warn_about_nan_proportion(data, nan_warn_threshold)

if nan_policy == "drop":
Expand Down Expand Up @@ -446,64 +455,6 @@ def _validate_type_data_array(data: xr.DataArray) -> None:
)


def _validate_start_stop_times(
data: xr.DataArray,
start: int | float | None,
stop: int | float | None,
) -> None:
"""Validate the start and stop times for path length computation.
Parameters
----------
data: xarray.DataArray
The input data array containing position information.
start : float or None
The start time point for path length computation.
stop : float or None
The stop time point for path length computation.
Raises
------
TypeError
If the start or stop time is not numeric.
ValueError
If the time dimension is missing, if either of the provided times is
outside the data time range, or if the start time is later than the
stop time.
"""
# We validate the time dimension here, on top of any validation that may
# occur downstream, because we rely on it for start/stop times.
validate_dims_coords(data, {"time": []})

provided_time_points = {"start time": start, "stop time": stop}
expected_time_range = (data.time.min(), data.time.max())

for name, value in provided_time_points.items():
if value is None: # Skip if the time point is not provided
continue
# Check that the provided value is numeric
if not isinstance(value, int | float):
raise log_error(
TypeError,
f"Expected a numeric value for {name}, but got {type(value)}.",
)
# Check that the provided value is within the time range of the data
if value < expected_time_range[0] or value > expected_time_range[1]:
raise log_error(
ValueError,
f"The provided {name} {value} is outside the time range "
f"of the data array ({expected_time_range}).",
)

# Check that the start time is earlier than the stop time
if start is not None and stop is not None and start >= stop:
raise log_error(
ValueError,
"The start time must be earlier than the stop time.",
)


def _warn_about_nan_proportion(
data: xr.DataArray, nan_warn_threshold: float
) -> None:
Expand Down
46 changes: 17 additions & 29 deletions tests/test_unit/test_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,61 +210,46 @@ def test_approximate_derivative_with_invalid_order(order):
# full time ranges
(None, None, does_not_raise()),
(0, None, does_not_raise()),
(None, 9, does_not_raise()),
(0, 9, does_not_raise()),
(0, 10, does_not_raise()), # xarray.sel will truncate to 0, 9
(-1, 9, does_not_raise()), # xarray.sel will truncate to 0, 9
# partial time ranges
(1, 8, does_not_raise()),
(1.5, 8.5, does_not_raise()),
(2, None, does_not_raise()),
(None, 6.3, does_not_raise()),
# invalid time ranges
(
0,
10,
pytest.raises(
ValueError, match="stop time 10 is outside the time range"
),
),
(
-1,
9,
pytest.raises(
ValueError, match="start time -1 is outside the time range"
),
),
# Empty time range (because start > stop)
(
9,
0,
pytest.raises(
ValueError,
match="start time must be earlier than the stop time",
match="At least 2 time points",
),
),
# Empty time range (because of invalid start type)
(
"text",
9,
pytest.raises(
TypeError, match="Expected a numeric value for start"
ValueError,
match="At least 2 time points",
),
),
# Time range too short
(
0,
[0, 1],
0.5,
pytest.raises(
TypeError, match="Expected a numeric value for stop"
ValueError,
match="At least 2 time points",
),
),
],
)
@pytest.mark.parametrize(
"nan_policy",
["drop", "scale"], # results should be same for both here
)
def test_path_length_across_time_ranges(
valid_poses_dataset_uniform_linear_motion,
start,
stop,
nan_policy,
expected_exception,
):
"""Test path length computation for a uniform linear motion case,
Expand All @@ -280,15 +265,18 @@ def test_path_length_across_time_ranges(
position = valid_poses_dataset_uniform_linear_motion.position
with expected_exception:
path_length = kinematics.compute_path_length(
position, start=start, stop=stop, nan_policy=nan_policy
position, start=start, stop=stop
)

# Expected number of segments (displacements) in selected time range
num_segments = 9 # full time range: 10 frames - 1
start = max(0, start) if start is not None else 0
stop = min(9, stop) if stop is not None else 9
if start is not None:
num_segments -= np.ceil(start)
num_segments -= np.ceil(max(0, start))
if stop is not None:
num_segments -= 9 - np.floor(stop)
stop = min(9, stop)
num_segments -= 9 - np.floor(min(9, stop))

expected_path_length = xr.DataArray(
np.ones((2, 3)) * np.sqrt(2) * num_segments,
Expand Down

0 comments on commit e09f42d

Please sign in to comment.