Created
March 18, 2022 11:27
-
-
Save papr/1341a7f8badbb2da8e0e7dddda126c99 to your computer and use it in GitHub Desktop.
Partial simulation of https://pupil-labs-realtime-api.readthedocs.io/en/stable/api/simple.html
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import abc | |
import collections | |
import json | |
import logging | |
import pathlib | |
import threading | |
import time | |
from typing import Deque, Mapping, NamedTuple, Optional, Type, Union | |
try: | |
from typing import Literal | |
except ImportError: | |
# FIXME: Remove when dropping py3.7 support | |
from typing_extensions import Literal | |
import av | |
import numpy as np | |
import pandas as pd | |
from pupil_labs.realtime_api.base import DeviceBase, DeviceType | |
from pupil_labs.realtime_api.models import APIPath, DiscoveredDeviceInfo, Event, Sensor | |
from pupil_labs.realtime_api.simple._utils import _StreamManager | |
from pupil_labs.realtime_api.simple.device import Device | |
from pupil_labs.realtime_api.simple.models import ( | |
MATCHED_ITEM_LABEL, | |
GazeData, | |
MatchedItem, | |
SimpleVideoFrame, | |
VideoFrame, | |
) | |
logger = logging.getLogger(__name__) | |
class SimulatedDeviceFromRecording(Device): | |
""" | |
Simulates the data receiving functionality of pupil_labs.realtime_api.simple.Device | |
by reading data from an existing recording. | |
Caveats: | |
- does not simulate any of the control functions, e.g. recording_start() | |
- does not simulate network delay | |
- stops providing data when reaching the end of recording | |
- gaze data `worn` field is always True | |
""" | |
def __init__( | |
self, | |
address: str, | |
port: int = 0, | |
full_name: Optional[str] = None, | |
dns_name: Optional[str] = None, | |
start_streaming_by_default: bool = False, | |
suppress_decoding_warnings: bool = True, | |
) -> None: | |
DeviceBase.__init__( | |
self, address, port, full_name, dns_name, suppress_decoding_warnings | |
) | |
self._background_worker_thread = None | |
self._background_worker_thread_should_stop = threading.Event() | |
self._background_worker_thread_should_stop.set() | |
sensor_names = [ | |
Sensor.Name.GAZE.value, | |
Sensor.Name.WORLD.value, | |
MATCHED_ITEM_LABEL, | |
] | |
self._cached_gaze_for_matching: Deque[GazeData] = collections.deque(maxlen=200) | |
self._most_recent_item: Mapping[str, Deque] = { | |
name: collections.deque(maxlen=1) for name in sensor_names | |
} | |
self._event_new_item: Mapping[str, threading.Event] = { | |
name: threading.Event() for name in sensor_names | |
} | |
if start_streaming_by_default: | |
self.streaming_start() | |
def receive_scene_video_frame( | |
self, timeout_seconds: Optional[float] = None | |
) -> Optional[SimpleVideoFrame]: | |
return self._receive_item(Sensor.Name.WORLD.value, timeout_seconds) | |
def receive_gaze_datum( | |
self, timeout_seconds: Optional[float] = None | |
) -> Optional[GazeData]: | |
return self._receive_item(Sensor.Name.GAZE.value, timeout_seconds) | |
def receive_matched_scene_video_frame_and_gaze( | |
self, timeout_seconds: Optional[float] = None | |
) -> Optional[MatchedItem]: | |
return self._receive_item(MATCHED_ITEM_LABEL, timeout_seconds) | |
def _receive_item(self, sensor: str, timeout_seconds: Optional[float] = None): | |
if not self.is_currently_streaming: | |
self.streaming_start() | |
try: | |
return self._most_recent_item[sensor].popleft() | |
except IndexError: | |
# no cached frame available, waiting for new one | |
event_new_item = self._event_new_item[sensor] | |
event_new_item.clear() | |
if event_new_item.wait(timeout=timeout_seconds): | |
return self._most_recent_item[sensor].popleft() | |
return None | |
def __repr__(self) -> str: | |
return f"SimulatedDeviceFromRecording(path={self.address})" | |
def streaming_start(self): | |
if self.is_currently_streaming: | |
return | |
self._start_background_worker() | |
def streaming_stop(self): | |
if self.is_currently_streaming: | |
self._background_worker_thread_should_stop.set() | |
self._background_worker_thread.join() | |
@property | |
def is_currently_streaming(self) -> bool: | |
return self._background_worker_thread is not None | |
def close(self) -> None: | |
self.streaming_stop() | |
def __enter__(self): | |
return self | |
def __exit__(self, *exc): | |
self.close() | |
def _start_background_worker(self): | |
event_worker_started = threading.Event() | |
self._background_worker_thread_should_stop.clear() | |
self._background_worker_thread = threading.Thread( | |
target=self._background_loop, | |
args=(event_worker_started,), | |
name=f"{self} background worker", | |
) | |
self._background_worker_thread.start() | |
event_worker_started.wait() | |
def _background_loop(self, event_worker_started: threading.Event): | |
recording = Recording(self.address) | |
event_worker_started.set() | |
_streaming_start_time = time.time_ns() | |
for sample in recording.samples(): | |
if self._background_worker_thread_should_stop.is_set(): | |
break | |
sample_age = sample.timestamp_ns - recording.start_time | |
time_since_start = time.time_ns() - _streaming_start_time | |
time_to_sleep = sample_age - time_since_start | |
if time_to_sleep > 0: | |
time.sleep(time_to_sleep * 1e-9) | |
self._most_recent_item[sample.stream].append(sample.data) | |
self._event_new_item[sample.stream].set() | |
if sample.stream == Sensor.Name.GAZE.value: | |
self._cached_gaze_for_matching.append( | |
(sample.timestamp_ns, sample.data) | |
) | |
elif sample.stream == Sensor.Name.WORLD.value: | |
try: | |
logger.debug( | |
f"Searching closest gaze datum in cache " | |
f"(len={len(device._cached_gaze_for_matching)})..." | |
) | |
gaze = _StreamManager._get_closest_item( | |
device._cached_gaze_for_matching, | |
sample.timestamp_ns, | |
) | |
except IndexError: | |
logger.debug("No cached gaze data available for matching") | |
else: | |
match_time_difference = ( | |
sample.data.timestamp_unix_seconds - gaze.timestamp_unix_seconds | |
) | |
logger.debug( | |
f"Found matching sample (time difference: " | |
f"{match_time_difference:.3f} seconds)" | |
) | |
self._most_recent_item[MATCHED_ITEM_LABEL].append( | |
MatchedItem(sample.data, gaze) | |
) | |
self._event_new_item[MATCHED_ITEM_LABEL].set() | |
# remaining methods are not implemented | |
def api_url( | |
self, path: APIPath, protocol: str = "http", prefix: str = "/api" | |
) -> str: | |
raise NotImplementedError | |
@classmethod | |
def from_discovered_device( | |
cls: Type[DeviceType], device: DiscoveredDeviceInfo | |
) -> DeviceType: | |
raise NotImplementedError | |
@property | |
def phone_name(self) -> str: | |
raise NotImplementedError | |
@property | |
def phone_id(self) -> str: | |
raise NotImplementedError | |
@property | |
def phone_ip(self) -> str: | |
raise NotImplementedError | |
@property | |
def battery_level_percent(self) -> int: | |
raise NotImplementedError | |
@property | |
def battery_state(self) -> Literal["OK", "LOW", "CRITICAL"]: | |
raise NotImplementedError | |
@property | |
def memory_num_free_bytes(self) -> int: | |
raise NotImplementedError | |
@property | |
def memory_state(self) -> Literal["OK", "LOW", "CRITICAL"]: | |
raise NotImplementedError | |
@property | |
def version_glasses(self) -> str: | |
raise NotImplementedError | |
@property | |
def serial_number_glasses(self) -> Union[str, None, Literal["default"]]: | |
raise NotImplementedError | |
@property | |
def serial_number_scene_cam(self) -> Optional[str]: | |
raise NotImplementedError | |
def world_sensor(self) -> Optional[Sensor]: | |
raise NotImplementedError | |
def gaze_sensor(self) -> Optional[Sensor]: | |
raise NotImplementedError | |
def recording_start(self) -> str: | |
raise NotImplementedError | |
def recording_stop_and_save(self): | |
raise NotImplementedError | |
def recording_cancel(self): | |
raise NotImplementedError | |
def send_event( | |
self, event_name: str, event_timestamp_unix_ns: Optional[int] = None | |
) -> Event: | |
raise NotImplementedError | |
class Recording: | |
class Sample(NamedTuple): | |
stream: str | |
timestamp_ns: int | |
data: Union[GazeData, SimpleVideoFrame] | |
def __init__(self, path: str) -> None: | |
self.path = pathlib.Path(path).resolve() | |
if not self.path.exists(): | |
raise FileNotFoundError(f"Recording not found: {self.path}") | |
with (self.path / "info.json").open("r") as fh: | |
self.info = json.load(fh) | |
self.streams: Mapping[str, Stream] = { | |
Sensor.Name.GAZE.value: GazeStream( | |
self.path, | |
), | |
Sensor.Name.WORLD.value: VideoStream(self.path), | |
} | |
@property | |
def start_time(self) -> int: | |
return self.info["start_time"] | |
def samples(self): | |
ts = ( | |
pd.concat( | |
[stream.ts for stream in self.streams.values()], | |
keys=self.streams.keys(), | |
names=["stream"], | |
) | |
.reset_index(level="stream") | |
.sort_values(by="time", ignore_index=True) | |
) | |
for sample in ts.itertuples(): | |
yield Recording.Sample( | |
sample.stream, | |
sample.time, | |
self.streams[sample.stream].next_data_sample(sample.time), | |
) | |
class Stream(abc.ABC): | |
def __init__(self, folder: pathlib.Path) -> None: | |
self.folder = folder | |
self.ts = self.load_time() | |
self.init_data_loading() | |
def load_time(self): | |
return pd.Series( | |
np.fromfile(self.time_path, dtype="<u8"), | |
name="time", | |
) | |
def __repr__(self) -> str: | |
return f"{type(self).__name__}(num_samples={len(self.ts)})" | |
@property | |
def time_path(self): | |
return self.folder / f"{self.name}.time" | |
@property | |
def data_path(self): | |
return self.folder / f"{self.name}.{self.data_suffix}" | |
@property | |
@abc.abstractmethod | |
def name(self) -> str: | |
raise NotImplementedError | |
@property | |
@abc.abstractmethod | |
def data_suffix(self) -> str: | |
raise NotImplementedError | |
@abc.abstractmethod | |
def init_data_loading(self): | |
raise NotImplementedError | |
@abc.abstractmethod | |
def next_data_sample(self, timestamp_ns: int): | |
raise NotImplementedError | |
class EndOfFileError(IOError): | |
pass | |
class GazeStream(Stream): | |
@property | |
def name(self) -> str: | |
return "gaze ps1" | |
@property | |
def data_suffix(self) -> str: | |
return "raw" | |
def init_data_loading(self): | |
self.data = np.fromfile(self.data_path, dtype="<f4") | |
self.data.shape = -1, 2 | |
self.current_index = 0 | |
def next_data_sample(self, timestamp_ns: int): | |
try: | |
x, y = self.data[self.current_index] | |
self.current_index += 1 | |
return GazeData(x, y, True, timestamp_ns * 1e-9) | |
except IndexError as err: | |
raise Stream.EndOfFileError from err | |
class VideoStream(Stream): | |
@property | |
def name(self) -> str: | |
return "PI world v1 ps1" | |
@property | |
def data_suffix(self) -> str: | |
return "mp4" | |
def init_data_loading(self): | |
container = av.open(str(self.data_path)) | |
self.frames = container.decode(video=0) | |
def next_data_sample(self, timestamp_ns: int): | |
try: | |
frame = next(self.frames) | |
return SimpleVideoFrame.from_video_frame( | |
VideoFrame(frame, timestamp_unix_seconds=timestamp_ns * 1e-9) | |
) | |
except StopIteration as err: | |
raise Stream.EndOfFileError from err | |
if __name__ == "__main__": | |
import argparse | |
import cv2 | |
from rich.logging import RichHandler | |
MODE_MATCHED = "matched" | |
MODE_SCENE_ONLY = "scene-only" | |
MODE_GAZE_ONLY = "gaze-only" | |
parser = argparse.ArgumentParser() | |
parser.add_argument("recording") | |
parser.add_argument( | |
"--mode", | |
"-m", | |
choices=[MODE_MATCHED, MODE_SCENE_ONLY, MODE_GAZE_ONLY], | |
default=MODE_MATCHED, | |
) | |
parser.add_argument("--debug", action="store_true") | |
parser.add_argument("--timeout", type=float, default=1.0) | |
args = parser.parse_args() | |
level = logging.DEBUG if args.debug else logging.INFO | |
logging.basicConfig(level=level, handlers=[RichHandler(level=level)]) | |
with SimulatedDeviceFromRecording( | |
args.recording, start_streaming_by_default=True | |
) as device: | |
logger.info(f"Start displaying {args.mode} data...") | |
blank_image = np.zeros((1080, 1088, 1), dtype="uint8") + 125 | |
while True: | |
if args.mode == MODE_MATCHED: | |
matched = device.receive_matched_scene_video_frame_and_gaze( | |
args.timeout | |
) | |
# matched is None if the timeout hits | |
frame, gaze = matched if matched else (None, None) | |
elif args.mode == MODE_SCENE_ONLY: | |
frame, gaze = device.receive_scene_video_frame(args.timeout), None | |
elif args.mode == MODE_GAZE_ONLY: | |
frame, gaze = None, device.receive_gaze_datum(args.timeout) | |
if frame: | |
image = frame.bgr_pixels | |
else: | |
image = blank_image.copy() | |
if gaze: | |
cv2.circle( | |
frame.bgr_pixels, | |
(int(gaze.x), int(gaze.y)), | |
radius=80, | |
color=(0, 0, 255), | |
thickness=15, | |
) | |
cv2.imshow(f"Mode {args.mode}", image) | |
if cv2.waitKey(1) & 0xFF == 27: | |
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pupil-labs-realtime-api | |
numpy | |
pandas | |
# example: | |
opencv-python | |
rich |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment