Last active
October 23, 2025 09:54
-
-
Save mikelgg93/7355a22d3502249328b43ad150b2e2d9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # requires-python = ">=3.11" | |
| # dependencies = [ | |
| # "click", | |
| # "opencv-python", | |
| # "pupil-labs-neon-recording", | |
| # "tqdm", | |
| # "ultralytics", | |
| # "rich", | |
| # "scipy", | |
| # "mediapipe", | |
| # "requests", | |
| # "depth-pro", | |
| # "rfdetr", | |
| # ] | |
| # | |
| # [tool.uv.sources] | |
| # depth-pro = { git = "https://github.com/apple/ml-depth-pro.git" } | |
| # /// | |
| import collections.abc | |
| import shutil | |
| import subprocess | |
| from collections.abc import Generator | |
| from itertools import pairwise | |
| from typing import Any | |
| import click | |
| import cv2 | |
| import depth_pro | |
| import mediapipe as mpipe | |
| import numpy as np | |
| import pupil_labs.video as plv | |
| import supervision as sv | |
| import torch | |
| from mediapipe import solutions | |
| from mediapipe.framework.formats import landmark_pb2 | |
| from mediapipe.tasks import python | |
| from mediapipe.tasks.python import vision | |
| from pupil_labs.neon_recording import NeonRecording | |
| from pupil_labs.neon_recording.calib import Calibration | |
| from pupil_labs.neon_recording.timeseries.av.video import GrayFrame | |
| from pupil_labs.neon_recording.timeseries.blinks import BlinkRecord | |
| from pupil_labs.neon_recording.timeseries.eyeball import EyeballRecord | |
| from pupil_labs.neon_recording.timeseries.eyelid import EyelidRecord | |
| from pupil_labs.neon_recording.timeseries.gaze import GazeRecord | |
| from pupil_labs.neon_recording.timeseries.pupil import PupilRecord | |
| from pupil_labs.video.frame import VideoFrame | |
| from rfdetr import RFDETRBase | |
| from rfdetr.util.coco_classes import COCO_CLASSES | |
| from rich.console import Console | |
| from rich.progress import ( | |
| BarColumn, | |
| Progress, | |
| SpinnerColumn, | |
| TextColumn, | |
| TimeRemainingColumn, | |
| ) | |
| from rich.table import Table | |
| from scipy.spatial.distance import cdist | |
| from ultralytics import RTDETR, YOLO | |
| from ultralytics.utils.downloads import download | |
| from upath import UPath | |
| console: Console = Console() | |
| OFFSET = np.array([0, 0]) # x,y offset to apply to gaze points | |
| class InterSamples: | |
| def __init__(self, timeseries): | |
| self.timeseries = timeseries | |
| class RecordingWithSampler(NeonRecording): | |
| """An extended version of NeonRecording that includes a custom timeserie sampler method.""" | |
| def sections_sampler( # noqa: C901 | |
| self, | |
| *timeseries, | |
| fps: int, | |
| start_event_name: str = "recording.begin", | |
| end_event_name: str = "recording.end", | |
| ) -> tuple[Generator[tuple[Any, ...], None, None], np.ndarray]: | |
| """Creates a synchronized generator for specified data timeseries. | |
| This method samples 'standard' timeseries at a fixed FPS and collects all data | |
| points for 'grouped' timeseries (wrapped in `InterSamples`) that fall between | |
| the fixed samples. | |
| Args: | |
| *timeseries: A sequence of stream objects from the recording (e.g., self.gaze). | |
| Wrap timeseries in `InterSamples()` to get all values between frames. | |
| fps: The frames per second to sample the standard timeseries at. | |
| start_event_name: The event name marking the beginning of a section. | |
| end_event_name: The event name marking the end of a section. | |
| Returns: | |
| A generator yielding tuples of the format: | |
| (section_start_diff, timestamp, std_stream_1_value, std_stream_2_value, | |
| ..., grouped_stream_1_data, grouped_stream_2_data, ...) | |
| The `section_start_diff` is the difference in timestamps from the start of | |
| the section to the current timestamp. | |
| The `timestamp` is the timestamp of the current frame in nanoseconds. | |
| The `std_stream_*_value` are the sampled values from the standard timeseries. | |
| The `grouped_stream_*_data` are lists of tuples containing timestamps and | |
| values for the grouped timeseries. | |
| The `combined_output_time` is a numpy array of timestamps for all frames to be | |
| sampled. | |
| """ | |
| std_timeseries = [] | |
| grouped_timeseries = [] | |
| for s in timeseries: | |
| if isinstance(s, InterSamples): | |
| grouped_timeseries.append(s.timeseries) | |
| else: | |
| std_timeseries.append(s) | |
| if ( | |
| start_event_name not in self.events.by_name | |
| or end_event_name not in self.events.by_name | |
| ): | |
| raise ValueError( | |
| f"Event '{start_event_name}' or '{end_event_name}' not found." | |
| ) | |
| start_events_time = self.events.by_name[start_event_name] | |
| end_events_time = self.events.by_name[end_event_name] | |
| sections = [] | |
| for start_time in start_events_time: | |
| end_time = next((et for et in end_events_time if et > start_time), None) | |
| if end_time is not None: | |
| sections.append((start_time, end_time)) | |
| if not sections: | |
| raise ValueError("No valid start/end event pairs found.") | |
| sections.sort() | |
| all_time_arrays = [ | |
| np.arange(start_time, end_time, 1e9 / fps, dtype=np.int64) | |
| for start_time, end_time in sections | |
| ] | |
| combined_output_time = ( | |
| np.concatenate(all_time_arrays) | |
| if any(len(arr) > 0 for arr in all_time_arrays) | |
| else np.array([], dtype=np.int64) | |
| ) | |
| def _generator_fn() -> Generator[tuple[Any, ...], None, None]: | |
| total_prev_length = 0 | |
| for section_idx, (start_time, end_time) in enumerate(sections): | |
| section_output_time = all_time_arrays[section_idx] | |
| if not len(section_output_time): | |
| continue | |
| section_start_diff = -start_time + total_prev_length | |
| section_grouped_data = [] | |
| for timeserie in grouped_timeseries: | |
| if timeserie and len(timeserie.time) > 0: | |
| sec_start_idx = np.searchsorted( | |
| timeserie.time, start_time, side="left" | |
| ) | |
| sec_end_idx = np.searchsorted( | |
| timeserie.time, end_time, side="right" | |
| ) | |
| ts_slice = timeserie.time[sec_start_idx:sec_end_idx] | |
| if len(ts_slice) > 0: | |
| frames_slice = timeserie.sample(ts_slice) | |
| section_grouped_data.append({ | |
| "time": ts_slice, | |
| "frames": frames_slice, | |
| "search_idx": 0, | |
| }) | |
| else: | |
| section_grouped_data.append(None) | |
| else: | |
| section_grouped_data.append(None) | |
| std_samples_list = [ | |
| s.sample(section_output_time) | |
| if s | |
| else [None] * len(section_output_time) | |
| for s in std_timeseries | |
| ] | |
| def _get_section_grouped_samples(end_frame_time: np.int64) -> list: | |
| yield_values = [] | |
| for g_data in section_grouped_data: | |
| if g_data: | |
| start_idx = g_data["search_idx"] | |
| end_idx = np.searchsorted( | |
| g_data["time"], end_frame_time, side="right" | |
| ) | |
| ts_slice = g_data["time"][start_idx:end_idx] | |
| frames_slice = g_data["frames"][start_idx:end_idx] | |
| yield_values.append((ts_slice, frames_slice)) | |
| g_data["search_idx"] = end_idx | |
| else: | |
| yield_values.append((np.array([], dtype=np.int64), [])) | |
| return yield_values | |
| for i, (t_start, t_end) in enumerate(pairwise(section_output_time)): | |
| current_std_values = [s[i] for s in std_samples_list] | |
| grouped_yield_values = _get_section_grouped_samples(t_end) | |
| yield ( | |
| section_start_diff, | |
| t_start, | |
| *current_std_values, | |
| *grouped_yield_values, | |
| ) | |
| last_t = section_output_time[-1] | |
| last_std_values = [s[-1] for s in std_samples_list] | |
| grouped_yield_values = _get_section_grouped_samples(end_time) | |
| yield ( | |
| section_start_diff, | |
| last_t, | |
| *last_std_values, | |
| *grouped_yield_values, | |
| ) | |
| total_prev_length += end_time - start_time | |
| return _generator_fn(), combined_output_time | |
| MODEL_CONFIG: dict[str, Any] = { | |
| "hand": { | |
| "name": "hand_landmarker.task", | |
| "url": "https://storage.googleapis.com/mediapipe-models/hand_landmarker/hand_landmarker/float16/1/hand_landmarker.task", | |
| "options": { | |
| "num_hands": 4, | |
| "min_hand_detection_confidence": 0.25, | |
| "min_hand_presence_confidence": 0.25, | |
| }, | |
| "model_name": "MediaPipe Hand Landmarker", | |
| }, | |
| "pose": { | |
| "name": "pose_landmarker_heavy.task", | |
| "url": "https://storage.googleapis.com/mediapipe-models/pose_landmarker/pose_landmarker_heavy/float16/1/pose_landmarker_heavy.task", | |
| "options": { | |
| "num_poses": 1, | |
| "min_pose_detection_confidence": 0.5, | |
| "min_pose_presence_confidence": 0.25, | |
| }, | |
| "model_name": "MediaPipe Pose Landmarker", | |
| }, | |
| "depth": { | |
| "name": "depth_pro.pt", | |
| "url": "https://ml-site.cdn-apple.com/models/depth-pro/depth_pro.pt", | |
| "model_name": "ml-depth-pro", | |
| }, | |
| "facial_landmarks": { | |
| "name": "face_landmarker.task", | |
| "url": "https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task", | |
| "options": { | |
| "num_faces": 1, | |
| "min_face_detection_confidence": 0.5, | |
| "min_face_presence_confidence": 0.5, | |
| "min_tracking_confidence": 0.5, | |
| }, | |
| "model_name": "MediaPipe Face Landmarker", | |
| }, | |
| "ultralytics": {}, | |
| "rfdetr": {}, | |
| } | |
| def _deep_merge(base: dict, updates: dict) -> dict: | |
| """Recursively merges dictionaries, with `updates` taking precedence.""" | |
| merged = base.copy() | |
| for key, value in updates.items(): | |
| if isinstance(merged.get(key), dict) and isinstance( | |
| value, collections.abc.Mapping | |
| ): | |
| merged[key] = _deep_merge(merged[key], value) | |
| else: | |
| merged[key] = value | |
| return merged | |
| def _fill_model_params(params: dict[str, Any]) -> dict: | |
| filled = {} | |
| for key, value in params.items(): | |
| filled[key] = _deep_merge(MODEL_CONFIG[key], value) | |
| return filled | |
| def _initialize_detectors( | |
| device: str, model_params: dict[str, Any], base_path: UPath = UPath("./checkpoints") | |
| ) -> dict[ | |
| str, | |
| vision.HandLandmarker | |
| | vision.PoseLandmarker | |
| | vision.FaceLandmarker | |
| | YOLO | |
| | RTDETR | |
| | tuple[depth_pro.depth_pro.DepthPro, depth_pro.depth_pro.Compose] | |
| | RFDETRBase | |
| | None, | |
| ]: | |
| detectors = {} | |
| for model_key, config in model_params.items(): | |
| console.log(f"Initializing {config['model_name']}...") | |
| if model_key in ["hand", "pose", "facial_landmarks", "depth"]: | |
| model_path = UPath(base_path / config["name"]) | |
| if not model_path.exists(): | |
| console.log( | |
| f"Downloading {config['model_name']} model to {model_path}..." | |
| ) | |
| download( | |
| config["url"], | |
| dir=base_path, | |
| ) | |
| if model_key in ["hand", "pose", "facial_landmarks"]: | |
| delegate = ( | |
| mpipe.tasks.BaseOptions.Delegate.GPU | |
| if "cuda" in device or "gpu" in device | |
| else mpipe.tasks.BaseOptions.Delegate.CPU | |
| ) | |
| base_options = python.BaseOptions( | |
| model_asset_path=str(model_path), delegate=delegate | |
| ) | |
| elif model_key == "depth": | |
| model, depth_transform = depth_pro.create_model_and_transforms( | |
| device=device, precision=torch.half | |
| ) | |
| model.eval() | |
| if model_key == "hand": | |
| hand_options = vision.HandLandmarkerOptions( | |
| base_options=base_options, | |
| running_mode=mpipe.tasks.vision.RunningMode.VIDEO, | |
| **config["options"], | |
| ) | |
| model = vision.HandLandmarker.create_from_options(hand_options) | |
| elif model_key == "pose": | |
| pose_options = vision.PoseLandmarkerOptions( | |
| base_options=base_options, | |
| running_mode=mpipe.tasks.vision.RunningMode.VIDEO, | |
| **config["options"], | |
| ) | |
| model = vision.PoseLandmarker.create_from_options(pose_options) | |
| elif model_key == "facial_landmarks": | |
| face_options = vision.FaceLandmarkerOptions( | |
| base_options=base_options, | |
| running_mode=mpipe.tasks.vision.RunningMode.VIDEO, | |
| **config["options"], | |
| ) | |
| model = vision.FaceLandmarker.create_from_options(face_options) | |
| elif model_key == "ultralytics": | |
| model_path = UPath(base_path / config["model_name"]) | |
| if not model_path.exists(): | |
| model = YOLO(config["model_name"]) | |
| else: | |
| model = YOLO(model_path) | |
| elif model_key == "rfdetr": | |
| model = RFDETRBase( | |
| model_name=config["model_name"], | |
| device=device, | |
| classes=COCO_CLASSES, | |
| ) | |
| model.optimize_for_inference() | |
| detectors[model_key] = ( | |
| model if model_key != "depth" else (model, depth_transform) | |
| ) | |
| if len(detectors) < len(model_params): | |
| console.log( | |
| "[bold red]Failed to initialize all detectors. Some models may not be available.[/bold red]" | |
| ) | |
| return detectors | |
| def _print_arguments_table(): | |
| ctx = click.get_current_context() | |
| table = Table( | |
| title="Arguments Provided", | |
| show_header=True, | |
| header_style="bold magenta", | |
| ) | |
| table.add_column("Parameter", style="cyan") | |
| table.add_column("Value Provided", style="white") | |
| provided_params = [ | |
| (param_name, ctx.params[param_name]) for param_name in ctx.params | |
| ] | |
| if not provided_params: | |
| console.print( | |
| "[yellow]No custom parameters were provided. Using all defaults.[/yellow]" | |
| ) | |
| return | |
| for key, value in provided_params: | |
| table.add_row(key, str(value)) | |
| console.print(table) | |
| def _predict_depth( | |
| frame: np.ndarray, | |
| f_px: float, | |
| transform: depth_pro.depth_pro.Compose, | |
| model: depth_pro.depth_pro.DepthPro, | |
| ) -> np.ndarray: | |
| """Predicts depth based on single frame data using ml-depth-pro.""" | |
| transformed = transform(frame[:, :, :3]) | |
| if not isinstance(transformed, torch.Tensor): | |
| transformed = torch.from_numpy(transformed) | |
| prediction = model.infer(transformed, f_px=f_px) | |
| depth = prediction["depth"].detach().cpu().numpy().squeeze() | |
| return depth | |
| def _calculate_visual_angle( | |
| p1: tuple[float, float], p2: tuple[float, float], calibration: Calibration | |
| ) -> float: | |
| K = calibration.scene_camera_matrix | |
| D = calibration.scene_distortion_coefficients | |
| points_distorted = np.array([p1, p2], dtype=np.float32).reshape(-1, 1, 2) | |
| points_undistorted_normalized = cv2.undistortPoints(points_distorted, K, D) | |
| vec1 = np.array([ | |
| points_undistorted_normalized[0, 0, 0], | |
| points_undistorted_normalized[0, 0, 1], | |
| 1.0, | |
| ]) | |
| vec2 = np.array([ | |
| points_undistorted_normalized[1, 0, 0], | |
| points_undistorted_normalized[1, 0, 1], | |
| 1.0, | |
| ]) | |
| vec1_normalized = vec1 / np.linalg.norm(vec1) | |
| vec2_normalized = vec2 / np.linalg.norm(vec2) | |
| dot_product = np.clip(np.dot(vec1_normalized, vec2_normalized), -1.0, 1.0) | |
| angle_rad = np.arccos(dot_product) | |
| return np.rad2deg(angle_rad) | |
| def _annotate_frame( | |
| frame: np.ndarray, | |
| gaze: GazeRecord, | |
| blink: BlinkRecord | None, | |
| results: dict[ | |
| str, | |
| vision.HandLandmarkerResult | |
| | vision.PoseLandmarkerResult | |
| | vision.FaceLandmarkerResult | |
| | sv.Detections | |
| | np.ndarray | |
| | None, | |
| ], | |
| eye_frame: VideoFrame, | |
| time: np.int64, | |
| fps: int, | |
| calibration: Calibration | None = None, | |
| ) -> tuple[np.ndarray, dict[str, float]]: | |
| """Annotates the frame with gaze, blink, hand landmarks, and ball detection.""" | |
| global OFFSET | |
| object_center = None | |
| right_hand_data = {} | |
| left_hand_data = {} | |
| gaze_point = np.array([gaze.point[0] + OFFSET[0], gaze.point[1] + OFFSET[1]]) | |
| if results.get("depth") is not None: | |
| depth_vis = np.nan_to_num(results.get("depth").astype(np.float32), nan=0.0) | |
| depth_vis_clipped = np.clip(depth_vis, 0, 10) # For visualization purposes | |
| depth_vis_norm = ((depth_vis_clipped / 10.0) * 255).astype(np.uint8) | |
| depth_vis_color = cv2.applyColorMap(depth_vis_norm, cv2.COLORMAP_JET) | |
| # Half - blend the depth visualization with the original frame | |
| # h, w = frame.shape[:2] | |
| # margin_width = 100 | |
| # center_x = w // 2 | |
| # margin_start = center_x - (margin_width // 2) | |
| # margin_end = center_x + (margin_width // 2) | |
| # alpha_mask = np.full((h, w), 0.7, dtype=np.float32) | |
| # gradient = np.linspace(0.7, 0.0, margin_width) | |
| # alpha_mask[:, margin_start:margin_end] = gradient | |
| # alpha_mask[:, margin_end:] = 0.0 | |
| # alpha_3d = alpha_mask[:, :, np.newaxis] | |
| # frame_float = frame.astype(np.float32) | |
| # depth_vis_float = depth_vis_color.astype(np.float32) | |
| # blended_float = depth_vis_float * alpha_3d + frame_float * (1 - alpha_3d) | |
| # frame = blended_float.astype(np.uint8) | |
| frame = cv2.addWeighted(depth_vis_color, 0.9, frame, 0.1, 0) | |
| if results.get("object") is not None: | |
| object_result = results["object"] | |
| frame = sv.BoxCornerAnnotator(color=sv.Color.ROBOFLOW).annotate( | |
| scene=frame, detections=object_result | |
| ) | |
| frame = sv.DotAnnotator(color=sv.Color.BLUE, radius=10).annotate( | |
| scene=frame, detections=object_result | |
| ) | |
| frame = sv.RichLabelAnnotator( | |
| text_position=sv.Position.TOP_RIGHT, color=sv.Color.ROBOFLOW | |
| ).annotate( | |
| scene=frame, | |
| detections=object_result, | |
| labels=[ | |
| f"{object_result.data['class_name'][0]} {object_result.confidence[0]:.2f}" | |
| ] | |
| if hasattr(object_result.data, "class_name") | |
| else [ | |
| f"{COCO_CLASSES[class_id]} {confidence:.2f}" | |
| for class_id, confidence in zip( | |
| object_result.class_id, object_result.confidence, strict=False | |
| ) | |
| ], | |
| ) | |
| object_center = ( | |
| object_result.get_anchors_coordinates(sv.Position.CENTER)[0] | |
| if object_result.xyxy.size != 0 | |
| else None | |
| ) | |
| if object_result.mask is not None: | |
| frame = sv.MaskAnnotator(color=sv.ColorPalette.ROBOFLOW).annotate( | |
| scene=frame, detections=object_result | |
| ) | |
| if results.get("hand") and results["hand"].hand_landmarks: | |
| hand_result = results["hand"] | |
| for i, hand_landmarks in enumerate(hand_result.hand_landmarks): | |
| hand_landmarks_proto = landmark_pb2.NormalizedLandmarkList() | |
| hand_landmarks_proto.landmark.extend([ | |
| landmark_pb2.NormalizedLandmark(x=lm.x, y=lm.y, z=lm.z) | |
| for lm in hand_landmarks | |
| ]) | |
| solutions.drawing_utils.draw_landmarks( | |
| frame, | |
| hand_landmarks_proto, | |
| solutions.hands.HAND_CONNECTIONS, | |
| solutions.drawing_styles.get_default_hand_landmarks_style(), | |
| solutions.drawing_styles.get_default_hand_connections_style(), | |
| ) | |
| landmark_coords = np.array([ | |
| (lm.x * frame.shape[1], lm.y * frame.shape[0]) for lm in hand_landmarks | |
| ]) | |
| handedness = hand_result.handedness[i][0].category_name.lower() | |
| dists_to_gaze = cdist(landmark_coords, [gaze_point]) | |
| closest_gaze_idx = np.argmin(dists_to_gaze) | |
| closest_gaze_coord = tuple(map(int, landmark_coords[closest_gaze_idx])) | |
| current_hand_data = {"min_dist_to_gaze": dists_to_gaze[closest_gaze_idx][0]} | |
| if object_center is not None: | |
| dists_to_obj = cdist(landmark_coords, [object_center]) | |
| closest_obj_idx = np.argmin(dists_to_obj) | |
| closest_obj_coord = tuple(map(int, landmark_coords[closest_obj_idx])) | |
| current_hand_data["min_dist_to_obj"] = dists_to_obj[closest_obj_idx][0] | |
| cv2.line( | |
| frame, | |
| tuple(map(int, object_center)), | |
| closest_obj_coord, | |
| (0, 255, 255), | |
| 2, | |
| ) | |
| cv2.line( | |
| frame, | |
| closest_gaze_coord, | |
| tuple(map(int, object_center)), | |
| (0, 128, 255), | |
| 2, | |
| ) | |
| cv2.line( | |
| frame, | |
| closest_obj_coord, | |
| tuple(map(int, gaze_point)), | |
| (128, 0, 255), | |
| 2, | |
| ) | |
| cv2.line( | |
| frame, tuple(map(int, gaze_point)), closest_gaze_coord, (255, 0, 0), 2 | |
| ) | |
| if handedness == "right": | |
| right_hand_data = current_hand_data | |
| elif handedness == "left": | |
| left_hand_data = current_hand_data | |
| if results.get("pose") and results["pose"].pose_landmarks: | |
| pose_result = results["pose"] | |
| for _, pose_landmarks in enumerate(pose_result.pose_landmarks): | |
| pose_landmarks_proto = landmark_pb2.NormalizedLandmarkList() | |
| pose_landmarks_proto.landmark.extend([ | |
| landmark_pb2.NormalizedLandmark(x=lm.x, y=lm.y, z=lm.z) | |
| for lm in pose_landmarks | |
| ]) | |
| solutions.drawing_utils.draw_landmarks( | |
| frame, | |
| pose_landmarks_proto, | |
| solutions.pose.POSE_CONNECTIONS, | |
| solutions.drawing_styles.get_default_pose_landmarks_style(), | |
| ) | |
| if results.get("facial_landmarks") and results["facial_landmarks"].face_landmarks: | |
| face_result = results["facial_landmarks"] | |
| for face_landmarks in face_result.face_landmarks: | |
| face_landmarks_proto = landmark_pb2.NormalizedLandmarkList() | |
| face_landmarks_proto.landmark.extend([ | |
| landmark_pb2.NormalizedLandmark(x=lm.x, y=lm.y, z=lm.z) | |
| for lm in face_landmarks | |
| ]) | |
| solutions.drawing_utils.draw_landmarks( | |
| image=frame, | |
| landmark_list=face_landmarks_proto, | |
| connections=solutions.face_mesh.FACEMESH_TESSELATION, | |
| landmark_drawing_spec=None, | |
| connection_drawing_spec=solutions.drawing_styles.get_default_face_mesh_tesselation_style(), | |
| ) | |
| solutions.drawing_utils.draw_landmarks( | |
| image=frame, | |
| landmark_list=face_landmarks_proto, | |
| connections=solutions.face_mesh.FACEMESH_CONTOURS, | |
| landmark_drawing_spec=None, | |
| connection_drawing_spec=solutions.drawing_styles.get_default_face_mesh_contours_style( | |
| 1 | |
| ), | |
| ) | |
| solutions.drawing_utils.draw_landmarks( | |
| image=frame, | |
| landmark_list=face_landmarks_proto, | |
| connections=solutions.face_mesh.FACEMESH_IRISES, | |
| landmark_drawing_spec=None, | |
| connection_drawing_spec=solutions.drawing_styles.get_default_face_mesh_iris_connections_style(), | |
| ) | |
| if abs(gaze.time - time) < 2e9 / fps: | |
| color = ( | |
| (128, 128, 128) | |
| if blink and blink.start_time < time < blink.end_time | |
| else (0, 0, 255) | |
| ) | |
| cv2.circle(frame, tuple(map(int, gaze_point)), 16, color, 4) | |
| if results.get("depth") is not None: | |
| depth_result = results["depth"] | |
| cv2.putText( | |
| frame, | |
| f"{depth_result[int(gaze_point[1]), int(gaze_point[0])]:.2f} m", | |
| (int(gaze_point[0] + 20), int(gaze_point[1] + 10)), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.8, | |
| (255, 255, 255), | |
| 2, | |
| ) | |
| if ( | |
| abs(eye_frame.time - time) < 2e9 / fps | |
| and eye_frame.bgr.shape[0] > 0 | |
| and eye_frame.bgr.shape[1] > 0 | |
| ): | |
| h, w = eye_frame.bgr.shape[:2] | |
| mask = np.zeros((h, w), dtype=np.uint8) | |
| radius = 5 | |
| cv2.rectangle(mask, (radius, 0), (w - radius, h), 255, -1) | |
| cv2.rectangle(mask, (0, radius), (w, h - radius), 255, -1) | |
| cv2.circle(mask, (radius, radius), radius, 255, -1) | |
| cv2.circle(mask, (w - radius - 1, radius), radius, 255, -1) | |
| cv2.circle(mask, (radius, h - radius - 1), radius, 255, -1) | |
| cv2.circle(mask, (w - radius - 1, h - radius - 1), radius, 255, -1) | |
| roi = frame[15 : h + 15, 15 : w + 15] | |
| eye_bgr_masked = cv2.bitwise_and(eye_frame.bgr, eye_frame.bgr, mask=mask) | |
| roi_bg = cv2.bitwise_and(roi, roi, mask=cv2.bitwise_not(mask)) | |
| combined = cv2.add(roi_bg, eye_bgr_masked) | |
| frame[10 : h + 10, 10 : w + 10] = combined | |
| distances: dict[str, float] = { | |
| "gaze_to_ball": cdist([gaze_point], [object_center])[0][0] | |
| if object_center is not None | |
| else np.nan, | |
| "gaze_to_ball_angle": ( | |
| _calculate_visual_angle(gaze_point, object_center, calibration) | |
| if object_center is not None and calibration is not None | |
| else np.nan | |
| ), | |
| "gaze_to_right_hand": right_hand_data.get("min_dist_to_gaze", np.nan), | |
| "gaze_to_left_hand": left_hand_data.get("min_dist_to_gaze", np.nan), | |
| "ball_to_right_hand": right_hand_data.get("min_dist_to_ball", np.nan), | |
| "ball_to_left_hand": left_hand_data.get("min_dist_to_ball", np.nan), | |
| } | |
| return frame, distances | |
| def _ns_to_vtt_time(time: np.int64) -> str: | |
| total_ms = time // 1_000_000 | |
| h, rem = divmod(total_ms, 3_600_000) | |
| m, rem = divmod(rem, 60_000) | |
| s, ms = divmod(rem, 1_000) | |
| return f"{h:02d}:{m:02d}:{s:02d}.{ms:03d}" | |
| def _build_subtitle_text( | |
| time: np.int64, | |
| initial_time: int, | |
| gaze: GazeRecord, | |
| eyeball: EyeballRecord, | |
| pupil: PupilRecord, | |
| eyelid: EyelidRecord, | |
| distances: dict[str, float], | |
| ) -> str: | |
| lines = [f"Time: {_ns_to_vtt_time(time - initial_time)}"] | |
| lines.append( | |
| f"Gaze: ({int(gaze.point[0] + OFFSET[0]):04d}, {int(gaze.point[1] + OFFSET[1]):04d})" | |
| ) | |
| subtitle_data = { | |
| "Pupil Size": ( | |
| f"{np.nanmean(pupil.diameter_left):.2f}", | |
| "mm", | |
| ), | |
| "Eyelid Aperture": ( | |
| f"{np.nanmean(eyelid.aperture_left):.2f}", | |
| "mm", | |
| ), | |
| "Object2Gaze Distance": ( | |
| ( | |
| f"{np.nanmean(distances.get('gaze_to_ball')):.2f}px / " | |
| f"{np.nanmean(distances.get('gaze_to_ball_angle')):.2f}°" | |
| ), | |
| "", | |
| ), | |
| "Gaze2Hand Distances (R/L)": ( | |
| ( | |
| f"{np.nanmean(distances.get('gaze_to_right_hand')):.2f} / " | |
| f"{np.nanmean(distances.get('gaze_to_left_hand')):.2f}" | |
| ), | |
| "px", | |
| ), | |
| "Object2Hand Distances (R/L)": ( | |
| ( | |
| f"{np.nanmean(distances.get('ball_to_right_hand')):.2f} / " | |
| f"{np.nanmean(distances.get('ball_to_left_hand')):.2f}" | |
| ), | |
| "px", | |
| ), | |
| } | |
| for label, (value, unit) in subtitle_data.items(): | |
| if value is not None: | |
| lines.append(f"{label}: {value} {unit}") | |
| else: | |
| lines.append(f"{label}: N/A") | |
| return "\n".join(lines) | |
| @click.command() | |
| @click.argument("recording_directory", type=click.Path(exists=True)) | |
| @click.option( | |
| "--output-filename", default="output_vid", help="Base name for output files." | |
| ) | |
| @click.option( | |
| "--sub", is_flag=True, default=False, help="Include subtitles in output video." | |
| ) | |
| @click.option( | |
| "--preview", | |
| is_flag=True, | |
| default=False, | |
| help="Show live preview of annotated frames.", | |
| ) | |
| @click.option( | |
| "--device", | |
| "device_str", | |
| default="cpu", | |
| show_default=True, | |
| help='Device for inference, e.g., "cpu", "cuda", "mps".', | |
| ) | |
| @click.option( | |
| "--detectors", | |
| "detectors_str", | |
| default="hand,pose,ultralytics,facial_landmarks", | |
| show_default=True, | |
| help=( | |
| "Comma-separated list of detectors to use: 'hand', 'pose', 'ultralytics'," | |
| "'depth', 'facial_landmarks'." | |
| ), | |
| ) | |
| @click.option( | |
| "--conf", | |
| "conf_threshold", | |
| default=0.25, | |
| show_default=True, | |
| help="Confidence threshold for object detection.", | |
| ) | |
| @click.option( | |
| "--model", | |
| "model", | |
| default="rtdetr-l.pt", | |
| show_default=True, | |
| help=( | |
| "Ultralytics / RF-DETR model to use for detection. (yolo11x-seg.pt, " | |
| "rtdetr-l.pt, rfdetr-large)" | |
| ), | |
| ) | |
| @click.option( | |
| "--class-id", | |
| "class_id", | |
| default=32, | |
| show_default=True, | |
| help="Class ID to detect. Default is 32 ('sports ball' in COCO, 37 in RF-DETR.)", | |
| ) | |
| @click.option( | |
| "--offset", | |
| "offset_str", | |
| default="0,0", | |
| show_default=True, | |
| help="Gaze offset in pixels, format: 'x,y'.", | |
| ) | |
| @click.option( | |
| "--start", | |
| "start_event", | |
| type=click.STRING, | |
| default="recording.begin", | |
| show_default=True, | |
| help="Event name or timestamp to start processing from.", | |
| ) | |
| @click.option( | |
| "--end", | |
| "end_event", | |
| type=click.STRING, | |
| default="recording.end", | |
| show_default=True, | |
| help="Event name or timestamp to stop processing at.", | |
| ) | |
| @click.option( | |
| "--fps-mode", | |
| type=click.Choice(["scene", "gaze"], case_sensitive=False), | |
| default="scene", | |
| show_default=True, | |
| help="Choose FPS mode: 'scene' or 'gaze'.", | |
| ) | |
| @click.option( | |
| "--mute", | |
| is_flag=True, | |
| default=False, | |
| help="Mute audio in the output video.", | |
| ) | |
| def main( | |
| recording_directory: str, | |
| output_filename: str, | |
| fps_mode: str, | |
| sub: bool, | |
| preview: bool, | |
| detectors_str: str, | |
| model: str, | |
| device_str: str, | |
| conf_threshold: float, | |
| class_id: int, | |
| offset_str: str, | |
| start_event: str, | |
| end_event: str, | |
| mute: bool, | |
| ) -> None: | |
| try: | |
| x, y = map(int, offset_str.split(",")) | |
| global OFFSET | |
| OFFSET = [x, y] | |
| except ValueError: | |
| console.log("[bold red]Invalid offset format. Use 'x,y'. Exiting.[/bold red]") | |
| return | |
| detectors_dict = {} | |
| detector_map = { | |
| "hand": {}, | |
| "pose": {}, | |
| "facial_landmarks": {}, | |
| "ultralytics": {"model_name": model}, | |
| "rfdetr": {"model_name": model}, | |
| "depth": {}, | |
| } | |
| for det in (d.strip().lower() for d in detectors_str.split(",")): | |
| if det in detector_map: | |
| detectors_dict[det] = detector_map[det] | |
| detectors_dict = _fill_model_params(detectors_dict) | |
| detectors = _initialize_detectors(device_str, detectors_dict) | |
| _print_arguments_table() | |
| rec = RecordingWithSampler(recording_directory) | |
| ns_mean_isi = ( | |
| np.nanmean(np.diff(rec.gaze.time)) | |
| if fps_mode == "gaze" | |
| else np.nanmean(np.diff(rec.scene.time)) | |
| ) | |
| fps = int(np.round(1 / (ns_mean_isi / 1e9))) | |
| scene_fps = int(np.round(1 / (np.nanmean(np.diff(rec.scene.time)) / 1e9))) | |
| combined, output_time = rec.sections_sampler( | |
| rec.scene, | |
| rec.gaze, | |
| rec.blinks, | |
| rec.eyeball, | |
| rec.pupil, | |
| rec.eyelid, | |
| rec.eye, | |
| InterSamples(rec.audio) if not mute else None, | |
| fps=scene_fps, | |
| start_event_name=start_event, | |
| end_event_name=end_event, | |
| ) | |
| subtitle_events = [] | |
| video_path = str(UPath(recording_directory) / f"{output_filename}.mp4") | |
| progress_columns = [ | |
| SpinnerColumn(), | |
| TextColumn("[progress.description]{task.description}"), | |
| BarColumn(), | |
| TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), | |
| TextColumn("{task.completed} of {task.total} frames"), | |
| TimeRemainingColumn(), | |
| ] | |
| with Progress(*progress_columns, console=console) as progress: | |
| task = progress.add_task("[cyan]Processing frames...", total=len(output_time)) | |
| with plv.Writer(video_path, bit_rate=8_000_000) as writer: | |
| for ( | |
| local_time_diff, | |
| time, | |
| scene_frame, | |
| gaze, | |
| blink, | |
| eyeball, | |
| pupil, | |
| eyelid, | |
| eye_frame, | |
| audio_samples, | |
| ) in combined: | |
| if abs(scene_frame.time - time) < 2e9 / scene_fps: | |
| frame = scene_frame.bgr | |
| else: | |
| frame = GrayFrame(scene_frame.width, scene_frame.height).bgr | |
| rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| mp_image = mpipe.Image( | |
| image_format=mpipe.ImageFormat.SRGB, data=rgb_frame | |
| ) | |
| results: dict[str, Any] = {} | |
| mp_detector_names = ["hand", "pose", "facial_landmarks"] | |
| for name in mp_detector_names: | |
| if name in detectors: | |
| results[name] = detectors[name].detect_for_video( | |
| mp_image, time // 1000 | |
| ) | |
| else: | |
| results[name] = None | |
| if "ultralytics" in detectors: | |
| results["ultralytics"] = detectors["ultralytics"].predict( | |
| rgb_frame, | |
| verbose=False, | |
| device=device_str, | |
| classes=[class_id], | |
| conf=conf_threshold, | |
| )[0] | |
| else: | |
| results["ultralytics"] = None | |
| results["object"] = ( | |
| sv.Detections.from_ultralytics(results.get("ultralytics")) | |
| if results.get("ultralytics") | |
| else None | |
| ) | |
| if "rfdetr" in detectors: | |
| results["object"] = detectors["rfdetr"].predict( | |
| rgb_frame, | |
| device=device_str | |
| if device_str != "mps" | |
| else "cpu", # mps not supported by rfdetr | |
| threshold=conf_threshold, | |
| ) | |
| if results["object"] is not None: | |
| results["object"] = results["object"][ | |
| results["object"].class_id == class_id | |
| ] | |
| else: | |
| results["object"] = results.get("object") | |
| if "depth" in detectors: | |
| results["depth"] = _predict_depth( | |
| frame, | |
| f_px=rec.calibration["scene_camera_matrix"][0, 0], | |
| transform=detectors["depth"][1], | |
| model=detectors["depth"][0], | |
| ) | |
| else: | |
| results["depth"] = None | |
| frame, distances = _annotate_frame( | |
| frame, gaze, blink, results, eye_frame, time, fps, rec.calibration | |
| ) | |
| if preview: | |
| cv2.imshow("Live Preview", frame) | |
| if cv2.waitKey(1) & 0xFF == ord("q"): | |
| break | |
| if not mute: | |
| for a_time, audio_frame in zip( | |
| audio_samples[0], audio_samples[1], strict=False | |
| ): | |
| writer.write_frame( | |
| audio_frame.plv_frame, | |
| (a_time + local_time_diff) / 1e9, | |
| ) | |
| writer.write_image( | |
| frame, | |
| (time + local_time_diff) / 1e9, | |
| ) | |
| if sub: | |
| text = _build_subtitle_text( | |
| time, rec.start_time, gaze, eyeball, pupil, eyelid, distances | |
| ) | |
| subtitle_events.append(( | |
| text, | |
| int(time + local_time_diff), | |
| int(time + local_time_diff + ns_mean_isi), | |
| )) | |
| progress.update(task, advance=1) | |
| console.log(f"Video saved to [cyan]{video_path}[/cyan]") | |
| if preview: | |
| cv2.destroyAllWindows() | |
| if sub: | |
| vtt_path = str(UPath(recording_directory) / f"{output_filename}.vtt") | |
| with open(vtt_path, "w", encoding="utf-8") as vtt_file: | |
| vtt_file.write("WEBVTT\n\n") | |
| for text, start_ns, end_ns in subtitle_events: | |
| vtt_file.write( | |
| f"{_ns_to_vtt_time(start_ns)} --> {_ns_to_vtt_time(end_ns)}\n{text}\n\n" | |
| ) | |
| console.log(f"VTT subtitles written to [cyan]{vtt_path}[/cyan]") | |
| console.log("Embedding subtitles with FFMPEG...") | |
| video_with_subs_path = str( | |
| UPath(recording_directory) / f"{output_filename}_subbed.mp4" | |
| ) | |
| ffmpeg_command = [ | |
| "ffmpeg", | |
| "-i", | |
| video_path, | |
| "-i", | |
| vtt_path, | |
| "-c:v", | |
| "copy", | |
| "-c:a", | |
| "copy", | |
| "-c:s", | |
| "mov_text", | |
| "-map", | |
| "0:v", | |
| "-map", | |
| "0:a?", | |
| "-map", | |
| "1", | |
| "-metadata:s:s:0", | |
| "language=eng", | |
| "-y", | |
| video_with_subs_path, | |
| ] | |
| try: | |
| subprocess.run(ffmpeg_command, check=True, capture_output=True, text=True) | |
| shutil.move(video_with_subs_path, video_path) | |
| # UPath(vtt_path).unlink() | |
| console.log( | |
| f"[green]Subtitles successfully embedded into {video_path}[/green]" | |
| ) | |
| except FileNotFoundError: | |
| console.log( | |
| "[bold red]FFMPEG not found. Please install FFMPEG and ensure it is " | |
| "in your system's PATH.[/bold red]" | |
| ) | |
| except subprocess.CalledProcessError as e: | |
| console.log(f"[bold red]FFMPEG error:\n{e.stderr}[/bold red]") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment