|
import contextlib |
|
import json |
|
import logging |
|
import os |
|
import pathlib |
|
from sys import flags |
|
import threading |
|
import time |
|
import typing as T |
|
|
|
import click |
|
import msgpack |
|
import numpy as np |
|
import zmq |
|
|
|
|
|
@click.command() |
|
@click.argument("capture_recording_loc", type=click.Path(exists=True)) |
|
@click.option("--repeat", default=False, is_flag=True, type=bool) |
|
@click.option("--pupil-remote-port", default=50020, type=int) |
|
@click.option("--pub-socket-port", default=50021, type=int) |
|
def main(capture_recording_loc, repeat, pupil_remote_port, pub_socket_port): |
|
capture_recording_loc = pathlib.Path(capture_recording_loc) |
|
with NetworkAPI(pupil_remote_port, pub_socket_port) as frontend: |
|
while True: |
|
logging.info("Waiting for `R` Pupil Remote command") |
|
frontend.should_publish_flag.wait() |
|
publish_recording(capture_recording_loc, frontend.pub_socket) |
|
if not repeat: |
|
break |
|
|
|
|
|
def publish_recording(capture_recording_loc, pub_socket): |
|
logger = logging.getLogger(__name__ + ".publish_recording") |
|
recording_to_realtime_clock = recording_clock_transform(capture_recording_loc) |
|
logger.debug(f"Recording->realtime {recording_to_realtime_clock}") |
|
for datum in yield_datum_with_smallest_timestamp(capture_recording_loc): |
|
recording_to_realtime_clock.transform_timestamps_to_target(datum) |
|
logger.debug(f"Publishing {datum['topic']} at {datum['timestamp']}") |
|
time_to_wait = datum["timestamp"] - time.perf_counter() |
|
if time_to_wait > 0.0: |
|
time.sleep(time_to_wait) |
|
pub_socket.send_string(datum["topic"], flags=zmq.SNDMORE) |
|
pub_socket.send(msgpack.packb(datum)) |
|
|
|
|
|
def recording_clock_transform(capture_recording_loc): |
|
with open(capture_recording_loc / "info.player.json") as f: |
|
info = json.load(f) |
|
return ClockTransform( |
|
source_clock_start=info["start_time_synced_s"], |
|
target_clock_start=time.perf_counter(), |
|
) |
|
|
|
|
|
class ClockTransform: |
|
def __init__(self, source_clock_start, target_clock_start): |
|
self.source_to_target_clock_offset = target_clock_start - source_clock_start |
|
|
|
def __str__(self): |
|
return ( |
|
f"<{type(self).__name__} source_to_target_clock_offset=" |
|
f"{self.source_to_target_clock_offset:.3f}>" |
|
) |
|
|
|
def transform_timestamps_to_target(self, datum): |
|
if "timestamp" in datum: |
|
datum["timestamp"] = self.to_target_clock(datum["timestamp"]) |
|
if datum.get("topic", "").startswith("surface"): |
|
for gaze_on_surfaces in datum.get("gaze_on_surfaces"): |
|
gaze_topic, gaze_timestamp = gaze_on_surfaces["base_data"] |
|
gaze_on_surfaces["base_data"] = gaze_topic, self.to_target_clock( |
|
gaze_timestamp |
|
) |
|
|
|
def to_target_clock(self, source_clock_time): |
|
return source_clock_time + self.source_to_target_clock_offset |
|
|
|
def to_source_clock(self, target_clock_time): |
|
return target_clock_time - self.source_to_target_clock_offset |
|
|
|
|
|
@contextlib.contextmanager |
|
def NetworkAPI(pupil_remote_server=50020, pub_port=50021): |
|
class Frontend(T.NamedTuple): |
|
pub_socket: zmq.Socket |
|
should_publish_flag: threading.Event |
|
|
|
logger = logging.getLogger(__name__ + ".IPC_backend") |
|
logger.debug("Starting IPC backend") |
|
ctx = zmq.Context() |
|
pub_socket = ctx.socket(zmq.PUB) |
|
pub_socket.bind(f"tcp://*:{pub_port}") |
|
logger.debug(f"PUB socket bound to {pub_socket.last_endpoint}") |
|
should_shutdown_flag = threading.Event() |
|
should_publish_flag = threading.Event() |
|
pupil_remote_thread = threading.Thread( |
|
target=_pupil_remote_server, |
|
args=( |
|
ctx, |
|
should_shutdown_flag, |
|
pupil_remote_server, |
|
pub_port, |
|
should_publish_flag, |
|
), |
|
) |
|
pupil_remote_thread.start() |
|
try: |
|
yield Frontend(pub_socket, should_publish_flag) |
|
finally: |
|
logger.debug("Shutting down IPC backend") |
|
should_shutdown_flag.set() |
|
logger.debug("Requesting Pupil Remote server shutdown") |
|
pupil_remote_thread.join() |
|
logger.debug("Pupil Remote server shutdown confirmed") |
|
pub_socket.close() |
|
ctx.term() |
|
ctx.destroy() |
|
logger.debug("IPC backend shutdown") |
|
|
|
|
|
def _pupil_remote_server( |
|
ctx, should_shutdown_flag, pupil_remote_port, sub_port, should_publish_flag |
|
): |
|
logger = logging.getLogger(__name__ + "._pupil_remote_server") |
|
pupil_remote = ctx.socket(zmq.REP) |
|
pupil_remote.bind(f"tcp://*:{pupil_remote_port}") |
|
try: |
|
logger.debug(f"Bound to {pupil_remote.last_endpoint}") |
|
while not should_shutdown_flag.is_set(): |
|
_handle_incoming_messages( |
|
pupil_remote, logger, sub_port, should_publish_flag |
|
) |
|
except KeyboardInterrupt: |
|
logger.debug("Caught KeyboardInterrupt") |
|
except Exception: |
|
logger.exception("Unhandled exception") |
|
pupil_remote.close() |
|
|
|
|
|
def _handle_incoming_messages(pupil_remote, logger, sub_port, should_publish_flag): |
|
if pupil_remote.poll(timeout=200): |
|
msg = pupil_remote.recv_string() |
|
if pupil_remote.get(zmq.RCVMORE): |
|
_drop_multi_frame_message(pupil_remote) |
|
response = "Multi-frame message dropped" |
|
else: |
|
response = _response_for_single_frame_message( |
|
msg, sub_port, should_publish_flag |
|
) |
|
logger.debug(f"Received {msg}. Responding {response}") |
|
pupil_remote.send_string(response) |
|
|
|
|
|
def _response_for_single_frame_message(msg, sub_port, should_publish_flag): |
|
if msg == "SUB_PORT": |
|
return str(sub_port) |
|
elif msg == "t": |
|
return str(time.perf_counter()) |
|
elif msg.startswith("R"): |
|
should_publish_flag.set() |
|
return "OK" |
|
elif msg == "r": |
|
should_publish_flag.clear() |
|
return "OK" |
|
elif msg in ("c", "C", "T"): |
|
return "OK" |
|
elif msg in ("PUB_PORT", "v"): |
|
return "NOT IMPLEMENTED" |
|
else: |
|
return "Unknown command" |
|
|
|
|
|
def _drop_multi_frame_message(socket): |
|
while socket.get(zmq.RCVMORE): |
|
socket.recv(flags=zmq.NOBLOCK) |
|
|
|
|
|
def yield_datum_with_smallest_timestamp(capture_recording_loc): |
|
""" |
|
Yields (topic, datum) tuples, where datum is the datum with the smallest |
|
timestamp. |
|
""" |
|
recording = { |
|
file.stem: load_pldata_file(capture_recording_loc, file.stem) |
|
for file in capture_recording_loc.glob("[!.]*.pldata") |
|
} |
|
# remove empty files |
|
recording = { |
|
topic: data for topic, data in recording.items() if len(data.timestamps) > 0 |
|
} |
|
if not recording: |
|
raise FileNotFoundError(f"No recording files found at {capture_recording_loc}") |
|
while recording: |
|
topic, pldata = min(recording.items(), key=lambda x: x[1].timestamps[0]) |
|
yield pldata.data.pop(0) |
|
if not pldata.data: |
|
# topic is empty, can discard everything |
|
del recording[topic] |
|
else: |
|
del pldata.topics[0] |
|
del pldata.timestamps[0] |
|
|
|
|
|
def load_pldata_file(directory, topic): |
|
class PLData(T.NamedTuple): |
|
data: T.List[T.Any] |
|
timestamps: T.List[float] |
|
topics: T.List[str] |
|
|
|
def serialized_dict_from_msgpack_bytes(data): |
|
return msgpack.unpackb( |
|
data, |
|
raw=False, |
|
use_list=False, |
|
ext_hook=msgpack_unpacking_ext_hook, |
|
) |
|
|
|
def msgpack_unpacking_ext_hook(self, code, data): |
|
SERIALIZED_DICT_MSGPACK_EXT_CODE = 13 |
|
if code == SERIALIZED_DICT_MSGPACK_EXT_CODE: |
|
return serialized_dict_from_msgpack_bytes(data) |
|
return msgpack.ExtType(code, data) |
|
|
|
ts_file = os.path.join(directory, topic + "_timestamps.npy") |
|
msgpack_file = os.path.join(directory, topic + ".pldata") |
|
try: |
|
data = [] |
|
topics = [] |
|
data_ts = np.load(ts_file).tolist() |
|
with open(msgpack_file, "rb") as fh: |
|
for topic, payload in msgpack.Unpacker(fh, raw=False, use_list=True): |
|
datum = serialized_dict_from_msgpack_bytes(payload) |
|
data.append(datum) |
|
topics.append(topic) |
|
except FileNotFoundError: |
|
data = [] |
|
data_ts = [] |
|
topics = [] |
|
|
|
return PLData(data, data_ts, topics) |
|
|
|
|
|
if __name__ == "__main__": |
|
logging.basicConfig(level=logging.DEBUG) |
|
main() |