/README Secret
Last active
October 31, 2024 19:22
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
python main.py --publish_url=http://localhost:2939/bar --subscribe_url=http://localhost:2939/foo | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import io | |
import os | |
import sys | |
import logging | |
class JPEGStreamParser: | |
def __init__(self, callback): | |
""" | |
Initializes the JPEGStreamParser. | |
:param callback: Function to be called when a complete JPEG is found. Receives bytes of the JPEG image. | |
""" | |
self.buffer = bytearray() | |
self.callback = callback | |
self.in_jpeg = False | |
self.start_idx = 0 # Keep track of the start index across feeds | |
def __enter__(self): | |
return self | |
def __exit__(self, exec_type, exec_val, exec_tb): | |
logging.info("JOSH closing jpeg parser via exit") | |
self.close() | |
def close(self): | |
logging.info("Closing jpeg parser") | |
self.callback(None) | |
def feed(self, data): | |
""" | |
Feed incoming data into the parser. | |
:param data: Incoming data bytes. | |
""" | |
self.buffer.extend(data) | |
while True: | |
# Find the start marker (0xFFD8) if not already in JPEG | |
if not self.in_jpeg: | |
self.start_idx = self.buffer.find(b'\xff\xd8', self.start_idx) | |
if self.start_idx == -1: | |
# No start marker found, buffer the data for later. | |
self.start_idx = max(0, len(self.buffer) - 1) # Move to the end of the buffer to prevent unnecessary re-scanning | |
return | |
self.in_jpeg = True | |
self.start_idx += 2 # Move past the start marker | |
# Look for the end marker (0xFFD9) | |
end_idx = self.buffer.find(b'\xff\xd9', self.start_idx) | |
if end_idx == -1: | |
# No end marker found, keep buffering | |
self.start_idx = max(0, len(self.buffer) - 1) # Move to the end of the buffer to prevent unnecessary re-scanning | |
return | |
# Extract the JPEG and call the callback | |
jpeg_data = self.buffer[:end_idx + 2] | |
self.callback(jpeg_data) | |
# Remove the processed JPEG from the buffer | |
self.buffer = self.buffer[end_idx + 2:] | |
self.in_jpeg = False | |
self.start_idx = 0 # Reset start_idx to begin from the start of the updated buffer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import asyncio | |
import json | |
import logging | |
import signal | |
import sys | |
import os | |
import traceback | |
import queue | |
from typing import List | |
import trickle | |
async def main(subscribe_url: str, publish_url: str, params: dict): | |
logger = logging.getLogger(__name__) | |
image_queue = queue.Queue() | |
# Set up async signal handling for SIGINT and SIGTERM | |
loop = asyncio.get_running_loop() | |
def stop_signal_handler(): | |
logging.info("Stopping program due to received signal") | |
loop.stop() | |
loop.add_signal_handler(signal.SIGINT, stop_signal_handler) | |
loop.add_signal_handler(signal.SIGTERM, stop_signal_handler) | |
try: | |
subscribe_task = asyncio.create_task(trickle.run_subscribe(subscribe_url, image_queue.put)) | |
publish_task = asyncio.create_task(trickle.run_publish(publish_url, image_queue)) | |
except Exception as e: | |
logging.error(f"Error starting socket handler or HTTP server: {e}") | |
logging.error(f"Stack trace:\n{traceback.format_exc()}") | |
raise e | |
try: | |
await asyncio.gather(subscribe_task, publish_task) | |
except Exception as e: | |
logging.error(f"Error stopping room handler: {e}") | |
logging.error(f"Stack trace:\n{traceback.format_exc()}") | |
raise e | |
if __name__ == "__main__": | |
logging.basicConfig( | |
format='%(asctime)s %(levelname)-8s %(message)s', | |
level=logging.INFO, | |
datefmt='%Y-%m-%d %H:%M:%S') | |
parser = argparse.ArgumentParser(description="Infer process to run the AI pipeline") | |
parser.add_argument( | |
"--initial-params", type=str, default="{}", help="Initial parameters for the pipeline" | |
) | |
parser.add_argument( | |
"--subscribe-url", type=str, required=True, help="url to pull incoming streams" | |
) | |
parser.add_argument( | |
"--publish-url", type=str, required=True, help="url to push outgoing streams" | |
) | |
args = parser.parse_args() | |
try: | |
params = json.loads(args.initial_params) | |
except Exception as e: | |
logging.error(f"Error parsing --initial-params: {e}") | |
sys.exit(1) | |
try: | |
asyncio.run( | |
main(args.subscribe_url, args.publish_url, params) | |
) | |
except Exception as e: | |
logging.error(f"Fatal error in main: {e}") | |
logging.error(f"Traceback:\n{''.join(traceback.format_tb(e.__traceback__))}") | |
sys.exit(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import aiohttp | |
import asyncio | |
import logging | |
import os | |
import threading | |
import subprocess | |
from trickle_subscriber import TrickleSubscriber | |
from trickle_publisher import TricklePublisher | |
from jpeg_parser import JPEGStreamParser | |
import segmenter | |
# target framerate | |
FRAMERATE=segmenter.FRAMERATE | |
GOP_SECS=segmenter.GOP_SECS | |
# TODO make this better configurable | |
GPU=segmenter.GPU | |
async def preprocess(subscribe_url: str, image_callback): | |
# TODO add some pre-processing parameters, eg image size | |
try: | |
ffmpeg = await launch_ffmpeg() | |
logging_task = asyncio.create_task(log_pipe_async(ffmpeg.stderr)) | |
subscribe_task = asyncio.create_task(subscribe(subscribe_url, ffmpeg.stdin)) | |
jpeg_task = asyncio.create_task(parse_jpegs(ffmpeg.stdout, image_callback)) | |
await asyncio.gather(ffmpeg.wait(), logging_task, subscribe_task, jpeg_task) | |
except Exception as e: | |
logging.error(f"preprocess got error {e}", e) | |
raise e | |
async def subscribe(subscribe_url, out_pipe): | |
subscriber = TrickleSubscriber(url=subscribe_url) | |
logging.info(f"JOSH - launching subscribe loop for {subscribe_url}") | |
while True: | |
segment = None | |
try: | |
segment = await subscriber.next() | |
if not segment: | |
break # complete | |
while True: | |
chunk = await segment.read() | |
if not chunk: | |
break # end of segment | |
out_pipe.write(chunk) | |
await out_pipe.drain() | |
except aiohttp.ClientError as e: | |
logging.info(f"Failed to read segment - {e}") | |
break # end of stream? | |
except Exception as e: | |
raise e | |
finally: | |
if segment: | |
await segment.close() | |
else: | |
# stream is complete | |
out_pipe.close() | |
break | |
async def launch_ffmpeg(): | |
if GPU: | |
ffmpeg_cmd = [ | |
'ffmpeg', | |
'-loglevel', 'warning', | |
'-hwaccel', 'cuda', | |
'-hwaccel_output_format', 'cuda', | |
'-i', 'pipe:0', # Read input from stdin | |
'-an', | |
'-vf', 'scale_cuda=w=512:h=512:force_original_aspect_ratio=decrease:force_divisible_by=2,hwdownload,format=nv12,fps={FRAMERATE}' | |
'-c:v', 'mjpeg', | |
'-start_number', '0', | |
'-q:v', '1', | |
'-f', 'image2pipe', | |
'pipe:1' # Output to stdout | |
] | |
else: | |
ffmpeg_cmd = [ | |
'ffmpeg', | |
'-loglevel', 'warning', | |
'-i', 'pipe:0', # Read input from stdin | |
'-an', | |
'-vf', f'scale=w=512:h=512:force_original_aspect_ratio=decrease:force_divisible_by=2,fps={FRAMERATE}', | |
'-c:v', 'mjpeg', | |
'-start_number', '0', | |
'-q:v', '1', | |
'-f', 'image2pipe', | |
'pipe:1' # Output to stdout | |
] | |
logging.info(f"ffmpeg (input) {ffmpeg_cmd}") | |
# Launch FFmpeg process with stdin, stdout, and stderr as pipes | |
process = await asyncio.create_subprocess_exec( | |
*ffmpeg_cmd, | |
stdin=asyncio.subprocess.PIPE, | |
stdout=asyncio.subprocess.PIPE, | |
stderr=asyncio.subprocess.PIPE, | |
) | |
return process # Return the process handle | |
async def log_pipe_async(pipe): | |
"""Reads from a pipe and logs each line.""" | |
while True: | |
line = await pipe.readline() | |
if not line: | |
break # Exit when the pipe is closed | |
# Decode the binary line and log it | |
logging.info(line.decode().strip()) | |
async def parse_jpegs(in_pipe, image_callback): | |
chunk_size = 32 * 1024 # read in 32kb chunks | |
with JPEGStreamParser(image_callback) as parser: | |
# TODO this does not work on asyncio streams - figure out how to | |
# disable os buffering on readsdisable buffering on reads | |
#pipe = os.fdopen(in_pipe.fileno(), 'rb', buffering=0) | |
while True: | |
chunk = await in_pipe.read(chunk_size) | |
if not chunk: | |
break | |
parser.feed(chunk) | |
def feed_ffmpeg(ffmpeg_fd, image_generator): | |
with os.fdopen(ffmpeg_fd, 'wb', buffering=0) as ffmpeg: | |
while True: | |
image = image_generator.get() | |
if image is None: | |
logging.info("Image generator empty, leaving feed_ffmpeg") | |
break | |
ffmpeg.write(image) | |
ffmpeg.flush() | |
async def postprocess(publish_url: str, image_generator): | |
try: | |
publisher = TricklePublisher(url=publish_url, mime_type="video/mp2t") | |
loop = asyncio.get_running_loop() | |
async def callback(pipe_file, pipe_name): | |
async with await publisher.next() as segment: | |
# convert pipe_fd into an asyncio friendly StreamReader | |
reader = asyncio.StreamReader() | |
protocol = asyncio.StreamReaderProtocol(reader) | |
transport, _ = await loop.connect_read_pipe(lambda: protocol, pipe_file) | |
while True: | |
sz = 32 * 1024 # read in chunks of 32KB | |
data = await reader.read(sz) | |
if not data: | |
break | |
await segment.write(data) | |
transport.close() | |
def sync_callback(pipe_fd, pipe_name): | |
# asyncio.connect_read_pipe expects explicit fd close | |
# so we have to manually read, detect eof, then close | |
r, w = os.pipe() | |
rf = os.fdopen(r, 'rb', buffering=0) | |
future = asyncio.run_coroutine_threadsafe(callback(rf, pipe_name), loop) | |
try: | |
while True: | |
data = pipe_fd.read(32 * 1024) | |
if not data: | |
break | |
os.write(w, data) | |
os.close(w) # streamreader is very sensitive about this | |
future.result() # This blocks in the thread until callback completes | |
rf.close() # also closes the read end of the pipe | |
# Ensure any exceptions in the coroutine are caught | |
except Exception as e: | |
logging.error(f"Error in sync_callback: {e}") | |
ffmpeg_read_fd, ffmpeg_write_fd = os.pipe() | |
segment_thread = threading.Thread(target=segmenter.segment_reading_process, args=(ffmpeg_read_fd, sync_callback)) | |
ffmpeg_feeder = threading.Thread(target=feed_ffmpeg, args=(ffmpeg_write_fd, image_generator)) | |
segment_thread.start() | |
ffmpeg_feeder.start() | |
def joins(): | |
segment_thread.join() | |
ffmpeg_feeder.join() | |
await asyncio.to_thread(joins) | |
logging.info("postprocess complete") | |
except Exception as e: | |
logging.error(f"postprocess got error {e}", e) | |
raise e | |
finally: | |
await publisher.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import errno | |
import time | |
import logging | |
import os | |
import select | |
import string | |
import subprocess | |
import sys | |
import random | |
import threading | |
from datetime import datetime | |
# Constants and initial values | |
READ_TIMEOUT = 2 | |
SLEEP_INTERVAL = 0.05 | |
# TODO make this better configurable | |
FRAMERATE=24 | |
GOP_SECS=3 | |
GPU=False | |
def create_named_pipe(pattern, pipe_index): | |
pipe_name = pattern % pipe_index | |
try: | |
os.mkfifo(pipe_name) | |
except OSError as e: | |
if e.errno != errno.EEXIST: | |
raise | |
return pipe_name | |
def remove_named_pipe(pipe_name): | |
try: | |
os.remove(pipe_name) | |
except OSError as e: | |
if e.errno != errno.ENOENT: | |
raise | |
def ffmpeg_cmd(in_pipe_fd, out_pattern): | |
if GPU: | |
cmd = [ | |
'ffmpeg', | |
'-loglevel', 'warning', | |
'-f', 'image2pipe', | |
'-framerate', f"{FRAMERATE}", | |
'-i', f'pipe:{in_pipe_fd}', | |
'-c:v', 'h264_nvenc', | |
'-bf', '0', # disable bframes for webrtc | |
'-g', f'{GOP_SECS*FRAMERATE}', | |
'-preset', 'p1', | |
'-tune', 'ull', | |
'-f', 'segment', | |
out_pattern | |
] | |
else: | |
cmd = [ | |
'ffmpeg', | |
'-loglevel', 'info', | |
'-f', 'image2pipe', | |
'-framerate', f"{FRAMERATE}", | |
'-i', f'pipe:{in_pipe_fd}', | |
#'-i', f'-', # stdin | |
'-c:v', 'libx264', | |
'-bf', '0', # disable bframes for webrtc | |
'-g', f'{GOP_SECS*FRAMERATE}', | |
'-preset', 'superfast', | |
'-tune', 'zerolatency', | |
'-f', 'segment', | |
out_pattern | |
] | |
logging.info(f"JOSH - ffmpeg (output) {cmd}") | |
return cmd | |
def read_from_pipe(pipe_name, callback, ffmpeg_proc): | |
fd = os.open(pipe_name, os.O_RDONLY | os.O_NONBLOCK) | |
# Polling to check if the pipe is ready | |
poller = select.poll() | |
poller.register(fd, select.POLLIN) | |
start_time = time.time() | |
while True: | |
# Wait for the pipe to become ready for reading | |
events = poller.poll(1000 * SLEEP_INTERVAL) | |
# If the pipe is ready, switch to blocking mode and read | |
if events: | |
os.set_blocking(fd, True) | |
break | |
# Check if ffmpeg has exited after polling | |
if ffmpeg_proc.poll() is not None: | |
logging.info(f"FFmpeg process has exited while waiting for pipe {pipe_name}") | |
os.close(fd) | |
return False | |
# Check if we've exceeded the timeout | |
if time.time() - start_time > READ_TIMEOUT: | |
logging.info(f"Timeout waiting for pipe {pipe_name}") | |
os.close(fd) | |
return False | |
# Sleep briefly before checking again | |
time.sleep(SLEEP_INTERVAL) | |
# Now that the pipe is ready, invoke the callback | |
# fdopen will implcitly close the supplied fd | |
with os.fdopen(fd, 'rb', buffering=0) as pipe_fd: | |
callback(pipe_fd, pipe_name) | |
remove_named_pipe(pipe_name) | |
return True | |
def segment_reading_process(in_fd, callback): | |
logging.info("JOSH - in segment reading process") | |
pipe_index = 0 | |
out_pattern = generate_random_string() + "-%d.ts" | |
# Start by creating the first two named pipes | |
current_pipe = create_named_pipe(out_pattern, pipe_index) | |
next_pipe = create_named_pipe(out_pattern, pipe_index + 1) | |
# Launch FFmpeg process with stdin, stdout, and stderr as pipes | |
proc = subprocess.Popen( | |
ffmpeg_cmd(in_fd, out_pattern), | |
#stdin=subprocess.PIPE, | |
#stdin=in_fd, | |
stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE, | |
pass_fds=(in_fd,), | |
) | |
# Create a thread to handle stderr redirection | |
thread = threading.Thread(target=print_proc, args=(proc.stderr,)) | |
thread.start() | |
thread2 = threading.Thread(target=print_proc, args=(proc.stdout,)) | |
thread2.start() | |
try: | |
while True: | |
# Read from the current pipe, exit the loop if there's a timeout or ffmpeg exit | |
if not read_from_pipe(current_pipe, callback, proc): | |
logging.info("Exiting ffmpeg (output) due to timeout or process exit.") | |
break | |
# Move to the next pipes in the sequence | |
pipe_index += 1 | |
current_pipe = next_pipe | |
# Create the new next pipe in the sequence | |
next_pipe = create_named_pipe(out_pattern, pipe_index + 1) | |
except Exception as e: | |
logging.info(f"FFmpeg (output) error : {e} - {current_pipe}") | |
finally: | |
os.close(in_fd) | |
#proc.stdin.close() | |
logging.info("awaitng ffmpeg (output)") | |
proc.wait() | |
logging.info("proc complete ffmpeg (output)") | |
thread.join() | |
thread2.join() | |
logging.info("ffmpeg (output) complete") | |
#(stdout, stderr) = proc.communicate() | |
#logging.info("FFmpeg (output)") | |
#logging.info(stderr.decode()) | |
#logging.info(stdout.decode()) | |
# Cleanup remaining pipes | |
remove_named_pipe(current_pipe) | |
remove_named_pipe(next_pipe) | |
def print_proc(f): | |
"""Reads stderr from a subprocess and writes it to sys.stderr.""" | |
for line in iter(f.readline, b''): | |
sys.stderr.write(line.decode()) | |
def generate_random_string(): | |
"""Generates a random string of length 5.""" | |
length=5 | |
letters = string.ascii_letters + string.digits | |
return ''.join(random.choice(letters) for i in range(length)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ pkgs ? import <nixpkgs> {} }: | |
let | |
in | |
pkgs.mkShell { | |
packages = [ pkgs.ripgrep pkgs.go_1_22 pkgs.python3 pkgs.python3Packages.flask pkgs.python3Packages.numpy pkgs.ffmpeg-full pkgs.python3Packages.pillow pkgs.python3Packages.aiohttp pkgs.python3Packages.aiofiles ]; | |
nativeBuildInputs = [ pkgs.pkg-config pkgs.bzip2 pkgs.zlib pkgs.iconv ] ++ pkgs.lib.optionals pkgs.stdenv.isDarwin [ | |
pkgs.darwin.apple_sdk.frameworks.VideoToolbox | |
pkgs.darwin.apple_sdk.frameworks.OpenGL | |
pkgs.darwin.apple_sdk.frameworks.AppKit | |
]; | |
shellHooks = '' | |
export PKG_CONFIG_PATH=$HOME/livepeer/compiled/lib/pkgconfig | |
''; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import aiohttp | |
import logging | |
from contextlib import asynccontextmanager | |
class TricklePublisher: | |
def __init__(self, url: str, mime_type: str): | |
self.url = url | |
self.mime_type = mime_type | |
self.idx = 0 # Start index for POSTs | |
self.next_writer = None | |
self.lock = asyncio.Lock() # Lock to manage concurrent access | |
self.session = aiohttp.ClientSession() | |
async def __aenter__(self): | |
"""Enter context manager.""" | |
return self | |
async def __aexit__(self, exc_type, exc_value, traceback): | |
"""Exit context manager and close the session.""" | |
await self.close() | |
def streamIdx(self): | |
return f"{self.url}/{self.idx}" | |
async def preconnect(self): | |
"""Preconnect to the server by initiating a POST request to the current index.""" | |
url = self.streamIdx() | |
logging.info(f"Preconnecting to URL: {url}") | |
try: | |
# we will be incrementally writing data into this queue | |
queue = asyncio.Queue() | |
asyncio.create_task(self._run_post(url, queue)) | |
return queue | |
except aiohttp.ClientError as e: | |
logging.error(f"Failed to complete POST for {self.streamIdx()}: {e}") | |
return None | |
async def _run_post(self, url, queue): | |
try: | |
resp = await self.session.post( | |
url, | |
headers={'Connection': 'close', 'Content-Type': self.mime_type}, | |
data=self._stream_data(queue) | |
) | |
# TODO propagate errors? | |
if resp.status != 200: | |
body = await resp.text() | |
logging.error(f"Trickle POST failed {self.streamIdx()}, status code: {resp.status}, msg: {body}") | |
except Exception as e: | |
logging.error(f"Trickle POST exception {self.streamIdx()} - {e}") | |
return None | |
async def _stream_data(self, queue): | |
"""Stream data from the queue for the POST request.""" | |
while True: | |
chunk = await queue.get() | |
if chunk is None: # Stop signal | |
break | |
yield chunk | |
async def next(self): | |
"""Start or retrieve a pending POST request and preconnect for the next segment.""" | |
async with self.lock: | |
if self.next_writer is None: | |
logging.info(f"No pending connection, preconnecting {self.streamIdx()}...") | |
self.next_writer = await self.preconnect() | |
writer = self.next_writer | |
self.next_writer = None | |
# Set up the next POST in the background | |
asyncio.create_task(self._preconnect_next_segment()) | |
return SegmentWriter(writer) | |
async def _preconnect_next_segment(self): | |
"""Preconnect to the next POST in the background.""" | |
logging.info(f"Setting up next connection for {self.streamIdx()}") | |
async with self.lock: | |
if self.next_writer is not None: | |
return | |
self.idx += 1 # Increment the index for the next POST | |
next_writer = await self.preconnect() | |
if next_writer: | |
self.next_writer = next_writer | |
async def close(self): | |
"""Close the session when done.""" | |
logging.info(f"Closing {self.url}") | |
if self.next_writer: | |
s = SegmentWriter(self.next_writer) | |
await s.close() | |
await self.session.delete(self.url) | |
await self.session.close() | |
class SegmentWriter: | |
def __init__(self, queue: asyncio.Queue): | |
self.queue = queue | |
async def write(self, data): | |
"""Write data to the current segment.""" | |
await self.queue.put(data) | |
async def close(self): | |
"""Ensure the request is properly closed when done.""" | |
await self.queue.put(None) # Send None to signal end of data | |
async def __aenter__(self): | |
"""Enter context manager.""" | |
return self | |
async def __aexit__(self, exc_type, exc_value, traceback): | |
"""Exit context manager and close the connection.""" | |
await self.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import aiohttp | |
import logging | |
import sys | |
class TrickleSubscriber: | |
def __init__(self, url: str): | |
self.base_url = url | |
self.idx = -1 # Start with -1 for 'latest' index | |
self.pending_get = None # Pre-initialized GET request | |
self.lock = asyncio.Lock() # Lock to manage concurrent access | |
self.session = aiohttp.ClientSession() | |
self.errored = False | |
async def get_index(self, resp): | |
"""Extract the index from the response headers.""" | |
if resp is None: | |
return -1 | |
idx_str = resp.headers.get("Lp-Trickle-Idx") | |
try: | |
idx = int(idx_str) | |
except (TypeError, ValueError): | |
return -1 | |
return idx | |
async def preconnect(self): | |
"""Preconnect to the server by making a GET request to fetch the next segment.""" | |
url = f"{self.base_url}/{self.idx}" | |
logging.info(f"Trickle sub Preconnecting to URL: {url}") | |
try: | |
resp = await self.session.get(url, headers={'Connection':'close'}) | |
if resp.status != 200: | |
body = await resp.text() | |
resp.release() | |
logging.error(f"Trickle sub Failed GET segment, status code: {resp.status}, msg: {body}") | |
self.errored = True | |
return None | |
# Return the response for later processing | |
return resp | |
except aiohttp.ClientError as e: | |
logging.error(f"Trickle sub Failed to complete GET for next segment: {e}") | |
self.errored = True | |
return None | |
async def next(self): | |
"""Retrieve data from the current segment and set up the next segment concurrently.""" | |
async with self.lock: | |
if self.errored: | |
logging.info("Trickle subscription closed or errored") | |
return None | |
# If we don't have a pending GET request, preconnect | |
if self.pending_get is None: | |
logging.info("Trickle sub No pending connection, preconnecting...") | |
self.pending_get = await self.preconnect() | |
# Extract the current connection to use for reading | |
conn = self.pending_get | |
self.pending_get = None | |
# Extract and set the next index from the response headers | |
idx = await self.get_index(conn) | |
if idx != -1: | |
self.idx = idx + 1 | |
# Set up the next connection in the background | |
asyncio.create_task(self._preconnect_next_segment()) | |
return Segment(conn) | |
async def _preconnect_next_segment(self): | |
"""Preconnect to the next segment in the background.""" | |
logging.info(f"Trickle sub setting up next connection for index {self.idx}") | |
async with self.lock: | |
if self.pending_get is not None: | |
return | |
next_conn = await self.preconnect() | |
if next_conn: | |
self.pending_get = next_conn | |
next_idx = await self.get_index(next_conn) | |
if next_idx != -1: | |
self.idx = next_idx + 1 | |
class Segment: | |
def __init__(self, response): | |
self.response = response | |
async def read(self, chunk_size=32 * 1024): | |
"""Read the next chunk of the segment.""" | |
if not self.response: | |
await self.close() | |
return None | |
chunk = await self.response.content.read(chunk_size) | |
if not chunk: | |
await self.close() | |
return chunk | |
async def close(self): | |
"""Ensure the response is properly closed when done.""" | |
if self.response is None: | |
return | |
if not self.response.closed: | |
await self.response.release() | |
await self.response.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment