Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement compute_speed and compute_path_length #280

Merged
merged 18 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 187 additions & 1 deletion movement/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import xarray as xr
from scipy.spatial.distance import cdist

from movement.utils.logging import log_error
from movement.utils.logging import log_error, log_warning
from movement.utils.reports import report_nan_values
from movement.utils.vector import compute_norm
from movement.validators.arrays import validate_dims_coords

Expand Down Expand Up @@ -173,6 +174,30 @@ def compute_time_derivative(data: xr.DataArray, order: int) -> xr.DataArray:
return result


def compute_speed(data: xr.DataArray) -> xr.DataArray:
"""Compute instantaneous speed at each time point.

Speed is a scalar quantity computed as the Euclidean norm (magnitude)
of the velocity vector at each time point.


Parameters
----------
data : xarray.DataArray
The input data containing position information, with ``time``
and ``space`` (in Cartesian coordinates) as required dimensions.

Returns
-------
xarray.DataArray
An xarray DataArray containing the computed speed,
with dimensions matching those of the input data,
except ``space`` is removed.

"""
return compute_norm(compute_velocity(data))


def compute_forward_vector(
data: xr.DataArray,
left_keypoint: str,
Expand Down Expand Up @@ -675,3 +700,164 @@ def _validate_type_data_array(data: xr.DataArray) -> None:
TypeError,
f"Input data must be an xarray.DataArray, but got {type(data)}.",
)


def compute_path_length(
data: xr.DataArray,
start: float | None = None,
stop: float | None = None,
nan_policy: Literal["ffill", "scale"] = "ffill",
nan_warn_threshold: float = 0.2,
) -> xr.DataArray:
"""Compute the length of a path travelled between two time points.

The path length is defined as the sum of the norms (magnitudes) of the
displacement vectors between two time points ``start`` and ``stop``,
which should be provided in the time units of the data array.
If not specified, the minimum and maximum time coordinates of the data
array are used as start and stop times, respectively.

Parameters
----------
data : xarray.DataArray
The input data containing position information, with ``time``
and ``space`` (in Cartesian coordinates) as required dimensions.
start : float, optional
The start time of the path. If None (default),
the minimum time coordinate in the data is used.
stop : float, optional
The end time of the path. If None (default),
the maximum time coordinate in the data is used.
nan_policy : Literal["ffill", "scale"], optional
Policy to handle NaN (missing) values. Can be one of the ``"ffill"``
or ``"scale"``. Defaults to ``"ffill"`` (forward fill).
See Notes for more details on the two policies.
nan_warn_threshold : float, optional
If more than this proportion of values are missing in any point track,
a warning will be emitted. Defaults to 0.2 (20%).

Returns
-------
xarray.DataArray
An xarray DataArray containing the computed path length,
with dimensions matching those of the input data,
except ``time`` and ``space`` are removed.

Notes
-----
Choosing ``nan_policy="ffill"`` will use :meth:`xarray.DataArray.ffill`
to forward-fill missing segments (NaN values) across time.
This equates to assuming that a track remains stationary for
the duration of the missing segment and then instantaneously moves to
the next valid position, following a straight line. This approach tends
to underestimate the path length, and the error increases with the number
of missing values.

Choosing ``nan_policy="scale"`` will adjust the path length based on the
the proportion of valid segments per point track. For example, if only
80% of segments are present, the path length will be computed based on
these and the result will be divided by 0.8. This approach assumes
that motion dynamics are similar across observed and missing time
segments, which may not accurately reflect actual conditions.

"""
validate_dims_coords(data, {"time": [], "space": []})
data = data.sel(time=slice(start, stop))
# Check that the data is not empty or too short
n_time = data.sizes["time"]
if n_time < 2:
raise log_error(
ValueError,
f"At least 2 time points are required to compute path length, "
f"but {n_time} were found. Double-check the start and stop times.",
)

_warn_about_nan_proportion(data, nan_warn_threshold)

