Skip to content

Commit

Permalink
Implement compute_speed and compute_path_length (#280)
Browse files Browse the repository at this point in the history
* implement compute_speed and compute_path_length functions

* added speed to existing kinematics unit test

* rewrote compute_path_length with various nan policies

* unit test compute_path_length across time ranges

* fixed and refactor compute_path_length and its tests

* fixed docstring for compute_path_length

* Accept suggestion on docstring wording

Co-authored-by: Chang Huan Lo <[email protected]>

* Remove print statement from test

Co-authored-by: Chang Huan Lo <[email protected]>

* Ensure nan report is printed

Co-authored-by: Chang Huan Lo <[email protected]>

* adapt warning message match in test

* change 'any' to 'all'

* uniform wording across path length docstrings

* (mostly) leave time range validation to xarray slice

* refactored parameters for test across time ranges

* simplified test for path lenght with nans

* replace drop policy with ffill

* remove B905 ruff rule

* make pre-commit happy

---------

Co-authored-by: Chang Huan Lo <[email protected]>
  • Loading branch information
niksirbi and lochhh committed Nov 19, 2024
1 parent 8a9017f commit 0c2917b
Show file tree
Hide file tree
Showing 3 changed files with 416 additions and 9 deletions.
188 changes: 187 additions & 1 deletion movement/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import xarray as xr
from scipy.spatial.distance import cdist

from movement.utils.logging import log_error
from movement.utils.logging import log_error, log_warning
from movement.utils.reports import report_nan_values
from movement.utils.vector import compute_norm
from movement.validators.arrays import validate_dims_coords

Expand Down Expand Up @@ -173,6 +174,30 @@ def compute_time_derivative(data: xr.DataArray, order: int) -> xr.DataArray:
return result


def compute_speed(data: xr.DataArray) -> xr.DataArray:
"""Compute instantaneous speed at each time point.
Speed is a scalar quantity computed as the Euclidean norm (magnitude)
of the velocity vector at each time point.
Parameters
----------
data : xarray.DataArray
The input data containing position information, with ``time``
and ``space`` (in Cartesian coordinates) as required dimensions.
Returns
-------
xarray.DataArray
An xarray DataArray containing the computed speed,
with dimensions matching those of the input data,
except ``space`` is removed.
"""
return compute_norm(compute_velocity(data))


def compute_forward_vector(
data: xr.DataArray,
left_keypoint: str,
Expand Down Expand Up @@ -675,3 +700,164 @@ def _validate_type_data_array(data: xr.DataArray) -> None:
TypeError,
f"Input data must be an xarray.DataArray, but got {type(data)}.",
)


def compute_path_length(
data: xr.DataArray,
start: float | None = None,
stop: float | None = None,
nan_policy: Literal["ffill", "scale"] = "ffill",
nan_warn_threshold: float = 0.2,
) -> xr.DataArray:
"""Compute the length of a path travelled between two time points.
The path length is defined as the sum of the norms (magnitudes) of the
displacement vectors between two time points ``start`` and ``stop``,
which should be provided in the time units of the data array.
If not specified, the minimum and maximum time coordinates of the data
array are used as start and stop times, respectively.
Parameters
----------
data : xarray.DataArray
The input data containing position information, with ``time``
and ``space`` (in Cartesian coordinates) as required dimensions.
start : float, optional
The start time of the path. If None (default),
the minimum time coordinate in the data is used.
stop : float, optional
The end time of the path. If None (default),
the maximum time coordinate in the data is used.
nan_policy : Literal["ffill", "scale"], optional
Policy to handle NaN (missing) values. Can be one of the ``"ffill"``
or ``"scale"``. Defaults to ``"ffill"`` (forward fill).
See Notes for more details on the two policies.
nan_warn_threshold : float, optional
If more than this proportion of values are missing in any point track,
a warning will be emitted. Defaults to 0.2 (20%).
Returns
-------
xarray.DataArray
An xarray DataArray containing the computed path length,
with dimensions matching those of the input data,
except ``time`` and ``space`` are removed.
Notes
-----
Choosing ``nan_policy="ffill"`` will use :meth:`xarray.DataArray.ffill`
to forward-fill missing segments (NaN values) across time.
This equates to assuming that a track remains stationary for
the duration of the missing segment and then instantaneously moves to
the next valid position, following a straight line. This approach tends
to underestimate the path length, and the error increases with the number
of missing values.
Choosing ``nan_policy="scale"`` will adjust the path length based on the
the proportion of valid segments per point track. For example, if only
80% of segments are present, the path length will be computed based on
these and the result will be divided by 0.8. This approach assumes
that motion dynamics are similar across observed and missing time
segments, which may not accurately reflect actual conditions.
"""
validate_dims_coords(data, {"time": [], "space": []})
data = data.sel(time=slice(start, stop))
# Check that the data is not empty or too short
n_time = data.sizes["time"]
if n_time < 2:
raise log_error(
ValueError,
f"At least 2 time points are required to compute path length, "
f"but {n_time} were found. Double-check the start and stop times.",
)

_warn_about_nan_proportion(data, nan_warn_threshold)

if nan_policy == "ffill":
return compute_norm(
compute_displacement(data.ffill(dim="time")).isel(
time=slice(1, None)
) # skip first displacement (always 0)
).sum(dim="time", min_count=1) # return NaN if no valid segment

elif nan_policy == "scale":
return _compute_scaled_path_length(data)
else:
raise log_error(
ValueError,
f"Invalid value for nan_policy: {nan_policy}. "
"Must be one of 'ffill' or 'scale'.",
)


def _warn_about_nan_proportion(
data: xr.DataArray, nan_warn_threshold: float
) -> None:
"""Print a warning if the proportion of NaN values exceeds a threshold.
The NaN proportion is evaluated per point track, and a given point is
considered NaN if any of its ``space`` coordinates are NaN. The warning
specifically lists the point tracks that exceed the threshold.
Parameters
----------
data : xarray.DataArray
The input data array.
nan_warn_threshold : float
The threshold for the proportion of NaN values. Must be a number
between 0 and 1.
"""
nan_warn_threshold = float(nan_warn_threshold)
if not 0 <= nan_warn_threshold <= 1:
raise log_error(
ValueError,
"nan_warn_threshold must be between 0 and 1.",
)

n_nans = data.isnull().any(dim="space").sum(dim="time")
data_to_warn_about = data.where(
n_nans > data.sizes["time"] * nan_warn_threshold, drop=True
)
if len(data_to_warn_about) > 0:
log_warning(
"The result may be unreliable for point tracks with many "
"missing values. The following tracks have more than "
f"{nan_warn_threshold * 100:.3} % NaN values:",
)
print(report_nan_values(data_to_warn_about))


def _compute_scaled_path_length(
data: xr.DataArray,
) -> xr.DataArray:
"""Compute scaled path length based on proportion of valid segments.
Path length is first computed based on valid segments (non-NaN values
on both ends of the segment) and then scaled based on the proportion of
valid segments per point track - i.e. the result is divided by the
proportion of valid segments.
Parameters
----------
data : xarray.DataArray
The input data containing position information, with ``time``
and ``space`` (in Cartesian coordinates) as required dimensions.
Returns
-------
xarray.DataArray
An xarray DataArray containing the computed path length,
with dimensions matching those of the input data,
except ``time`` and ``space`` are removed.
"""
# Skip first displacement segment (always 0) to not mess up the scaling
displacement = compute_displacement(data).isel(time=slice(1, None))
# count number of valid displacement segments per point track
valid_segments = (~displacement.isnull()).all(dim="space").sum(dim="time")
# compute proportion of valid segments per point track
valid_proportion = valid_segments / (data.sizes["time"] - 1)
# return scaled path length
return compute_norm(displacement).sum(dim="time") / valid_proportion
34 changes: 34 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,40 @@ def valid_poses_dataset_uniform_linear_motion(
)


@pytest.fixture
def valid_poses_dataset_uniform_linear_motion_with_nans(
valid_poses_dataset_uniform_linear_motion,
):
"""Return a valid poses dataset with NaN values in the position array.
Specifically, we will introducde:
- 1 NaN value in the centroid keypoint of individual id_1 at time=0
- 5 NaN values in the left keypoint of individual id_1 (frames 3-7)
- 10 NaN values in the right keypoint of individual id_1 (all frames)
"""
valid_poses_dataset_uniform_linear_motion.position.loc[
{
"individuals": "id_1",
"keypoints": "centroid",
"time": 0,
}
] = np.nan
valid_poses_dataset_uniform_linear_motion.position.loc[
{
"individuals": "id_1",
"keypoints": "left",
"time": slice(3, 7),
}
] = np.nan
valid_poses_dataset_uniform_linear_motion.position.loc[
{
"individuals": "id_1",
"keypoints": "right",
}
] = np.nan
return valid_poses_dataset_uniform_linear_motion


# -------------------- Invalid datasets fixtures ------------------------------
@pytest.fixture
def not_a_dataset():
Expand Down
Loading

0 comments on commit 0c2917b

Please sign in to comment.