Skip to content

Commit

Permalink
tidy up estimate_scanpath
Browse files Browse the repository at this point in the history
  • Loading branch information
qian-chu committed Oct 16, 2024
1 parent dfa288c commit e5c3397
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 71 deletions.
3 changes: 2 additions & 1 deletion pyneon/preprocess/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

if TYPE_CHECKING:
from ..recording import NeonRecording
from tqdm import tqdm


# Data must be from a NeonStream object
Expand Down Expand Up @@ -130,7 +131,7 @@ def window_average(
new_data = pd.DataFrame(index=new_ts, columns=data.columns).astype(data.dtypes)
non_float_cols = data.select_dtypes(exclude="float").columns

for ts in new_ts:
for ts in tqdm(new_ts, desc="Computing window averages"):
lower_bound = ts - window_size / 2
upper_bound = ts + window_size / 2

Expand Down
3 changes: 2 additions & 1 deletion pyneon/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ def sync_gaze_to_video(

def estimate_scanpath(
self,
sync_gaze: Optional["NeonGaze"] = None,
lk_params: Union[None, dict] = None,
) -> pd.DataFrame:
"""
Expand All @@ -473,7 +474,7 @@ def estimate_scanpath(
lk_params : dict
Parameters for the Lucas-Kanade optical flow algorithm.
"""
return estimate_scanpath(self, lk_params)
return estimate_scanpath(self, sync_gaze, lk_params)

def overlay_scanpath_on_video(
self,
Expand Down
152 changes: 83 additions & 69 deletions pyneon/video/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import cv2
import warnings
from pathlib import Path

from tqdm import tqdm
from typing import TYPE_CHECKING, Union, Optional

from ..preprocess import window_average

if TYPE_CHECKING:
from ..recording import NeonRecording
from ..stream import NeonGaze
from .video import NeonVideo


def sync_gaze_to_video(
Expand Down Expand Up @@ -65,6 +66,27 @@ def sync_gaze_to_video(
return sync_gaze


class ScanPath:
def __init__(
self,
rec: "NeonRecording",
gaze: Optional["NeonGaze"] = None,
lk_params: Optional[dict] = None,
):
self.video = rec.video
self.gaze = sync_gaze_to_video(rec) if gaze is None else gaze
self.lk_params = (
dict(
winSize=(90, 90),
maxLevel=3,
criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 20, 0.03),
)
if lk_params is None
else lk_params
)
self.scanpath = None


def estimate_scanpath(
rec: "NeonRecording",
sync_gaze: Optional["NeonGaze"] = None,
Expand All @@ -80,110 +102,99 @@ def estimate_scanpath(
lk_params : dict
Parameters for the Lucas-Kanade optical flow algorithm.
"""

warnings.simplefilter(action="ignore", category=FutureWarning)

video = rec.video
if sync_gaze is None:
sync_gaze = sync_gaze_to_video(rec)
sync_gaze = rec.sync_gaze_to_video()
if video is None:
raise ValueError("No video data available.")
if not np.allclose(sync_gaze.ts, video.ts):
raise ValueError("Gaze and video timestamps do not match.")
# Default parameters for Lucas-Kanade optical flow from Neon
if lk_params is None:
lk_params = {
"winSize": (90, 90),
"maxLevel": 3,
"criteria": (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 20, 0.03),
}
gaze_data = sync_gaze.data.copy().reset_index(drop=True)

# create a new dataframe of dataframes that stores the relevant data for each fixation
sync_gaze.rename(
columns={"time [s]": "time", "gaze x [px]": "x", "gaze y [px]": "y"},
inplace=True,
)
estimated_scanpath = pd.DataFrame(columns=["time", "fixations"])

for idx in range(sync_gaze.shape[0]):
estimated_scanpath = estimated_scanpath._append(
{
"time": sync_gaze.loc[idx, "time"],
"fixations": sync_gaze.loc[
[idx], ["fixation id", "x", "y", "fixation status"]
],
},
ignore_index=True,
)
scanpath = pd.DataFrame(index=sync_gaze.ts, columns=["fixations"], dtype="object")
scanpath["fixations"] = [
gaze_data.loc[
i, ["fixation id", "gaze x [px]", "gaze y [px]", "fixation status"]
].to_frame().T
for i in gaze_data.index
]

video = rec.video
# reset video to the beginning
video.set(cv2.CAP_PROP_POS_FRAMES, 0)

# Taken from Neon
if lk_params is None:
lk_params = dict(
winSize=(90, 90),
maxLevel=3,
criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 20, 0.03),
)

prev_frame = None
prev_fix = None

for idx in estimated_scanpath.index:
prev_image = None

for i_frame in tqdm(range(scanpath.shape[0]), desc="Estimating scanpath"):
# Read the current frame from the video
ret, frame = video.read()
if not ret:
break

curr_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
curr_image = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

if idx >= 1:
# Access previous frame fixations and filter by status
prev_fixations = estimated_scanpath.at[idx - 1, "fixations"]
if i_frame >= 1:
# Estimate the new fixation points using optical flow for
# fixations that just ended or are being tracked
prev_fixations = scanpath.iat[i_frame - 1, 0].copy()

prev_fixations = prev_fixations[
(prev_fixations["fixation status"] == "end")
| (prev_fixations["fixation status"] == "tracked")
]

if not prev_fixations.empty:
# Prepare points for tracking
prev_pts = np.array(
prev_fixations[["x", "y"]].dropna().values, dtype=np.float32
).reshape(-1, 1, 2)
prev_pts = (
prev_fixations[["gaze x [px]", "gaze y [px]"]]
.to_numpy().astype(np.float32)
.reshape(-1, 1, 2)
)
prev_ids = prev_fixations["fixation id"].values

# Calculate optical flow to find new positions of the points
curr_pts, status, err = cv2.calcOpticalFlowPyrLK(
prev_frame, curr_frame, prev_pts, None, **lk_params
curr_pts, status, _ = cv2.calcOpticalFlowPyrLK(
prev_image, curr_image, prev_pts, None, **lk_params
)

# Update fixations for the current frame
curr_fixations = estimated_scanpath.at[idx, "fixations"].copy()
curr_fixations = scanpath.iloc[i_frame, 0].copy()

# Append new or updated fixation points
for i, (pt, s, e) in enumerate(zip(curr_pts, status, err)):
for i, (pt, s) in enumerate(zip(curr_pts, status)):
if s[0]: # Check if the point was successfully tracked
x, y = pt.ravel()
curr_fixations = curr_fixations._append(
{
"fixation id": prev_ids[i],
"x": x,
"y": y,
"fixation status": "tracked",
},
ignore_index=True,
)
fixation = pd.DataFrame({
"fixation id": prev_ids[i],
"gaze x [px]": x,
"gaze y [px]": y,
"fixation status": "tracked"
}, index=[prev_ids[i]])
else:
# Handle cases where the point could not be tracked
curr_fixations = curr_fixations._append(
{
"fixation id": prev_ids[i],
"x": None,
"y": None,
"fixation status": "lost",
},
ignore_index=True,
)

fixation = pd.DataFrame({
"fixation id": prev_ids[i],
"gaze x [px]": None,
"gaze y [px]": None,
"fixation status": "lost"
}, index=[prev_ids[i]])
curr_fixations = pd.concat([curr_fixations, fixation], ignore_index=True)

# Update the DataFrame with the modified fixations
estimated_scanpath.at[idx, "fixations"] = curr_fixations
scanpath.iat[i_frame, 0] = curr_fixations

# Update the previous frame for the next iteration
prev_frame = curr_frame

rec.scanpath = estimated_scanpath
prev_image = curr_image

return estimated_scanpath
return scanpath


def overlay_scanpath_on_video(
Expand Down Expand Up @@ -254,7 +265,10 @@ def overlay_scanpath_on_video(
continue

for i in range(len(fixations)):
fixation_x, fixation_y = fixations.iloc[i]["x"], fixations.iloc[i]["y"]
fixation_x, fixation_y = (
fixations.iloc[i]["gaze x [px]"],
fixations.iloc[i]["gaze y [px]"],
)
status = fixations.iloc[i]["fixation status"]
id = fixations.iloc[i]["fixation id"]

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies = [
"opencv-python",
"typeguard",
"requests",
"tqdm",
]

[project.optional-dependencies]
Expand Down

0 comments on commit e5c3397

Please sign in to comment.