if nan_policy == "ffill":
return compute_norm(
compute_displacement(data.ffill(dim="time")).isel(
time=slice(1, None)
) # skip first displacement (always 0)
).sum(dim="time", min_count=1) # return NaN if no valid segment

elif nan_policy == "scale":
return _compute_scaled_path_length(data)
else:
raise log_error(
ValueError,
f"Invalid value for nan_policy: {nan_policy}. "
"Must be one of 'ffill' or 'scale'.",
)


def _warn_about_nan_proportion(
data: xr.DataArray, nan_warn_threshold: float
) -> None:
"""Print a warning if the proportion of NaN values exceeds a threshold.

The NaN proportion is evaluated per point track, and a given point is
considered NaN if any of its ``space`` coordinates are NaN. The warning
specifically lists the point tracks that exceed the threshold.

Parameters
----------
data : xarray.DataArray
The input data array.
nan_warn_threshold : float
The threshold for the proportion of NaN values. Must be a number
between 0 and 1.

"""
nan_warn_threshold = float(nan_warn_threshold)
if not 0 <= nan_warn_threshold <= 1:
raise log_error(
ValueError,
"nan_warn_threshold must be between 0 and 1.",
)

n_nans = data.isnull().any(dim="space").sum(dim="time")
data_to_warn_about = data.where(
n_nans > data.sizes["time"] * nan_warn_threshold, drop=True
)
if len(data_to_warn_about) > 0:
log_warning(
"The result may be unreliable for point tracks with many "
"missing values. The following tracks have more than "
f"{nan_warn_threshold * 100:.3} % NaN values:",
)
print(report_nan_values(data_to_warn_about))


def _compute_scaled_path_length(
data: xr.DataArray,
) -> xr.DataArray:
"""Compute scaled path length based on proportion of valid segments.

Path length is first computed based on valid segments (non-NaN values
on both ends of the segment) and then scaled based on the proportion of
valid segments per point track - i.e. the result is divided by the
proportion of valid segments.

Parameters
----------
data : xarray.DataArray
The input data containing position information, with ``time``
and ``space`` (in Cartesian coordinates) as required dimensions.

Returns
-------
xarray.DataArray
An xarray DataArray containing the computed path length,
with dimensions matching those of the input data,
except ``time`` and ``space`` are removed.

"""
# Skip first displacement segment (always 0) to not mess up the scaling
displacement = compute_displacement(data).isel(time=slice(1, None))
# count number of valid displacement segments per point track
valid_segments = (~displacement.isnull()).all(dim="space").sum(dim="time")
# compute proportion of valid segments per point track
valid_proportion = valid_segments / (data.sizes["time"] - 1)
# return scaled path length
return compute_norm(displacement).sum(dim="time") / valid_proportion
34 changes: 34 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,40 @@ def valid_poses_dataset_uniform_linear_motion(
)


@pytest.fixture
def valid_poses_dataset_uniform_linear_motion_with_nans(
valid_poses_dataset_uniform_linear_motion,
):
"""Return a valid poses dataset with NaN values in the position array.

Specifically, we will introducde:
- 1 NaN value in the centroid keypoint of individual id_1 at time=0
- 5 NaN values in the left keypoint of individual id_1 (frames 3-7)
- 10 NaN values in the right keypoint of individual id_1 (all frames)
"""
valid_poses_dataset_uniform_linear_motion.position.loc[
{
"individuals": "id_1",
"keypoints": "centroid",
"time": 0,
}
] = np.nan
valid_poses_dataset_uniform_linear_motion.position.loc[
{
"individuals": "id_1",
"keypoints": "left",
"time": slice(3, 7),
}
] = np.nan
valid_poses_dataset_uniform_linear_motion.position.loc[
{
"individuals": "id_1",
"keypoints": "right",
}
] = np.nan
return valid_poses_dataset_uniform_linear_motion
niksirbi marked this conversation as resolved.
Show resolved Hide resolved


# -------------------- Invalid datasets fixtures ------------------------------
@pytest.fixture
def not_a_dataset():
Expand Down
Loading
Loading