Skip to content

Commit

Permalink
rename resample to interpolate, rolling average to window average
Browse files Browse the repository at this point in the history
  • Loading branch information
qian-chu committed Sep 24, 2024
1 parent 5c342ca commit 5795e8e
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 136 deletions.
10 changes: 5 additions & 5 deletions pyneon/export/export_bids.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,18 @@ def export_motion_bids(
imu = rec.imu
if imu is None:
raise ValueError("No IMU data found in the recording.")
resamp_data = imu.resample()
motion_first_ts = resamp_data.loc[0, "timestamp [ns]"]
interp_data = imu.interpolate()
motion_first_ts = interp_data.loc[0, "timestamp [ns]"]
motion_acq_time = datetime.datetime.fromtimestamp(motion_first_ts / 1e9).strftime(
"%Y-%m-%dT%H:%M:%S.%f"
)
resamp_data = resamp_data.drop(columns=["timestamp [ns]", "time [s]"])
interp_data = interp_data.drop(columns=["timestamp [ns]", "time [s]"])

resamp_data.to_csv(
interp_data.to_csv(
motion_tsv_path, sep="\t", index=False, header=False, na_rep="n/a"
)

ch_names = resamp_data.columns
ch_names = interp_data.columns
ch_names = [re.sub(r"\s\[[^\]]*\]", "", ch) for ch in ch_names]
channels = pd.DataFrame(
{
Expand Down
6 changes: 3 additions & 3 deletions pyneon/preprocess/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from .preprocess import resample, concat_streams, concat_events, rolling_average
from .preprocess import interpolate, concat_streams, concat_events, window_average
from .mapping import map_gaze_to_video, estimate_scanpath, overlay_scanpath_on_video
from .epoch import create_epoch, extract_event_times, construct_event_times, Epoch

__all__ = [
"resample",
"interpolate",
"concat_streams",
"concat_events",
"rolling_average",
"window_average",
"map_gaze_to_video",
"estimate_scanpath",
"overlay_scanpath_on_video",
Expand Down
2 changes: 1 addition & 1 deletion pyneon/preprocess/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def map_gaze_to_video(
raise ValueError("No video data available.")

# Resample the gaze data to the video timestamps
mapped_gaze = rec.roll_gaze_on_video()
mapped_gaze = rec.gaze_on_video()

# Mark the fixation status of each frame
mapped_gaze["fixation status"] = pd.NA
Expand Down
163 changes: 76 additions & 87 deletions pyneon/preprocess/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,34 @@
import numpy as np

from typing import TYPE_CHECKING, Union
from scipy import interpolate
from scipy.interpolate import interp1d

if TYPE_CHECKING:
from ..recording import NeonRecording


def resample(
def _check_data(data: pd.DataFrame) -> None:
if "timestamp [ns]" not in data.columns:
raise ValueError("Data must contain a 'timestamp [ns]' column")
if np.any(np.diff(data["timestamp [ns]"]) < 0):
raise ValueError("Timestamps must be monotonically increasing")


def interpolate(
new_ts: np.ndarray,
old_data: pd.DataFrame,
data: pd.DataFrame,
float_kind: str = "linear",
other_kind: str = "nearest",
) -> pd.DataFrame:
"""
Resample the stream to a new set of timestamps.
Interpolate a data stream to a new set of timestamps.
Parameters
----------
new_ts : np.ndarray, optional
New timestamps to resample the stream to. If ``None``,
the stream is resampled to its nominal sampling frequency according to
https://pupil-labs.com/products/neon/specs.
old_data : pd.DataFrame
Data to resample. Must contain a monotonically increasing
New timestamps to evaluate the interpolant at.
data : pd.DataFrame
Data to interpolate. Must contain a monotonically increasing
``timestamp [ns]`` column.
float_kind : str, optional
Kind of interpolation applied on columns of float type,
Expand All @@ -36,116 +41,101 @@ def resample(
Returns
-------
pandas.DataFrame
Resampled data.
Interpolated data.
"""
# Check that 'timestamp [ns]' is in the columns
if "timestamp [ns]" not in old_data.columns:
raise ValueError("old_data must contain a 'timestamp [ns]' column")
# Check that new_ts is monotonicically increasing
if np.any(np.diff(new_ts) < 0):
raise ValueError("new_ts must be monotonically increasing")
# Create a new dataframe with the new timestamps
resamp_data = pd.DataFrame(data=new_ts, columns=["timestamp [ns]"], dtype="Int64")
resamp_data["time [s]"] = (new_ts - new_ts[0]) / 1e9

for col in old_data.columns:
_check_data(data)
new_ts = np.sort(new_ts)
new_data = pd.DataFrame(data=new_ts, columns=["timestamp [ns]"], dtype="Int64")
new_data["time [s]"] = (new_ts - new_ts[0]) / 1e9
for col in data.columns:
# Skip time columns
if col == "timestamp [ns]" or col == "time [s]":
continue
if pd.api.types.is_float_dtype(old_data[col]):
resamp_data[col] = interpolate.interp1d(
old_data["timestamp [ns]"],
old_data[col],
# Float columns are interpolated with float_kind
if pd.api.types.is_float_dtype(data[col]):
new_data[col] = interp1d(
data["timestamp [ns]"],
data[col],
kind=float_kind,
bounds_error=False,
)(new_ts)
# Other columns are interpolated with other_kind
else:
resamp_data[col] = interpolate.interp1d(
old_data["timestamp [ns]"],
old_data[col],
new_data[col] = interp1d(
data["timestamp [ns]"],
data[col],
kind=other_kind,
bounds_error=False,
)(new_ts)
resamp_data[col] = resamp_data[col].astype(old_data[col].dtype)
return resamp_data
# Ensure the new column has the same dtype as the original
new_data[col] = new_data[col].astype(data[col].dtype)
return new_data


def rolling_average(
def window_average(
new_ts: np.ndarray,
old_data: pd.DataFrame,
data: pd.DataFrame,
window_size: Union[int, None] = None,
) -> pd.DataFrame:
"""
Apply rolling average over a time window to resampled data.
Take the average over a time window to obtain smoothed data at new timestamps.
Parameters
----------
new_ts : np.ndarray
New timestamps to resample the stream to.
old_data : pd.DataFrame
Data to apply rolling average to.
time_column : str, optional
Name of the time column in the data, by default "timestamp [ns]".
New timestamps to evaluate the window average at. The median of the differences
between the new timestamps must be larger than the median of the differences
between the old timestamps. In other words, only downsampling is supported.
data : pd.DataFrame
Data to apply window average to. Must contain a monotonically increasing
``timestamp [ns]`` column.
window_size : int, optional
Size of the time window in nanoseconds. If ``None``, the window size is
set to the median of the differences between the new timestamps.
Defaults to ``None``.
Returns
-------
pd.DataFrame
Data with rolling averages applied.
Data with window average applied.
"""
# Assert that 'timestamp [ns]' is present and monotonic
if "timestamp [ns]" not in old_data.columns:
raise ValueError("old_data must contain a 'timestamp [ns]' column")

if not np.all(np.diff(old_data["timestamp [ns]"]) > 0):
# call resample function to ensure monotonicity
old_data = resample(None, old_data)

# assert that the new_ts has a lower sampling frequency than the old data
if np.mean(np.diff(new_ts)) < np.mean(np.diff(old_data["timestamp [ns]"])):
_check_data(data)
new_ts = np.sort(new_ts)
new_ts_median_diff = np.median(np.diff(new_ts))
# Assert that the new_ts has a lower sampling frequency than the old data
if new_ts_median_diff < np.mean(np.diff(data["timestamp [ns]"])):
raise ValueError(
"new_ts must have a lower sampling frequency than the old data"
)

# Create a new DataFrame for the downsampled data
downsampled_data = pd.DataFrame(
data=new_ts, columns=["timestamp [ns]"], dtype="Int64"
)
downsampled_data["time [s]"] = (new_ts - new_ts[0]) / 1e9

# Convert window_size to nanoseconds
window_size = np.mean(np.diff(new_ts))

# Loop through each column (excluding time columns) to compute the downsampling
for col in old_data.columns:
if window_size is None:
window_size = new_ts_median_diff
new_data = pd.DataFrame(data=new_ts, columns=["timestamp [ns]"], dtype="Int64")
new_data["time [s]"] = (new_ts - new_ts[0]) / 1e9
for col in data.columns:
# Skip time columns
if col == "timestamp [ns]" or col == "time [s]":
continue

# Initialize an empty list to store the downsampled values
downsampled_values = []

# Loop through each new timestamp
new_values = [] # List to store the downsampled values
for ts in new_ts:
# Define the time window around the current new timestamp
lower_bound = ts - window_size / 2
upper_bound = ts + window_size / 2

# Select rows from old_data that fall within the time window
window_data = old_data[
(old_data["timestamp [ns]"] >= lower_bound)
& (old_data["timestamp [ns]"] <= upper_bound)
window_data = data[
(data["timestamp [ns]"] >= lower_bound)
& (data["timestamp [ns]"] <= upper_bound)
]

# Compute the average of the data within this window
if not window_data.empty:
mean_value = window_data[col].mean()
else:
mean_value = np.nan

# Append the averaged value to the list
downsampled_values.append(mean_value)

new_values.append(mean_value)
# Assign the downsampled values to the new DataFrame
downsampled_data[col] = downsampled_values
new_data[col] = new_values

return downsampled_data
return new_data


_VALID_STREAMS = {"3d_eye_states", "eye_states", "gaze", "imu"}
Expand All @@ -155,14 +145,14 @@ def concat_streams(
rec: "NeonRecording",
stream_names: Union[str, list[str]] = "all",
sampling_freq: Union[float, int, str] = "min",
resamp_float_kind: str = "linear",
resamp_other_kind: str = "nearest",
interp_float_kind: str = "linear",
interp_other_kind: str = "nearest",
inplace: bool = False,
) -> pd.DataFrame:
"""
Concatenate data from different streams under common timestamps.
Since the streams may have different timestamps and sampling frequencies,
resampling of all streams to a set of common timestamps is performed.
interpolation of all streams to a set of common timestamps is performed.
The latest start timestamp and earliest last timestamp of the selected streams
are used to define the common timestamps.
Expand All @@ -175,20 +165,20 @@ def concat_streams(
If a list, items must be in ``{"gaze", "imu", "eye_states"}``
(``"3d_eye_states"``) is also tolerated as an alias for ``"eye_states"``).
sampling_freq : float or int or str, optional
Sampling frequency to resample the streams to.
If numeric, the streams will be resampled to this frequency.
Sampling frequency of the concatenated streams.
If numeric, the streams will be interpolated to this frequency.
If ``"min"``, the lowest nominal sampling frequency
of the selected streams will be used.
If ``"max"``, the highest nominal sampling frequency will be used.
Defaults to ``"min"``.
resamp_float_kind : str, optional
interp_float_kind : str, optional
Kind of interpolation applied on columns of float type,
Defaults to ``"linear"``. For details see :class:`scipy.interpolate.interp1d`.
resamp_other_kind : str, optional
interp_other_kind : str, optional
Kind of interpolation applied on columns of other types.
Defaults to ``"nearest"``.
inplace : bool, optional
Replace selected stream data with resampled data during concatenation
Replace selected stream data with interpolated data during concatenation
if``True``. Defaults to ``False``.
Returns
Expand All @@ -203,7 +193,6 @@ def concat_streams(
raise ValueError(
"Invalid stream_names, must be 'all' or a list of stream names."
)

if len(stream_names) <= 1:
raise ValueError("Must provide at least two streams to concatenate.")

Expand Down Expand Up @@ -315,8 +304,8 @@ def concat_streams(
concat_data = pd.DataFrame(data=new_ts, columns=["timestamp [ns]"], dtype="Int64")
concat_data["time [s]"] = (new_ts - new_ts[0]) / 1e9
for stream in stream_info["stream"]:
resamp_df = stream.resample(
new_ts, resamp_float_kind, resamp_other_kind, inplace=inplace
resamp_df = stream.interpolate(
new_ts, interp_float_kind, interp_other_kind, inplace=inplace
)
assert concat_data.shape[0] == resamp_df.shape[0]
assert concat_data["timestamp [ns]"].equals(resamp_df["timestamp [ns]"])
Expand Down
17 changes: 7 additions & 10 deletions pyneon/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .preprocess import (
concat_streams,
concat_events,
rolling_average,
window_average,
map_gaze_to_video,
estimate_scanpath,
overlay_scanpath_on_video,
Expand Down Expand Up @@ -265,13 +265,10 @@ def video(self) -> Union[NeonVideo, None]:
"""
if self._video is None:
if (
self.contents.loc["scene_video", "exist"]
and self.contents.loc["world_timestamps", "exist"]
and self.contents.loc["scene_video_info", "exist"]
(video_file := self.contents.loc["scene_video", "path"])
and (timestamp_file := self.contents.loc["world_timestamps", "path"])
and (video_info_file := self.contents.loc["scene_video_info", "path"])
):
video_file = self.contents.loc["scene_video", "path"]
timestamp_file = self.contents.loc["world_timestamps", "path"]
video_info_file = self.contents.loc["scene_video_info", "path"]
self._video = NeonVideo(video_file, timestamp_file, video_info_file)
else:
warnings.warn(
Expand Down Expand Up @@ -428,13 +425,13 @@ def plot_distribution(
show,
)

def roll_gaze_on_video(
def gaze_on_video(
self,
) -> pd.DataFrame:
"""
Apply rolling average over a time window to gaze data.
Apply window average over video timestamps to gaze data.
"""
return rolling_average(self.video.ts, self.gaze.data)
return window_average(self.video.ts, self.gaze.data)

def map_gaze_to_video(
self,
Expand Down
Loading

0 comments on commit 5795e8e

Please sign in to comment.