Skip to content

Commit

Permalink
Merge branch 'SensorFusion' of https://github.com/ncc-brain/PyNeon in…
Browse files Browse the repository at this point in the history
…to SensorFusion
  • Loading branch information
JGHartel committed Dec 17, 2024
2 parents a362c14 + d87f415 commit 61114bf
Show file tree
Hide file tree
Showing 6 changed files with 328 additions and 75 deletions.
2 changes: 1 addition & 1 deletion pyneon/preprocess/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def smooth_camera_pose(
A DataFrame with 'frame_idx' and 'smoothed_camera_pos'.
"""
# Ensure the DataFrame is sorted by frame_idx
camera_position_raw = camera_position_raw.sort_values('frame_idx')
camera_position_raw = camera_position_raw.sort_values("frame_idx")

# Extract all frame indices and create a complete range
all_frames = np.arange(camera_position_raw['frame_idx'].min(),
Expand Down
39 changes: 17 additions & 22 deletions pyneon/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,14 +466,11 @@ def estimate_scanpath(
if (video := self.video) is None:
raise ValueError("Estimating scanpath requires video data.")
return estimate_scanpath(video, sync_gaze, lk_params)

def detect_apriltags(
self,
tag_family: str ='tag36h11'
) -> pd.DataFrame:

def detect_apriltags(self, tag_family: str = "tag36h11") -> pd.DataFrame:
"""
Detect AprilTags in a video and report their data for every frame using the apriltag library.
Parameters
----------
tag_family : str, optional
Expand All @@ -494,11 +491,12 @@ def detect_apriltags(

all_detections = detect_apriltags(self.video, tag_family)
# Save to JSON
all_detections.to_json(self.recording_dir / "apriltags.json", orient="records", lines=True)
all_detections.to_json(
self.recording_dir / "apriltags.json", orient="records", lines=True
)

return all_detections


def compute_camera_positions(
self,
tag_locations_df: pd.DataFrame,
Expand Down Expand Up @@ -556,19 +554,17 @@ def compute_camera_positions(
if json_file.is_file() and not overwrite:
return pd.read_json(json_file, orient="records")

# Compute camera positions
camera_positions = compute_camera_positions(
video=self.video,
tag_locations_df=tag_locations_df,
all_detections=all_detections
self.video, tag_locations, tag_size, all_detections
)

# Save to JSON
camera_positions.to_json(json_file, orient="records")
camera_positions.to_json(
self.recording_dir / "camera_positions.json", orient="records"
)

return camera_positions

def smooth_camera_pose(
def smooth_camera_positions(
self,
camera_position_raw: pd.DataFrame = pd.DataFrame(),
state_dim: int = 3,
Expand All @@ -577,7 +573,6 @@ def smooth_camera_pose(
process_noise: float = 0.005,
measurement_noise: float = 0.005,
gating_threshold: float = 3.0,
bidirectional: bool = False,
) -> pd.DataFrame:
"""
Apply a Kalman filter to smooth camera positions and gate outliers based on Mahalanobis distance.
Expand Down Expand Up @@ -611,9 +606,9 @@ def smooth_camera_pose(
if (json_file := self.recording_dir / "camera_positions.json").is_file():
camera_position_raw = pd.read_json(json_file, orient="records")
# Ensure 'camera_pos' is parsed as NumPy arrays
camera_position_raw['camera_pos'] = camera_position_raw['camera_pos'].apply(
lambda pos: np.array(pos, dtype=float)
)
camera_position_raw["camera_pos"] = camera_position_raw[
"camera_pos"
].apply(lambda pos: np.array(pos, dtype=float))
else:
# Run the function to get the data
camera_position_raw = self.compute_camera_positions()
Expand All @@ -626,14 +621,14 @@ def smooth_camera_pose(
process_noise,
measurement_noise,
gating_threshold,
bidirectional
)

# Save to JSON
smoothed_positions.to_json(self.recording_dir / "smoothed_camera_positions.json", orient="records")
smoothed_positions.to_json(
self.recording_dir / "camera_positions.json", orient="records"
)

return smoothed_positions


def plot_scanpath_on_video(
self,
Expand Down
66 changes: 34 additions & 32 deletions pyneon/video/apriltags.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,11 @@
from ..recording import NeonRecording
from .video import NeonVideo

def detect_apriltags(
video: "NeonVideo",
tag_family: str ='tag36h11'
):


def detect_apriltags(video: "NeonVideo", tag_family: str = "tag36h11"):
"""
Detect AprilTags in a video and report their data for every frame using the apriltag library.
Parameters
----------
video : cv2.VideoCapture or similar video object
Expand All @@ -45,44 +42,47 @@ def detect_apriltags(

all_detections = []
frame_idx = 0

while True:
ret, frame = video.read()
if not ret:
break

# Convert frame to grayscale for detection
gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

# Detect AprilTags
detections = detector.detect(gray_frame)

for detection in detections:
# Extract the tag ID and corners
tag_id = detection.tag_id
corners = detection.corners

# Calculate the center of the tag
center = np.mean(corners, axis=0)

# Store the detection data
all_detections.append({
"frame_idx": frame_idx,
"tag_id": tag_id,
"corners": corners,
"center": center
})

all_detections.append(
{
"frame_idx": frame_idx,
"tag_id": tag_id,
"corners": corners,
"center": center,
}
)

frame_idx += 1

video.release()

# convert to pandas DataFrame
all_detections = pd.DataFrame(all_detections)

return all_detections



def compute_camera_positions(
video: "NeonVideo",
tag_locations_df: pd.DataFrame,
Expand Down Expand Up @@ -203,10 +203,7 @@ def compute_camera_positions(
image_points = np.array(image_points, dtype=np.float32)

success, rotation_vector, translation_vector = cv2.solvePnP(
object_points,
image_points,
camera_matrix,
dist_coeffs
object_points, image_points, camera_matrix, dist_coeffs
)

if not success:
Expand All @@ -215,11 +212,16 @@ def compute_camera_positions(
R, _ = cv2.Rodrigues(rotation_vector)
camera_pos = -R.T @ translation_vector

results.append({
"frame_idx": frame,
"translation_vector": translation_vector.reshape(-1),
"rotation_vector": rotation_vector.reshape(-1),
"camera_pos": camera_pos.reshape(-1)
})
results.append(
{
"frame_idx": frame,
"translation_vector": translation_vector.reshape(-1),
"rotation_vector": rotation_vector.reshape(-1),
"camera_pos": camera_pos.reshape(-1),
}
)

return pd.DataFrame(results, columns=["frame_idx", "translation_vector", "rotation_vector", "camera_pos"])
return pd.DataFrame(
results,
columns=["frame_idx", "translation_vector", "rotation_vector", "camera_pos"],
)
1 change: 1 addition & 0 deletions pyneon/video/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ..vis import plot_frame


class NeonVideo(cv2.VideoCapture):
"""
Loaded video file with timestamps.
Expand Down
44 changes: 24 additions & 20 deletions source/tutorials/apriltag_detection.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,17 @@
],
"source": [
"tag_locations = {\n",
" 0: [0.0, -0.25, 0.2],\n",
" 1: [0.0, 0.25, 0.2],\n",
" 2: [0.0, -0.25, -0.2],\n",
" 3: [0.0, 0.25, -0.2]\n",
" }\n",
" 0: [0.0, -0.25, 0.2],\n",
" 1: [0.0, 0.25, 0.2],\n",
" 2: [0.0, -0.25, -0.2],\n",
" 3: [0.0, 0.25, -0.2],\n",
"}\n",
"\n",
"# Size of each tag (e.g., 0.2 meters, meaning 20 cm each side)\n",
"tag_size = 0.075\n",
"\n",
"camera_position = recording.compute_camera_positions(tag_locations, tag_size)\n",
"print(camera_position.columns)\n"
"print(camera_position.columns)"
]
},
{
Expand All @@ -93,21 +93,21 @@
}
],
"source": [
"#plot the trajectory in xy\n",
"# plot the trajectory in xy\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"x = (camera_position['camera_pos'].apply(lambda x: x[0]).values) # Extract x values\n",
"y = camera_position['camera_pos'].apply(lambda x: x[1]).values # Extract y values\n",
"x = camera_position[\"camera_pos\"].apply(lambda x: x[0]).values # Extract x values\n",
"y = camera_position[\"camera_pos\"].apply(lambda x: x[1]).values # Extract y values\n",
"colors = np.arange(len(x)) # Create a color array based on the index\n",
"\n",
"fig = plt.figure()\n",
"ax = fig.add_subplot(111)\n",
"scatter = ax.scatter(x, y, c=colors, cmap='viridis')\n",
"plt.colorbar(scatter, label='Index')\n",
"ax.set_aspect('equal', 'box')\n",
"plt.show()\n"
"scatter = ax.scatter(x, y, c=colors, cmap=\"viridis\")\n",
"plt.colorbar(scatter, label=\"Index\")\n",
"ax.set_aspect(\"equal\", \"box\")\n",
"plt.show()"
]
},
{
Expand Down Expand Up @@ -143,21 +143,25 @@
}
],
"source": [
"#plot the trajectory in xy\n",
"# plot the trajectory in xy\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"x = (camera_position['smoothed_camera_pos'].apply(lambda x: x[0]).values) # Extract x values\n",
"y = camera_position['smoothed_camera_pos'].apply(lambda x: x[1]).values # Extract y values\n",
"x = (\n",
" camera_position[\"smoothed_camera_pos\"].apply(lambda x: x[0]).values\n",
") # Extract x values\n",
"y = (\n",
" camera_position[\"smoothed_camera_pos\"].apply(lambda x: x[1]).values\n",
") # Extract y values\n",
"colors = np.arange(len(x)) # Create a color array based on the index\n",
"\n",
"fig = plt.figure()\n",
"ax = fig.add_subplot(111)\n",
"scatter = ax.scatter(x, y, c=colors, cmap='viridis')\n",
"plt.colorbar(scatter, label='Index')\n",
"ax.set_aspect('equal', 'box')\n",
"plt.show()\n"
"scatter = ax.scatter(x, y, c=colors, cmap=\"viridis\")\n",
"plt.colorbar(scatter, label=\"Index\")\n",
"ax.set_aspect(\"equal\", \"box\")\n",
"plt.show()"
]
}
],
Expand Down
Loading

0 comments on commit 61114bf

Please sign in to comment.