Skip to content

Instantly share code, notes, and snippets.

@j0sh

j0sh/README Secret

Last active October 31, 2024 19:22
python main.py --publish_url=http://localhost:2939/bar --subscribe_url=http://localhost:2939/foo
import io
import os
import sys
import logging
class JPEGStreamParser:
def __init__(self, callback):
"""
Initializes the JPEGStreamParser.
:param callback: Function to be called when a complete JPEG is found. Receives bytes of the JPEG image.
"""
self.buffer = bytearray()
self.callback = callback
self.in_jpeg = False
self.start_idx = 0 # Keep track of the start index across feeds
def __enter__(self):
return self
def __exit__(self, exec_type, exec_val, exec_tb):
logging.info("JOSH closing jpeg parser via exit")
self.close()
def close(self):
logging.info("Closing jpeg parser")
self.callback(None)
def feed(self, data):
"""
Feed incoming data into the parser.
:param data: Incoming data bytes.
"""
self.buffer.extend(data)
while True:
# Find the start marker (0xFFD8) if not already in JPEG
if not self.in_jpeg:
self.start_idx = self.buffer.find(b'\xff\xd8', self.start_idx)
if self.start_idx == -1:
# No start marker found, buffer the data for later.
self.start_idx = max(0, len(self.buffer) - 1) # Move to the end of the buffer to prevent unnecessary re-scanning
return
self.in_jpeg = True
self.start_idx += 2 # Move past the start marker
# Look for the end marker (0xFFD9)
end_idx = self.buffer.find(b'\xff\xd9', self.start_idx)
if end_idx == -1:
# No end marker found, keep buffering
self.start_idx = max(0, len(self.buffer) - 1) # Move to the end of the buffer to prevent unnecessary re-scanning
return
# Extract the JPEG and call the callback
jpeg_data = self.buffer[:end_idx + 2]
self.callback(jpeg_data)
# Remove the processed JPEG from the buffer
self.buffer = self.buffer[end_idx + 2:]
self.in_jpeg = False
self.start_idx = 0 # Reset start_idx to begin from the start of the updated buffer
import argparse
import asyncio
import json
import logging
import signal
import sys
import os
import traceback
import queue
from typing import List
import trickle
async def main(subscribe_url: str, publish_url: str, params: dict):
logger = logging.getLogger(__name__)
image_queue = queue.Queue()
# Set up async signal handling for SIGINT and SIGTERM
loop = asyncio.get_running_loop()
def stop_signal_handler():
logging.info("Stopping program due to received signal")
loop.stop()
loop.add_signal_handler(signal.SIGINT, stop_signal_handler)
loop.add_signal_handler(signal.SIGTERM, stop_signal_handler)
try:
subscribe_task = asyncio.create_task(trickle.run_subscribe(subscribe_url, image_queue.put))
publish_task = asyncio.create_task(trickle.run_publish(publish_url, image_queue))
except Exception as e:
logging.error(f"Error starting socket handler or HTTP server: {e}")
logging.error(f"Stack trace:\n{traceback.format_exc()}")
raise e
try:
await asyncio.gather(subscribe_task, publish_task)
except Exception as e:
logging.error(f"Error stopping room handler: {e}")
logging.error(f"Stack trace:\n{traceback.format_exc()}")
raise e
if __name__ == "__main__":
logging.basicConfig(
format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S')
parser = argparse.ArgumentParser(description="Infer process to run the AI pipeline")
parser.add_argument(
"--initial-params", type=str, default="{}", help="Initial parameters for the pipeline"
)
parser.add_argument(
"--subscribe-url", type=str, required=True, help="url to pull incoming streams"
)
parser.add_argument(
"--publish-url", type=str, required=True, help="url to push outgoing streams"
)
args = parser.parse_args()
try:
params = json.loads(args.initial_params)
except Exception as e:
logging.error(f"Error parsing --initial-params: {e}")
sys.exit(1)
try:
asyncio.run(
main(args.subscribe_url, args.publish_url, params)
)
except Exception as e:
logging.error(f"Fatal error in main: {e}")
logging.error(f"Traceback:\n{''.join(traceback.format_tb(e.__traceback__))}")
sys.exit(1)
import aiohttp
import asyncio
import logging
import os
import threading
import subprocess
from trickle_subscriber import TrickleSubscriber
from trickle_publisher import TricklePublisher
from jpeg_parser import JPEGStreamParser
import segmenter
# target framerate
FRAMERATE=segmenter.FRAMERATE
GOP_SECS=segmenter.GOP_SECS
# TODO make this better configurable
GPU=segmenter.GPU
async def preprocess(subscribe_url: str, image_callback):
# TODO add some pre-processing parameters, eg image size
try:
ffmpeg = await launch_ffmpeg()
logging_task = asyncio.create_task(log_pipe_async(ffmpeg.stderr))
subscribe_task = asyncio.create_task(subscribe(subscribe_url, ffmpeg.stdin))
jpeg_task = asyncio.create_task(parse_jpegs(ffmpeg.stdout, image_callback))
await asyncio.gather(ffmpeg.wait(), logging_task, subscribe_task, jpeg_task)
except Exception as e:
logging.error(f"preprocess got error {e}", e)
raise e
async def subscribe(subscribe_url, out_pipe):
subscriber = TrickleSubscriber(url=subscribe_url)
logging.info(f"JOSH - launching subscribe loop for {subscribe_url}")
while True:
segment = None
try:
segment = await subscriber.next()
if not segment:
break # complete
while True:
chunk = await segment.read()
if not chunk:
break # end of segment
out_pipe.write(chunk)
await out_pipe.drain()
except aiohttp.ClientError as e:
logging.info(f"Failed to read segment - {e}")
break # end of stream?
except Exception as e:
raise e
finally:
if segment:
await segment.close()
else:
# stream is complete
out_pipe.close()
break
async def launch_ffmpeg():
if GPU:
ffmpeg_cmd = [
'ffmpeg',
'-loglevel', 'warning',
'-hwaccel', 'cuda',
'-hwaccel_output_format', 'cuda',
'-i', 'pipe:0', # Read input from stdin
'-an',
'-vf', 'scale_cuda=w=512:h=512:force_original_aspect_ratio=decrease:force_divisible_by=2,hwdownload,format=nv12,fps={FRAMERATE}'
'-c:v', 'mjpeg',
'-start_number', '0',
'-q:v', '1',
'-f', 'image2pipe',
'pipe:1' # Output to stdout
]
else:
ffmpeg_cmd = [
'ffmpeg',
'-loglevel', 'warning',
'-i', 'pipe:0', # Read input from stdin
'-an',
'-vf', f'scale=w=512:h=512:force_original_aspect_ratio=decrease:force_divisible_by=2,fps={FRAMERATE}',
'-c:v', 'mjpeg',
'-start_number', '0',
'-q:v', '1',
'-f', 'image2pipe',
'pipe:1' # Output to stdout
]
logging.info(f"ffmpeg (input) {ffmpeg_cmd}")
# Launch FFmpeg process with stdin, stdout, and stderr as pipes
process = await asyncio.create_subprocess_exec(
*ffmpeg_cmd,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
return process # Return the process handle
async def log_pipe_async(pipe):
"""Reads from a pipe and logs each line."""
while True:
line = await pipe.readline()
if not line:
break # Exit when the pipe is closed
# Decode the binary line and log it
logging.info(line.decode().strip())
async def parse_jpegs(in_pipe, image_callback):
chunk_size = 32 * 1024 # read in 32kb chunks
with JPEGStreamParser(image_callback) as parser:
# TODO this does not work on asyncio streams - figure out how to
# disable os buffering on readsdisable buffering on reads
#pipe = os.fdopen(in_pipe.fileno(), 'rb', buffering=0)
while True:
chunk = await in_pipe.read(chunk_size)
if not chunk:
break
parser.feed(chunk)
def feed_ffmpeg(ffmpeg_fd, image_generator):
with os.fdopen(ffmpeg_fd, 'wb', buffering=0) as ffmpeg:
while True:
image = image_generator.get()
if image is None:
logging.info("Image generator empty, leaving feed_ffmpeg")
break
ffmpeg.write(image)
ffmpeg.flush()
async def postprocess(publish_url: str, image_generator):
try:
publisher = TricklePublisher(url=publish_url, mime_type="video/mp2t")
loop = asyncio.get_running_loop()
async def callback(pipe_file, pipe_name):
async with await publisher.next() as segment:
# convert pipe_fd into an asyncio friendly StreamReader
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
transport, _ = await loop.connect_read_pipe(lambda: protocol, pipe_file)
while True:
sz = 32 * 1024 # read in chunks of 32KB
data = await reader.read(sz)
if not data:
break
await segment.write(data)
transport.close()
def sync_callback(pipe_fd, pipe_name):
# asyncio.connect_read_pipe expects explicit fd close
# so we have to manually read, detect eof, then close
r, w = os.pipe()
rf = os.fdopen(r, 'rb', buffering=0)
future = asyncio.run_coroutine_threadsafe(callback(rf, pipe_name), loop)
try:
while True:
data = pipe_fd.read(32 * 1024)
if not data:
break
os.write(w, data)
os.close(w) # streamreader is very sensitive about this
future.result() # This blocks in the thread until callback completes
rf.close() # also closes the read end of the pipe
# Ensure any exceptions in the coroutine are caught
except Exception as e:
logging.error(f"Error in sync_callback: {e}")
ffmpeg_read_fd, ffmpeg_write_fd = os.pipe()
segment_thread = threading.Thread(target=segmenter.segment_reading_process, args=(ffmpeg_read_fd, sync_callback))
ffmpeg_feeder = threading.Thread(target=feed_ffmpeg, args=(ffmpeg_write_fd, image_generator))
segment_thread.start()
ffmpeg_feeder.start()
def joins():
segment_thread.join()
ffmpeg_feeder.join()
await asyncio.to_thread(joins)
logging.info("postprocess complete")
except Exception as e:
logging.error(f"postprocess got error {e}", e)
raise e
finally:
await publisher.close()
import errno
import time
import logging
import os
import select
import string
import subprocess
import sys
import random
import threading
from datetime import datetime
# Constants and initial values
READ_TIMEOUT = 2
SLEEP_INTERVAL = 0.05
# TODO make this better configurable
FRAMERATE=24
GOP_SECS=3
GPU=False
def create_named_pipe(pattern, pipe_index):
pipe_name = pattern % pipe_index
try:
os.mkfifo(pipe_name)
except OSError as e:
if e.errno != errno.EEXIST:
raise
return pipe_name
def remove_named_pipe(pipe_name):
try:
os.remove(pipe_name)
except OSError as e:
if e.errno != errno.ENOENT:
raise
def ffmpeg_cmd(in_pipe_fd, out_pattern):
if GPU:
cmd = [
'ffmpeg',
'-loglevel', 'warning',
'-f', 'image2pipe',
'-framerate', f"{FRAMERATE}",
'-i', f'pipe:{in_pipe_fd}',
'-c:v', 'h264_nvenc',
'-bf', '0', # disable bframes for webrtc
'-g', f'{GOP_SECS*FRAMERATE}',
'-preset', 'p1',
'-tune', 'ull',
'-f', 'segment',
out_pattern
]
else:
cmd = [
'ffmpeg',
'-loglevel', 'info',
'-f', 'image2pipe',
'-framerate', f"{FRAMERATE}",
'-i', f'pipe:{in_pipe_fd}',
#'-i', f'-', # stdin
'-c:v', 'libx264',
'-bf', '0', # disable bframes for webrtc
'-g', f'{GOP_SECS*FRAMERATE}',
'-preset', 'superfast',
'-tune', 'zerolatency',
'-f', 'segment',
out_pattern
]
logging.info(f"JOSH - ffmpeg (output) {cmd}")
return cmd
def read_from_pipe(pipe_name, callback, ffmpeg_proc):
fd = os.open(pipe_name, os.O_RDONLY | os.O_NONBLOCK)
# Polling to check if the pipe is ready
poller = select.poll()
poller.register(fd, select.POLLIN)
start_time = time.time()
while True:
# Wait for the pipe to become ready for reading
events = poller.poll(1000 * SLEEP_INTERVAL)
# If the pipe is ready, switch to blocking mode and read
if events:
os.set_blocking(fd, True)
break
# Check if ffmpeg has exited after polling
if ffmpeg_proc.poll() is not None:
logging.info(f"FFmpeg process has exited while waiting for pipe {pipe_name}")
os.close(fd)
return False
# Check if we've exceeded the timeout
if time.time() - start_time > READ_TIMEOUT:
logging.info(f"Timeout waiting for pipe {pipe_name}")
os.close(fd)
return False
# Sleep briefly before checking again
time.sleep(SLEEP_INTERVAL)
# Now that the pipe is ready, invoke the callback
# fdopen will implcitly close the supplied fd
with os.fdopen(fd, 'rb', buffering=0) as pipe_fd:
callback(pipe_fd, pipe_name)
remove_named_pipe(pipe_name)
return True
def segment_reading_process(in_fd, callback):
logging.info("JOSH - in segment reading process")
pipe_index = 0
out_pattern = generate_random_string() + "-%d.ts"
# Start by creating the first two named pipes
current_pipe = create_named_pipe(out_pattern, pipe_index)
next_pipe = create_named_pipe(out_pattern, pipe_index + 1)
# Launch FFmpeg process with stdin, stdout, and stderr as pipes
proc = subprocess.Popen(
ffmpeg_cmd(in_fd, out_pattern),
#stdin=subprocess.PIPE,
#stdin=in_fd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
pass_fds=(in_fd,),
)
# Create a thread to handle stderr redirection
thread = threading.Thread(target=print_proc, args=(proc.stderr,))
thread.start()
thread2 = threading.Thread(target=print_proc, args=(proc.stdout,))
thread2.start()
try:
while True:
# Read from the current pipe, exit the loop if there's a timeout or ffmpeg exit
if not read_from_pipe(current_pipe, callback, proc):
logging.info("Exiting ffmpeg (output) due to timeout or process exit.")
break
# Move to the next pipes in the sequence
pipe_index += 1
current_pipe = next_pipe
# Create the new next pipe in the sequence
next_pipe = create_named_pipe(out_pattern, pipe_index + 1)
except Exception as e:
logging.info(f"FFmpeg (output) error : {e} - {current_pipe}")
finally:
os.close(in_fd)
#proc.stdin.close()
logging.info("awaitng ffmpeg (output)")
proc.wait()
logging.info("proc complete ffmpeg (output)")
thread.join()
thread2.join()
logging.info("ffmpeg (output) complete")
#(stdout, stderr) = proc.communicate()
#logging.info("FFmpeg (output)")
#logging.info(stderr.decode())
#logging.info(stdout.decode())
# Cleanup remaining pipes
remove_named_pipe(current_pipe)
remove_named_pipe(next_pipe)
def print_proc(f):
"""Reads stderr from a subprocess and writes it to sys.stderr."""
for line in iter(f.readline, b''):
sys.stderr.write(line.decode())
def generate_random_string():
"""Generates a random string of length 5."""
length=5
letters = string.ascii_letters + string.digits
return ''.join(random.choice(letters) for i in range(length))
{ pkgs ? import <nixpkgs> {} }:
let
in
pkgs.mkShell {
packages = [ pkgs.ripgrep pkgs.go_1_22 pkgs.python3 pkgs.python3Packages.flask pkgs.python3Packages.numpy pkgs.ffmpeg-full pkgs.python3Packages.pillow pkgs.python3Packages.aiohttp pkgs.python3Packages.aiofiles ];
nativeBuildInputs = [ pkgs.pkg-config pkgs.bzip2 pkgs.zlib pkgs.iconv ] ++ pkgs.lib.optionals pkgs.stdenv.isDarwin [
pkgs.darwin.apple_sdk.frameworks.VideoToolbox
pkgs.darwin.apple_sdk.frameworks.OpenGL
pkgs.darwin.apple_sdk.frameworks.AppKit
];
shellHooks = ''
export PKG_CONFIG_PATH=$HOME/livepeer/compiled/lib/pkgconfig
'';
}
import asyncio
import aiohttp
import logging
from contextlib import asynccontextmanager
class TricklePublisher:
def __init__(self, url: str, mime_type: str):
self.url = url
self.mime_type = mime_type
self.idx = 0 # Start index for POSTs
self.next_writer = None
self.lock = asyncio.Lock() # Lock to manage concurrent access
self.session = aiohttp.ClientSession()
async def __aenter__(self):
"""Enter context manager."""
return self
async def __aexit__(self, exc_type, exc_value, traceback):
"""Exit context manager and close the session."""
await self.close()
def streamIdx(self):
return f"{self.url}/{self.idx}"
async def preconnect(self):
"""Preconnect to the server by initiating a POST request to the current index."""
url = self.streamIdx()
logging.info(f"Preconnecting to URL: {url}")
try:
# we will be incrementally writing data into this queue
queue = asyncio.Queue()
asyncio.create_task(self._run_post(url, queue))
return queue
except aiohttp.ClientError as e:
logging.error(f"Failed to complete POST for {self.streamIdx()}: {e}")
return None
async def _run_post(self, url, queue):
try:
resp = await self.session.post(
url,
headers={'Connection': 'close', 'Content-Type': self.mime_type},
data=self._stream_data(queue)
)
# TODO propagate errors?
if resp.status != 200:
body = await resp.text()
logging.error(f"Trickle POST failed {self.streamIdx()}, status code: {resp.status}, msg: {body}")
except Exception as e:
logging.error(f"Trickle POST exception {self.streamIdx()} - {e}")
return None
async def _stream_data(self, queue):
"""Stream data from the queue for the POST request."""
while True:
chunk = await queue.get()
if chunk is None: # Stop signal
break
yield chunk
async def next(self):
"""Start or retrieve a pending POST request and preconnect for the next segment."""
async with self.lock:
if self.next_writer is None:
logging.info(f"No pending connection, preconnecting {self.streamIdx()}...")
self.next_writer = await self.preconnect()
writer = self.next_writer
self.next_writer = None
# Set up the next POST in the background
asyncio.create_task(self._preconnect_next_segment())
return SegmentWriter(writer)
async def _preconnect_next_segment(self):
"""Preconnect to the next POST in the background."""
logging.info(f"Setting up next connection for {self.streamIdx()}")
async with self.lock:
if self.next_writer is not None:
return
self.idx += 1 # Increment the index for the next POST
next_writer = await self.preconnect()
if next_writer:
self.next_writer = next_writer
async def close(self):
"""Close the session when done."""
logging.info(f"Closing {self.url}")
if self.next_writer:
s = SegmentWriter(self.next_writer)
await s.close()
await self.session.delete(self.url)
await self.session.close()
class SegmentWriter:
def __init__(self, queue: asyncio.Queue):
self.queue = queue
async def write(self, data):
"""Write data to the current segment."""
await self.queue.put(data)
async def close(self):
"""Ensure the request is properly closed when done."""
await self.queue.put(None) # Send None to signal end of data
async def __aenter__(self):
"""Enter context manager."""
return self
async def __aexit__(self, exc_type, exc_value, traceback):
"""Exit context manager and close the connection."""
await self.close()
import asyncio
import aiohttp
import logging
import sys
class TrickleSubscriber:
def __init__(self, url: str):
self.base_url = url
self.idx = -1 # Start with -1 for 'latest' index
self.pending_get = None # Pre-initialized GET request
self.lock = asyncio.Lock() # Lock to manage concurrent access
self.session = aiohttp.ClientSession()
self.errored = False
async def get_index(self, resp):
"""Extract the index from the response headers."""
if resp is None:
return -1
idx_str = resp.headers.get("Lp-Trickle-Idx")
try:
idx = int(idx_str)
except (TypeError, ValueError):
return -1
return idx
async def preconnect(self):
"""Preconnect to the server by making a GET request to fetch the next segment."""
url = f"{self.base_url}/{self.idx}"
logging.info(f"Trickle sub Preconnecting to URL: {url}")
try:
resp = await self.session.get(url, headers={'Connection':'close'})
if resp.status != 200:
body = await resp.text()
resp.release()
logging.error(f"Trickle sub Failed GET segment, status code: {resp.status}, msg: {body}")
self.errored = True
return None
# Return the response for later processing
return resp
except aiohttp.ClientError as e:
logging.error(f"Trickle sub Failed to complete GET for next segment: {e}")
self.errored = True
return None
async def next(self):
"""Retrieve data from the current segment and set up the next segment concurrently."""
async with self.lock:
if self.errored:
logging.info("Trickle subscription closed or errored")
return None
# If we don't have a pending GET request, preconnect
if self.pending_get is None:
logging.info("Trickle sub No pending connection, preconnecting...")
self.pending_get = await self.preconnect()
# Extract the current connection to use for reading
conn = self.pending_get
self.pending_get = None
# Extract and set the next index from the response headers
idx = await self.get_index(conn)
if idx != -1:
self.idx = idx + 1
# Set up the next connection in the background
asyncio.create_task(self._preconnect_next_segment())
return Segment(conn)
async def _preconnect_next_segment(self):
"""Preconnect to the next segment in the background."""
logging.info(f"Trickle sub setting up next connection for index {self.idx}")
async with self.lock:
if self.pending_get is not None:
return
next_conn = await self.preconnect()
if next_conn:
self.pending_get = next_conn
next_idx = await self.get_index(next_conn)
if next_idx != -1:
self.idx = next_idx + 1
class Segment:
def __init__(self, response):
self.response = response
async def read(self, chunk_size=32 * 1024):
"""Read the next chunk of the segment."""
if not self.response:
await self.close()
return None
chunk = await self.response.content.read(chunk_size)
if not chunk:
await self.close()
return chunk
async def close(self):
"""Ensure the response is properly closed when done."""
if self.response is None:
return
if not self.response.closed:
await self.response.release()
await self.response.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment