Skip to content

Instantly share code, notes, and snippets.

@mikelgg93
Last active October 23, 2025 09:54
Show Gist options
  • Select an option

  • Save mikelgg93/7355a22d3502249328b43ad150b2e2d9 to your computer and use it in GitHub Desktop.

Select an option

Save mikelgg93/7355a22d3502249328b43ad150b2e2d9 to your computer and use it in GitHub Desktop.
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "click",
# "opencv-python",
# "pupil-labs-neon-recording",
# "tqdm",
# "ultralytics",
# "rich",
# "scipy",
# "mediapipe",
# "requests",
# "depth-pro",
# "rfdetr",
# ]
#
# [tool.uv.sources]
# depth-pro = { git = "https://github.com/apple/ml-depth-pro.git" }
# ///
import collections.abc
import shutil
import subprocess
from collections.abc import Generator
from itertools import pairwise
from typing import Any
import click
import cv2
import depth_pro
import mediapipe as mpipe
import numpy as np
import pupil_labs.video as plv
import supervision as sv
import torch
from mediapipe import solutions
from mediapipe.framework.formats import landmark_pb2
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from pupil_labs.neon_recording import NeonRecording
from pupil_labs.neon_recording.calib import Calibration
from pupil_labs.neon_recording.timeseries.av.video import GrayFrame
from pupil_labs.neon_recording.timeseries.blinks import BlinkRecord
from pupil_labs.neon_recording.timeseries.eyeball import EyeballRecord
from pupil_labs.neon_recording.timeseries.eyelid import EyelidRecord
from pupil_labs.neon_recording.timeseries.gaze import GazeRecord
from pupil_labs.neon_recording.timeseries.pupil import PupilRecord
from pupil_labs.video.frame import VideoFrame
from rfdetr import RFDETRBase
from rfdetr.util.coco_classes import COCO_CLASSES
from rich.console import Console
from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TextColumn,
TimeRemainingColumn,
)
from rich.table import Table
from scipy.spatial.distance import cdist
from ultralytics import RTDETR, YOLO
from ultralytics.utils.downloads import download
from upath import UPath
console: Console = Console()
OFFSET = np.array([0, 0]) # x,y offset to apply to gaze points
class InterSamples:
def __init__(self, timeseries):
self.timeseries = timeseries
class RecordingWithSampler(NeonRecording):
"""An extended version of NeonRecording that includes a custom timeserie sampler method."""
def sections_sampler( # noqa: C901
self,
*timeseries,
fps: int,
start_event_name: str = "recording.begin",
end_event_name: str = "recording.end",
) -> tuple[Generator[tuple[Any, ...], None, None], np.ndarray]:
"""Creates a synchronized generator for specified data timeseries.
This method samples 'standard' timeseries at a fixed FPS and collects all data
points for 'grouped' timeseries (wrapped in `InterSamples`) that fall between
the fixed samples.
Args:
*timeseries: A sequence of stream objects from the recording (e.g., self.gaze).
Wrap timeseries in `InterSamples()` to get all values between frames.
fps: The frames per second to sample the standard timeseries at.
start_event_name: The event name marking the beginning of a section.
end_event_name: The event name marking the end of a section.
Returns:
A generator yielding tuples of the format:
(section_start_diff, timestamp, std_stream_1_value, std_stream_2_value,
..., grouped_stream_1_data, grouped_stream_2_data, ...)
The `section_start_diff` is the difference in timestamps from the start of
the section to the current timestamp.
The `timestamp` is the timestamp of the current frame in nanoseconds.
The `std_stream_*_value` are the sampled values from the standard timeseries.
The `grouped_stream_*_data` are lists of tuples containing timestamps and
values for the grouped timeseries.
The `combined_output_time` is a numpy array of timestamps for all frames to be
sampled.
"""
std_timeseries = []
grouped_timeseries = []
for s in timeseries:
if isinstance(s, InterSamples):
grouped_timeseries.append(s.timeseries)
else:
std_timeseries.append(s)
if (
start_event_name not in self.events.by_name
or end_event_name not in self.events.by_name
):
raise ValueError(
f"Event '{start_event_name}' or '{end_event_name}' not found."
)
start_events_time = self.events.by_name[start_event_name]
end_events_time = self.events.by_name[end_event_name]
sections = []
for start_time in start_events_time:
end_time = next((et for et in end_events_time if et > start_time), None)
if end_time is not None:
sections.append((start_time, end_time))
if not sections:
raise ValueError("No valid start/end event pairs found.")
sections.sort()
all_time_arrays = [
np.arange(start_time, end_time, 1e9 / fps, dtype=np.int64)
for start_time, end_time in sections
]
combined_output_time = (
np.concatenate(all_time_arrays)
if any(len(arr) > 0 for arr in all_time_arrays)
else np.array([], dtype=np.int64)
)
def _generator_fn() -> Generator[tuple[Any, ...], None, None]:
total_prev_length = 0
for section_idx, (start_time, end_time) in enumerate(sections):
section_output_time = all_time_arrays[section_idx]
if not len(section_output_time):
continue
section_start_diff = -start_time + total_prev_length
section_grouped_data = []
for timeserie in grouped_timeseries:
if timeserie and len(timeserie.time) > 0:
sec_start_idx = np.searchsorted(
timeserie.time, start_time, side="left"
)
sec_end_idx = np.searchsorted(
timeserie.time, end_time, side="right"
)
ts_slice = timeserie.time[sec_start_idx:sec_end_idx]
if len(ts_slice) > 0:
frames_slice = timeserie.sample(ts_slice)
section_grouped_data.append({
"time": ts_slice,
"frames": frames_slice,
"search_idx": 0,
})
else:
section_grouped_data.append(None)
else:
section_grouped_data.append(None)
std_samples_list = [
s.sample(section_output_time)
if s
else [None] * len(section_output_time)
for s in std_timeseries
]
def _get_section_grouped_samples(end_frame_time: np.int64) -> list:
yield_values = []
for g_data in section_grouped_data:
if g_data:
start_idx = g_data["search_idx"]
end_idx = np.searchsorted(
g_data["time"], end_frame_time, side="right"
)
ts_slice = g_data["time"][start_idx:end_idx]
frames_slice = g_data["frames"][start_idx:end_idx]
yield_values.append((ts_slice, frames_slice))
g_data["search_idx"] = end_idx
else:
yield_values.append((np.array([], dtype=np.int64), []))
return yield_values
for i, (t_start, t_end) in enumerate(pairwise(section_output_time)):
current_std_values = [s[i] for s in std_samples_list]
grouped_yield_values = _get_section_grouped_samples(t_end)
yield (
section_start_diff,
t_start,
*current_std_values,
*grouped_yield_values,
)
last_t = section_output_time[-1]
last_std_values = [s[-1] for s in std_samples_list]
grouped_yield_values = _get_section_grouped_samples(end_time)
yield (
section_start_diff,
last_t,
*last_std_values,
*grouped_yield_values,
)
total_prev_length += end_time - start_time
return _generator_fn(), combined_output_time
MODEL_CONFIG: dict[str, Any] = {
"hand": {
"name": "hand_landmarker.task",
"url": "https://storage.googleapis.com/mediapipe-models/hand_landmarker/hand_landmarker/float16/1/hand_landmarker.task",
"options": {
"num_hands": 4,
"min_hand_detection_confidence": 0.25,
"min_hand_presence_confidence": 0.25,
},
"model_name": "MediaPipe Hand Landmarker",
},
"pose": {
"name": "pose_landmarker_heavy.task",
"url": "https://storage.googleapis.com/mediapipe-models/pose_landmarker/pose_landmarker_heavy/float16/1/pose_landmarker_heavy.task",
"options": {
"num_poses": 1,
"min_pose_detection_confidence": 0.5,
"min_pose_presence_confidence": 0.25,
},
"model_name": "MediaPipe Pose Landmarker",
},
"depth": {
"name": "depth_pro.pt",
"url": "https://ml-site.cdn-apple.com/models/depth-pro/depth_pro.pt",
"model_name": "ml-depth-pro",
},
"facial_landmarks": {
"name": "face_landmarker.task",
"url": "https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task",
"options": {
"num_faces": 1,
"min_face_detection_confidence": 0.5,
"min_face_presence_confidence": 0.5,
"min_tracking_confidence": 0.5,
},
"model_name": "MediaPipe Face Landmarker",
},
"ultralytics": {},
"rfdetr": {},
}
def _deep_merge(base: dict, updates: dict) -> dict:
"""Recursively merges dictionaries, with `updates` taking precedence."""
merged = base.copy()
for key, value in updates.items():
if isinstance(merged.get(key), dict) and isinstance(
value, collections.abc.Mapping
):
merged[key] = _deep_merge(merged[key], value)
else:
merged[key] = value
return merged
def _fill_model_params(params: dict[str, Any]) -> dict:
filled = {}
for key, value in params.items():
filled[key] = _deep_merge(MODEL_CONFIG[key], value)
return filled
def _initialize_detectors(
device: str, model_params: dict[str, Any], base_path: UPath = UPath("./checkpoints")
) -> dict[
str,
vision.HandLandmarker
| vision.PoseLandmarker
| vision.FaceLandmarker
| YOLO
| RTDETR
| tuple[depth_pro.depth_pro.DepthPro, depth_pro.depth_pro.Compose]
| RFDETRBase
| None,
]:
detectors = {}
for model_key, config in model_params.items():
console.log(f"Initializing {config['model_name']}...")
if model_key in ["hand", "pose", "facial_landmarks", "depth"]:
model_path = UPath(base_path / config["name"])
if not model_path.exists():
console.log(
f"Downloading {config['model_name']} model to {model_path}..."
)
download(
config["url"],
dir=base_path,
)
if model_key in ["hand", "pose", "facial_landmarks"]:
delegate = (
mpipe.tasks.BaseOptions.Delegate.GPU
if "cuda" in device or "gpu" in device
else mpipe.tasks.BaseOptions.Delegate.CPU
)
base_options = python.BaseOptions(
model_asset_path=str(model_path), delegate=delegate
)
elif model_key == "depth":
model, depth_transform = depth_pro.create_model_and_transforms(
device=device, precision=torch.half
)
model.eval()
if model_key == "hand":
hand_options = vision.HandLandmarkerOptions(
base_options=base_options,
running_mode=mpipe.tasks.vision.RunningMode.VIDEO,
**config["options"],
)
model = vision.HandLandmarker.create_from_options(hand_options)
elif model_key == "pose":
pose_options = vision.PoseLandmarkerOptions(
base_options=base_options,
running_mode=mpipe.tasks.vision.RunningMode.VIDEO,
**config["options"],
)
model = vision.PoseLandmarker.create_from_options(pose_options)
elif model_key == "facial_landmarks":
face_options = vision.FaceLandmarkerOptions(
base_options=base_options,
running_mode=mpipe.tasks.vision.RunningMode.VIDEO,
**config["options"],
)
model = vision.FaceLandmarker.create_from_options(face_options)
elif model_key == "ultralytics":
model_path = UPath(base_path / config["model_name"])
if not model_path.exists():
model = YOLO(config["model_name"])
else:
model = YOLO(model_path)
elif model_key == "rfdetr":
model = RFDETRBase(
model_name=config["model_name"],
device=device,
classes=COCO_CLASSES,
)
model.optimize_for_inference()
detectors[model_key] = (
model if model_key != "depth" else (model, depth_transform)
)
if len(detectors) < len(model_params):
console.log(
"[bold red]Failed to initialize all detectors. Some models may not be available.[/bold red]"
)
return detectors
def _print_arguments_table():
ctx = click.get_current_context()
table = Table(
title="Arguments Provided",
show_header=True,
header_style="bold magenta",
)
table.add_column("Parameter", style="cyan")
table.add_column("Value Provided", style="white")
provided_params = [
(param_name, ctx.params[param_name]) for param_name in ctx.params
]
if not provided_params:
console.print(
"[yellow]No custom parameters were provided. Using all defaults.[/yellow]"
)
return
for key, value in provided_params:
table.add_row(key, str(value))
console.print(table)
def _predict_depth(
frame: np.ndarray,
f_px: float,
transform: depth_pro.depth_pro.Compose,
model: depth_pro.depth_pro.DepthPro,
) -> np.ndarray:
"""Predicts depth based on single frame data using ml-depth-pro."""
transformed = transform(frame[:, :, :3])
if not isinstance(transformed, torch.Tensor):
transformed = torch.from_numpy(transformed)
prediction = model.infer(transformed, f_px=f_px)
depth = prediction["depth"].detach().cpu().numpy().squeeze()
return depth
def _calculate_visual_angle(
p1: tuple[float, float], p2: tuple[float, float], calibration: Calibration
) -> float:
K = calibration.scene_camera_matrix
D = calibration.scene_distortion_coefficients
points_distorted = np.array([p1, p2], dtype=np.float32).reshape(-1, 1, 2)
points_undistorted_normalized = cv2.undistortPoints(points_distorted, K, D)
vec1 = np.array([
points_undistorted_normalized[0, 0, 0],
points_undistorted_normalized[0, 0, 1],
1.0,
])
vec2 = np.array([
points_undistorted_normalized[1, 0, 0],
points_undistorted_normalized[1, 0, 1],
1.0,
])
vec1_normalized = vec1 / np.linalg.norm(vec1)
vec2_normalized = vec2 / np.linalg.norm(vec2)
dot_product = np.clip(np.dot(vec1_normalized, vec2_normalized), -1.0, 1.0)
angle_rad = np.arccos(dot_product)
return np.rad2deg(angle_rad)
def _annotate_frame(
frame: np.ndarray,
gaze: GazeRecord,
blink: BlinkRecord | None,
results: dict[
str,
vision.HandLandmarkerResult
| vision.PoseLandmarkerResult
| vision.FaceLandmarkerResult
| sv.Detections
| np.ndarray
| None,
],
eye_frame: VideoFrame,
time: np.int64,
fps: int,
calibration: Calibration | None = None,
) -> tuple[np.ndarray, dict[str, float]]:
"""Annotates the frame with gaze, blink, hand landmarks, and ball detection."""
global OFFSET
object_center = None
right_hand_data = {}
left_hand_data = {}
gaze_point = np.array([gaze.point[0] + OFFSET[0], gaze.point[1] + OFFSET[1]])
if results.get("depth") is not None:
depth_vis = np.nan_to_num(results.get("depth").astype(np.float32), nan=0.0)
depth_vis_clipped = np.clip(depth_vis, 0, 10) # For visualization purposes
depth_vis_norm = ((depth_vis_clipped / 10.0) * 255).astype(np.uint8)
depth_vis_color = cv2.applyColorMap(depth_vis_norm, cv2.COLORMAP_JET)
# Half - blend the depth visualization with the original frame
# h, w = frame.shape[:2]
# margin_width = 100
# center_x = w // 2
# margin_start = center_x - (margin_width // 2)
# margin_end = center_x + (margin_width // 2)
# alpha_mask = np.full((h, w), 0.7, dtype=np.float32)
# gradient = np.linspace(0.7, 0.0, margin_width)
# alpha_mask[:, margin_start:margin_end] = gradient
# alpha_mask[:, margin_end:] = 0.0
# alpha_3d = alpha_mask[:, :, np.newaxis]
# frame_float = frame.astype(np.float32)
# depth_vis_float = depth_vis_color.astype(np.float32)
# blended_float = depth_vis_float * alpha_3d + frame_float * (1 - alpha_3d)
# frame = blended_float.astype(np.uint8)
frame = cv2.addWeighted(depth_vis_color, 0.9, frame, 0.1, 0)
if results.get("object") is not None:
object_result = results["object"]
frame = sv.BoxCornerAnnotator(color=sv.Color.ROBOFLOW).annotate(
scene=frame, detections=object_result
)
frame = sv.DotAnnotator(color=sv.Color.BLUE, radius=10).annotate(
scene=frame, detections=object_result
)
frame = sv.RichLabelAnnotator(
text_position=sv.Position.TOP_RIGHT, color=sv.Color.ROBOFLOW
).annotate(
scene=frame,
detections=object_result,
labels=[
f"{object_result.data['class_name'][0]} {object_result.confidence[0]:.2f}"
]
if hasattr(object_result.data, "class_name")
else [
f"{COCO_CLASSES[class_id]} {confidence:.2f}"
for class_id, confidence in zip(
object_result.class_id, object_result.confidence, strict=False
)
],
)
object_center = (
object_result.get_anchors_coordinates(sv.Position.CENTER)[0]
if object_result.xyxy.size != 0
else None
)
if object_result.mask is not None:
frame = sv.MaskAnnotator(color=sv.ColorPalette.ROBOFLOW).annotate(
scene=frame, detections=object_result
)
if results.get("hand") and results["hand"].hand_landmarks:
hand_result = results["hand"]
for i, hand_landmarks in enumerate(hand_result.hand_landmarks):
hand_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
hand_landmarks_proto.landmark.extend([
landmark_pb2.NormalizedLandmark(x=lm.x, y=lm.y, z=lm.z)
for lm in hand_landmarks
])
solutions.drawing_utils.draw_landmarks(
frame,
hand_landmarks_proto,
solutions.hands.HAND_CONNECTIONS,
solutions.drawing_styles.get_default_hand_landmarks_style(),
solutions.drawing_styles.get_default_hand_connections_style(),
)
landmark_coords = np.array([
(lm.x * frame.shape[1], lm.y * frame.shape[0]) for lm in hand_landmarks
])
handedness = hand_result.handedness[i][0].category_name.lower()
dists_to_gaze = cdist(landmark_coords, [gaze_point])
closest_gaze_idx = np.argmin(dists_to_gaze)
closest_gaze_coord = tuple(map(int, landmark_coords[closest_gaze_idx]))
current_hand_data = {"min_dist_to_gaze": dists_to_gaze[closest_gaze_idx][0]}
if object_center is not None:
dists_to_obj = cdist(landmark_coords, [object_center])
closest_obj_idx = np.argmin(dists_to_obj)
closest_obj_coord = tuple(map(int, landmark_coords[closest_obj_idx]))
current_hand_data["min_dist_to_obj"] = dists_to_obj[closest_obj_idx][0]
cv2.line(
frame,
tuple(map(int, object_center)),
closest_obj_coord,
(0, 255, 255),
2,
)
cv2.line(
frame,
closest_gaze_coord,
tuple(map(int, object_center)),
(0, 128, 255),
2,
)
cv2.line(
frame,
closest_obj_coord,
tuple(map(int, gaze_point)),
(128, 0, 255),
2,
)
cv2.line(
frame, tuple(map(int, gaze_point)), closest_gaze_coord, (255, 0, 0), 2
)
if handedness == "right":
right_hand_data = current_hand_data
elif handedness == "left":
left_hand_data = current_hand_data
if results.get("pose") and results["pose"].pose_landmarks:
pose_result = results["pose"]
for _, pose_landmarks in enumerate(pose_result.pose_landmarks):
pose_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
pose_landmarks_proto.landmark.extend([
landmark_pb2.NormalizedLandmark(x=lm.x, y=lm.y, z=lm.z)
for lm in pose_landmarks
])
solutions.drawing_utils.draw_landmarks(
frame,
pose_landmarks_proto,
solutions.pose.POSE_CONNECTIONS,
solutions.drawing_styles.get_default_pose_landmarks_style(),
)
if results.get("facial_landmarks") and results["facial_landmarks"].face_landmarks:
face_result = results["facial_landmarks"]
for face_landmarks in face_result.face_landmarks:
face_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
face_landmarks_proto.landmark.extend([
landmark_pb2.NormalizedLandmark(x=lm.x, y=lm.y, z=lm.z)
for lm in face_landmarks
])
solutions.drawing_utils.draw_landmarks(
image=frame,
landmark_list=face_landmarks_proto,
connections=solutions.face_mesh.FACEMESH_TESSELATION,
landmark_drawing_spec=None,
connection_drawing_spec=solutions.drawing_styles.get_default_face_mesh_tesselation_style(),
)
solutions.drawing_utils.draw_landmarks(
image=frame,
landmark_list=face_landmarks_proto,
connections=solutions.face_mesh.FACEMESH_CONTOURS,
landmark_drawing_spec=None,
connection_drawing_spec=solutions.drawing_styles.get_default_face_mesh_contours_style(
1
),
)
solutions.drawing_utils.draw_landmarks(
image=frame,
landmark_list=face_landmarks_proto,
connections=solutions.face_mesh.FACEMESH_IRISES,
landmark_drawing_spec=None,
connection_drawing_spec=solutions.drawing_styles.get_default_face_mesh_iris_connections_style(),
)
if abs(gaze.time - time) < 2e9 / fps:
color = (
(128, 128, 128)
if blink and blink.start_time < time < blink.end_time
else (0, 0, 255)
)
cv2.circle(frame, tuple(map(int, gaze_point)), 16, color, 4)
if results.get("depth") is not None:
depth_result = results["depth"]
cv2.putText(
frame,
f"{depth_result[int(gaze_point[1]), int(gaze_point[0])]:.2f} m",
(int(gaze_point[0] + 20), int(gaze_point[1] + 10)),
cv2.FONT_HERSHEY_SIMPLEX,
0.8,
(255, 255, 255),
2,
)
if (
abs(eye_frame.time - time) < 2e9 / fps
and eye_frame.bgr.shape[0] > 0
and eye_frame.bgr.shape[1] > 0
):
h, w = eye_frame.bgr.shape[:2]
mask = np.zeros((h, w), dtype=np.uint8)
radius = 5
cv2.rectangle(mask, (radius, 0), (w - radius, h), 255, -1)
cv2.rectangle(mask, (0, radius), (w, h - radius), 255, -1)
cv2.circle(mask, (radius, radius), radius, 255, -1)
cv2.circle(mask, (w - radius - 1, radius), radius, 255, -1)
cv2.circle(mask, (radius, h - radius - 1), radius, 255, -1)
cv2.circle(mask, (w - radius - 1, h - radius - 1), radius, 255, -1)
roi = frame[15 : h + 15, 15 : w + 15]
eye_bgr_masked = cv2.bitwise_and(eye_frame.bgr, eye_frame.bgr, mask=mask)
roi_bg = cv2.bitwise_and(roi, roi, mask=cv2.bitwise_not(mask))
combined = cv2.add(roi_bg, eye_bgr_masked)
frame[10 : h + 10, 10 : w + 10] = combined
distances: dict[str, float] = {
"gaze_to_ball": cdist([gaze_point], [object_center])[0][0]
if object_center is not None
else np.nan,
"gaze_to_ball_angle": (
_calculate_visual_angle(gaze_point, object_center, calibration)
if object_center is not None and calibration is not None
else np.nan
),
"gaze_to_right_hand": right_hand_data.get("min_dist_to_gaze", np.nan),
"gaze_to_left_hand": left_hand_data.get("min_dist_to_gaze", np.nan),
"ball_to_right_hand": right_hand_data.get("min_dist_to_ball", np.nan),
"ball_to_left_hand": left_hand_data.get("min_dist_to_ball", np.nan),
}
return frame, distances
def _ns_to_vtt_time(time: np.int64) -> str:
total_ms = time // 1_000_000
h, rem = divmod(total_ms, 3_600_000)
m, rem = divmod(rem, 60_000)
s, ms = divmod(rem, 1_000)
return f"{h:02d}:{m:02d}:{s:02d}.{ms:03d}"
def _build_subtitle_text(
time: np.int64,
initial_time: int,
gaze: GazeRecord,
eyeball: EyeballRecord,
pupil: PupilRecord,
eyelid: EyelidRecord,
distances: dict[str, float],
) -> str:
lines = [f"Time: {_ns_to_vtt_time(time - initial_time)}"]
lines.append(
f"Gaze: ({int(gaze.point[0] + OFFSET[0]):04d}, {int(gaze.point[1] + OFFSET[1]):04d})"
)
subtitle_data = {
"Pupil Size": (
f"{np.nanmean(pupil.diameter_left):.2f}",
"mm",
),
"Eyelid Aperture": (
f"{np.nanmean(eyelid.aperture_left):.2f}",
"mm",
),
"Object2Gaze Distance": (
(
f"{np.nanmean(distances.get('gaze_to_ball')):.2f}px / "
f"{np.nanmean(distances.get('gaze_to_ball_angle')):.2f}°"
),
"",
),
"Gaze2Hand Distances (R/L)": (
(
f"{np.nanmean(distances.get('gaze_to_right_hand')):.2f} / "
f"{np.nanmean(distances.get('gaze_to_left_hand')):.2f}"
),
"px",
),
"Object2Hand Distances (R/L)": (
(
f"{np.nanmean(distances.get('ball_to_right_hand')):.2f} / "
f"{np.nanmean(distances.get('ball_to_left_hand')):.2f}"
),
"px",
),
}
for label, (value, unit) in subtitle_data.items():
if value is not None:
lines.append(f"{label}: {value} {unit}")
else:
lines.append(f"{label}: N/A")
return "\n".join(lines)
@click.command()
@click.argument("recording_directory", type=click.Path(exists=True))
@click.option(
"--output-filename", default="output_vid", help="Base name for output files."
)
@click.option(
"--sub", is_flag=True, default=False, help="Include subtitles in output video."
)
@click.option(
"--preview",
is_flag=True,
default=False,
help="Show live preview of annotated frames.",
)
@click.option(
"--device",
"device_str",
default="cpu",
show_default=True,
help='Device for inference, e.g., "cpu", "cuda", "mps".',
)
@click.option(
"--detectors",
"detectors_str",
default="hand,pose,ultralytics,facial_landmarks",
show_default=True,
help=(
"Comma-separated list of detectors to use: 'hand', 'pose', 'ultralytics',"
"'depth', 'facial_landmarks'."
),
)
@click.option(
"--conf",
"conf_threshold",
default=0.25,
show_default=True,
help="Confidence threshold for object detection.",
)
@click.option(
"--model",
"model",
default="rtdetr-l.pt",
show_default=True,
help=(
"Ultralytics / RF-DETR model to use for detection. (yolo11x-seg.pt, "
"rtdetr-l.pt, rfdetr-large)"
),
)
@click.option(
"--class-id",
"class_id",
default=32,
show_default=True,
help="Class ID to detect. Default is 32 ('sports ball' in COCO, 37 in RF-DETR.)",
)
@click.option(
"--offset",
"offset_str",
default="0,0",
show_default=True,
help="Gaze offset in pixels, format: 'x,y'.",
)
@click.option(
"--start",
"start_event",
type=click.STRING,
default="recording.begin",
show_default=True,
help="Event name or timestamp to start processing from.",
)
@click.option(
"--end",
"end_event",
type=click.STRING,
default="recording.end",
show_default=True,
help="Event name or timestamp to stop processing at.",
)
@click.option(
"--fps-mode",
type=click.Choice(["scene", "gaze"], case_sensitive=False),
default="scene",
show_default=True,
help="Choose FPS mode: 'scene' or 'gaze'.",
)
@click.option(
"--mute",
is_flag=True,
default=False,
help="Mute audio in the output video.",
)
def main(
recording_directory: str,
output_filename: str,
fps_mode: str,
sub: bool,
preview: bool,
detectors_str: str,
model: str,
device_str: str,
conf_threshold: float,
class_id: int,
offset_str: str,
start_event: str,
end_event: str,
mute: bool,
) -> None:
try:
x, y = map(int, offset_str.split(","))
global OFFSET
OFFSET = [x, y]
except ValueError:
console.log("[bold red]Invalid offset format. Use 'x,y'. Exiting.[/bold red]")
return
detectors_dict = {}
detector_map = {
"hand": {},
"pose": {},
"facial_landmarks": {},
"ultralytics": {"model_name": model},
"rfdetr": {"model_name": model},
"depth": {},
}
for det in (d.strip().lower() for d in detectors_str.split(",")):
if det in detector_map:
detectors_dict[det] = detector_map[det]
detectors_dict = _fill_model_params(detectors_dict)
detectors = _initialize_detectors(device_str, detectors_dict)
_print_arguments_table()
rec = RecordingWithSampler(recording_directory)
ns_mean_isi = (
np.nanmean(np.diff(rec.gaze.time))
if fps_mode == "gaze"
else np.nanmean(np.diff(rec.scene.time))
)
fps = int(np.round(1 / (ns_mean_isi / 1e9)))
scene_fps = int(np.round(1 / (np.nanmean(np.diff(rec.scene.time)) / 1e9)))
combined, output_time = rec.sections_sampler(
rec.scene,
rec.gaze,
rec.blinks,
rec.eyeball,
rec.pupil,
rec.eyelid,
rec.eye,
InterSamples(rec.audio) if not mute else None,
fps=scene_fps,
start_event_name=start_event,
end_event_name=end_event,
)
subtitle_events = []
video_path = str(UPath(recording_directory) / f"{output_filename}.mp4")
progress_columns = [
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TextColumn("{task.completed} of {task.total} frames"),
TimeRemainingColumn(),
]
with Progress(*progress_columns, console=console) as progress:
task = progress.add_task("[cyan]Processing frames...", total=len(output_time))
with plv.Writer(video_path, bit_rate=8_000_000) as writer:
for (
local_time_diff,
time,
scene_frame,
gaze,
blink,
eyeball,
pupil,
eyelid,
eye_frame,
audio_samples,
) in combined:
if abs(scene_frame.time - time) < 2e9 / scene_fps:
frame = scene_frame.bgr
else:
frame = GrayFrame(scene_frame.width, scene_frame.height).bgr
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
mp_image = mpipe.Image(
image_format=mpipe.ImageFormat.SRGB, data=rgb_frame
)
results: dict[str, Any] = {}
mp_detector_names = ["hand", "pose", "facial_landmarks"]
for name in mp_detector_names:
if name in detectors:
results[name] = detectors[name].detect_for_video(
mp_image, time // 1000
)
else:
results[name] = None
if "ultralytics" in detectors:
results["ultralytics"] = detectors["ultralytics"].predict(
rgb_frame,
verbose=False,
device=device_str,
classes=[class_id],
conf=conf_threshold,
)[0]
else:
results["ultralytics"] = None
results["object"] = (
sv.Detections.from_ultralytics(results.get("ultralytics"))
if results.get("ultralytics")
else None
)
if "rfdetr" in detectors:
results["object"] = detectors["rfdetr"].predict(
rgb_frame,
device=device_str
if device_str != "mps"
else "cpu", # mps not supported by rfdetr
threshold=conf_threshold,
)
if results["object"] is not None:
results["object"] = results["object"][
results["object"].class_id == class_id
]
else:
results["object"] = results.get("object")
if "depth" in detectors:
results["depth"] = _predict_depth(
frame,
f_px=rec.calibration["scene_camera_matrix"][0, 0],
transform=detectors["depth"][1],
model=detectors["depth"][0],
)
else:
results["depth"] = None
frame, distances = _annotate_frame(
frame, gaze, blink, results, eye_frame, time, fps, rec.calibration
)
if preview:
cv2.imshow("Live Preview", frame)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
if not mute:
for a_time, audio_frame in zip(
audio_samples[0], audio_samples[1], strict=False
):
writer.write_frame(
audio_frame.plv_frame,
(a_time + local_time_diff) / 1e9,
)
writer.write_image(
frame,
(time + local_time_diff) / 1e9,
)
if sub:
text = _build_subtitle_text(
time, rec.start_time, gaze, eyeball, pupil, eyelid, distances
)
subtitle_events.append((
text,
int(time + local_time_diff),
int(time + local_time_diff + ns_mean_isi),
))
progress.update(task, advance=1)
console.log(f"Video saved to [cyan]{video_path}[/cyan]")
if preview:
cv2.destroyAllWindows()
if sub:
vtt_path = str(UPath(recording_directory) / f"{output_filename}.vtt")
with open(vtt_path, "w", encoding="utf-8") as vtt_file:
vtt_file.write("WEBVTT\n\n")
for text, start_ns, end_ns in subtitle_events:
vtt_file.write(
f"{_ns_to_vtt_time(start_ns)} --> {_ns_to_vtt_time(end_ns)}\n{text}\n\n"
)
console.log(f"VTT subtitles written to [cyan]{vtt_path}[/cyan]")
console.log("Embedding subtitles with FFMPEG...")
video_with_subs_path = str(
UPath(recording_directory) / f"{output_filename}_subbed.mp4"
)
ffmpeg_command = [
"ffmpeg",
"-i",
video_path,
"-i",
vtt_path,
"-c:v",
"copy",
"-c:a",
"copy",
"-c:s",
"mov_text",
"-map",
"0:v",
"-map",
"0:a?",
"-map",
"1",
"-metadata:s:s:0",
"language=eng",
"-y",
video_with_subs_path,
]
try:
subprocess.run(ffmpeg_command, check=True, capture_output=True, text=True)
shutil.move(video_with_subs_path, video_path)
# UPath(vtt_path).unlink()
console.log(
f"[green]Subtitles successfully embedded into {video_path}[/green]"
)
except FileNotFoundError:
console.log(
"[bold red]FFMPEG not found. Please install FFMPEG and ensure it is "
"in your system's PATH.[/bold red]"
)
except subprocess.CalledProcessError as e:
console.log(f"[bold red]FFMPEG error:\n{e.stderr}[/bold red]")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment