Skip to content

Commit

Permalink
fixed and refactor compute_path_length and its tests
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Oct 10, 2024
1 parent 05182d8 commit ccbfeff
Show file tree
Hide file tree
Showing 4 changed files with 352 additions and 64 deletions.
187 changes: 138 additions & 49 deletions movement/analysis/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import xarray as xr

from movement.utils.logging import log_error, log_warning
from movement.utils.reports import report_nan_values
from movement.utils.vector import compute_norm
from movement.validators.arrays import validate_dims_coords

Expand Down Expand Up @@ -226,13 +227,15 @@ def compute_path_length(
- ``"drop"``: drop any NaN values before computing path length. This
is the default behavior, and it equates to assuming that a track
follows a straight line between two valid points flanking a missing
segment. This approach tends to underestimate the path length,
and the error increases with the number of missing values.
segment. Missing segments at the beginning or end of the specified
time range are not counted. This approach tends to underestimate
the path length, and the error increases with the number of missing
values.
- ``"scale"``: scale path length based on the proportion of valid
values per point track. For example, if only 80% of the values are
present, the path length will be computed based on these values,
and the result will be multiplied by 1/0.8 = 1.25. This approach
assumes that the point's dynamics are similar across present
segments per point track. For example, if only 80% of segments
are present, the path length will be computed based on these
and the result will be divided by 0.8. This approach
assumes that motion dynamics are similar across present
and missing time segments, which may not be the case.
nan_warn_threshold : float, optional
If more than this proportion of values are missing in any point track,
Expand All @@ -246,53 +249,15 @@ def compute_path_length(
and ``space`` which will be removed.
"""
# We validate the time dimension here, on top of its later validation
# inside compute_displacement, because we rely on it for start/stop times.
validate_dims_coords(data, {"time": []})
_validate_start_stop_times(data, start, stop)

# Select data within the specified time range
data = data.sel(time=slice(start, stop))

# Emit a warning for point tracks with many missing values
nan_counts = data.isnull().any(dim="space").sum(dim="time")
dims_to_stack = [dim for dim in data.dims if dim not in ["time", "space"]]
# Stack individual and keypoints dims into a single 'tracks' dimension
stacked_nan_counts = nan_counts.stack(tracks=dims_to_stack)
tracks_with_warning = stacked_nan_counts.where(
stacked_nan_counts > nan_warn_threshold, drop=True
).tracks.values
if len(tracks_with_warning) > 0:
log_warning(
"The following point tracks have more than "
f"{nan_warn_threshold * 100}% missing values, which may lead to "
"unreliable path length estimates: "
f"{', '.join(tracks_with_warning)}."
)
_warn_about_nan_proportion(data, nan_warn_threshold)

if nan_policy == "drop":
stacked_data = data.stack(tracks=dims_to_stack)
# Create an empty data array to hold the path length for each track
stacked_path_length = xr.zeros_like(stacked_nan_counts)
# Compute path length for each track
for track_name in stacked_data.tracks:
track_data = stacked_data.sel(tracks=track_name, drop=True).dropna(
dim="time", how="any"
)
stacked_path_length.loc[track_name] = compute_norm(
compute_displacement(track_data)
).sum(dim="time")
# Return the unstacked path length (restore individual and keypoints)
return stacked_path_length.unstack("tracks")
return _compute_path_length_drop_nan(data)

elif nan_policy == "scale":
valid_path_length = compute_norm(compute_displacement(data)).sum(
dim="time",
skipna=True, # path length only for valid points
)
scale_factor = 1 / (1 - nan_counts / data.sizes["time"])
return valid_path_length * scale_factor

return _compute_scaled_path_length(data)
else:
raise log_error(
ValueError,
Expand Down Expand Up @@ -496,10 +461,15 @@ def _validate_start_stop_times(
TypeError
If the start or stop time is not numeric.
ValueError
If either of the provided times is outside the time range of the data,
or if the start time is later than the stop time.
If the time dimension is missing, if either of the provided times is
outside the data time range, or if the start time is later than the
stop time.
"""
# We validate the time dimension here, on top of any validation that may
# occur downstream, because we rely on it for start/stop times.
validate_dims_coords(data, {"time": []})

provided_time_points = {"start time": start, "stop time": stop}
expected_time_range = (data.time.min(), data.time.max())

Expand All @@ -526,3 +496,122 @@ def _validate_start_stop_times(
ValueError,
"The start time must be earlier than the stop time.",
)


def _warn_about_nan_proportion(
data: xr.DataArray, nan_warn_threshold: float
) -> None:
"""Print a warning if the proportion of NaN values exceeds a threshold.
The NaN proportion is evaluated per point track, and a given point is
considered NaN if any of its ``space`` coordinates are NaN. The warning
specifically lists the point tracks that exceed the threshold.
Parameters
----------
data : xarray.DataArray
The input data array.
nan_warn_threshold : float
The threshold for the proportion of NaN values. Must be a number
between 0 and 1.
"""
nan_warn_threshold = float(nan_warn_threshold)
if nan_warn_threshold < 0 or nan_warn_threshold > 1:
raise log_error(
ValueError,
"nan_warn_threshold must be a number between 0 and 1.",
)

n_nans = data.isnull().any(dim="space").sum(dim="time")
data_to_warn_about = data.where(
n_nans > data.sizes["time"] * nan_warn_threshold, drop=True
)
if len(data_to_warn_about) > 0:
log_warning(
"The result may be unreliable for point tracks with many "
"missing values. The following tracks have more than "
f"{nan_warn_threshold * 100:.3} %) NaN values:",
)
report_nan_values(data_to_warn_about)


def _compute_scaled_path_length(
data: xr.DataArray,
) -> xr.DataArray:
"""Compute scaled path length based on proportion of valid segments.
Path length is first computed based on valid segments (non-NaN values
on both ends of the segment) and then scaled based on the proportion of
valid segments per point track - i.e. the result is divided by the
proportion of valid segments.
Parameters
----------
data : xarray.DataArray
The input data containing position information in Cartesian
coordinates, with ``time`` and ``space`` among the dimensions.
Returns
-------
xarray.DataArray
An xarray DataArray containing the computed path length.
Will have the same dimensions as the input data, except for ``time``
and ``space`` which will be removed.
"""
# Skip first displacement segment (always 0) to not mess up the scaling
displacement = compute_displacement(data).isel(time=slice(1, None))
# count number of valid displacement segments per point track
valid_segments = (~displacement.isnull()).any(dim="space").sum(dim="time")
# compute proportion of valid segments per point track
valid_proportion = valid_segments / (data.sizes["time"] - 1)
# return scaled path length
return compute_norm(displacement).sum(dim="time") / valid_proportion


def _compute_path_length_drop_nan(
data: xr.DataArray,
) -> xr.DataArray:
"""Compute path length by dropping NaN values before computation.
This function iterates over point tracks, drops NaN values from each
track, and then computes the path length for the remaining valid
segments (takes the sum of the norms of the displacement vectors).
If there is no valid segment in a track, the path length for that
track will be NaN.
Parameters
----------
data : xarray.DataArray
The input data containing position information in Cartesian
coordinates, with ``time`` and ``space`` among the dimensions.
Returns
-------
xarray.DataArray
An xarray DataArray containing the computed path length.
Will have the same dimensions as the input data, except for ``time``
and ``space`` which will be removed.
"""
# Create array for holding results
path_length = xr.full_like(
data.isel(time=0, space=0, drop=True), fill_value=np.nan
)

# Stack data to iterate over point tracks
dims_to_stack = [d for d in data.dims if d not in ["time", "space"]]
stacked_data = data.stack(tracks=dims_to_stack)
for track_name in stacked_data.tracks.values:
# Drop NaN values from current point track
track_data = stacked_data.sel(tracks=track_name, drop=True).dropna(
dim="time", how="any"
)
# Compute path length for current point track
# and store it in the result array
target_loc = {k: v for k, v in zip(dims_to_stack, track_name)}
path_length.loc[target_loc] = compute_norm(
compute_displacement(track_data)
).sum(dim="time", min_count=1) # returns NaN if no valid segment
return path_length
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ fix = true
ignore = [
"D203", # one blank line before class
"D213", # multi-line-summary second line
"B905", # zip without explicit strict
]
select = [
"E", # pycodestyle errors
Expand Down
34 changes: 34 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,40 @@ def valid_poses_dataset_uniform_linear_motion(
)


@pytest.fixture
def valid_poses_dataset_uniform_linear_motion_with_nans(
valid_poses_dataset_uniform_linear_motion,
):
"""Return a valid poses dataset with NaN values in the position array.
Specifically, we will introducde:
- 1 NaN value in the centroid keypoint of individual id_1 at time=0
- 5 NaN values in the left keypoint of individual id_1 (frames 3-7)
- 10 NaN values in the right keypoint of individual id_1 (all frames)
"""
valid_poses_dataset_uniform_linear_motion.position.loc[
{
"individuals": "id_1",
"keypoints": "centroid",
"time": 0,
}
] = np.nan
valid_poses_dataset_uniform_linear_motion.position.loc[
{
"individuals": "id_1",
"keypoints": "left",
"time": slice(3, 7),
}
] = np.nan
valid_poses_dataset_uniform_linear_motion.position.loc[
{
"individuals": "id_1",
"keypoints": "right",
}
] = np.nan
return valid_poses_dataset_uniform_linear_motion


# -------------------- Invalid datasets fixtures ------------------------------
@pytest.fixture
def not_a_dataset():
Expand Down
Loading

0 comments on commit ccbfeff

Please sign in to comment.