Skip to content

Instantly share code, notes, and snippets.

@monyone
Created April 9, 2022 07:03
Show Gist options
  • Save monyone/5b0fa6fe905815ee0e882f0d09e30d2e to your computer and use it in GitHub Desktop.
Save monyone/5b0fa6fe905815ee0e882f0d09e30d2e to your computer and use it in GitHub Desktop.
ProMPEG(RTP-FEC) on HTTP-DATAGRAM over QUIC-DATAGRAM
#!/usr/bin/env python3
import argparse
import asyncio
import sys
from aioquic.asyncio import QuicConnectionProtocol, serve
from aioquic.h3.connection import H3_ALPN, H3Connection
from aioquic.h3.events import HeadersReceived, WebTransportStreamDataReceived, DatagramReceived
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import stream_is_unidirectional
from aioquic.quic.events import ProtocolNegotiated, StreamReset, QuicEvent, ConnectionTerminated, StreamDataReceived
from aioquic.quic.logger import QuicLogger
from pprint import pprint
BIND_ADDRESS = '::1'
BIND_PORT = 4433
quic_logger = QuicLogger()
def rtp(payload, payload_type, sequence_number):
packet = bytearray()
packet += b'\x80'
packet += (0x00 | payload_type).to_bytes(1, byteorder='big')
packet += sequence_number.to_bytes(2, byteorder='big')
packet += b'\x00\x00\x00\x00'
packet += b'\x00\x00\x00\x00'
packet += payload
return bytes(packet)
def fec(payload, payload_type, SNBase_low_bits, D, Offset, NA):
packet = bytearray()
packet += SNBase_low_bits.to_bytes(2, byteorder='big')
packet += len(payload).to_bytes(2, byteorder='big')
packet += (0x00 | payload_type).to_bytes(1, byteorder='big')
packet += b'\x00\x00\x00'
packet += b'\x00\x00\x00\x00'
packet += (D << 6).to_bytes(1, byteorder='big')
packet += Offset.to_bytes(1, byteorder='big')
packet += NA.to_bytes(1, byteorder='big')
packet += b'\x00'
packet += payload
return bytes(packet)
class Client:
def __init__(self, stream_id, http3, protocol):
self.stream_id = stream_id
self.http3 = http3
self.protocol = protocol
def datagram(self, payload):
self.http3.send_datagram(self.stream_id, payload)
self.protocol.transmit()
streamingClients = []
R = 10
C = 10
FECS = [ [ None for _ in range(C) ] for _ in range(R) ]
INDEX = 0
SNBase = None
class WebTransportProtocol(QuicConnectionProtocol):
def __init__(self, *args, **kwargs):
global webTransport
super().__init__(*args, **kwargs)
self._http = None
def quic_event_received(self, event):
global streamingClients
if isinstance(event, ProtocolNegotiated):
self._http = H3Connection(self._quic, enable_webtransport=True)
elif isinstance(event, StreamReset):
streamingClients = [ client for client in streamingClients if client.stream_id != event.stream_id]
elif isinstance(event, StreamDataReceived):
if event.end_stream is True:
streamingClients = [ client for client in streamingClients if client.stream_id != event.stream_id]
elif isinstance(event, ConnectionTerminated):
pass
if self._http is not None:
for h3_event in self._http.handle_event(event):
self._h3_event_received(h3_event)
def _h3_event_received(self, event):
if isinstance(event, HeadersReceived):
headers = {}
for header, value in event.headers:
headers[header] = value
if (headers.get(b":method") == b"CONNECT" and
headers.get(b":protocol") == b"webtransport"):
self._handshake_webtransport(event.stream_id, headers)
else:
self._send_response(event.stream_id, 400, end_stream=True)
def _handshake_webtransport(self, stream_id, request_headers):
global streamingClients
authority = request_headers.get(b":authority")
path = request_headers.get(b":path")
if authority is None or path is None:
# `:authority` and `:path` must be provided.
self._send_response(stream_id, 400, end_stream=True)
if path == b'/stream':
self._send_response(stream_id, 200)
streamingClients.append(Client(stream_id, self._http, self))
else:
self._send_response(stream_id, 404, end_stream=True)
def _send_response(self, stream_id, status_code, end_stream=False):
headers = [(b":status", str(status_code).encode())]
if status_code == 200:
headers.append((b"sec-webtransport-http3-draft", b"draft02"))
self._http.send_headers(stream_id=stream_id, headers=headers, end_stream=end_stream)
async def streaming(stream):
global streamingClients
global R, C, FECS, INDEX, SNBase
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
await loop.connect_read_pipe(lambda: protocol, stream)
sequence_number = 0
buffer = bytearray()
while True:
isEOF = False
while True:
try:
sync_byte = await reader.readexactly(1)
if sync_byte == b'\x47':
break
except IncompleteReadError:
isEOF = True
if isEOF:
break
packet = None
try:
packet = b'\x47' + await reader.readexactly(187)
except IncompleteReadError:
break
buffer += packet
if len(buffer) < 188 * 6:
continue
for client in streamingClients:
client.datagram(rtp(bytes(buffer), 33, sequence_number))
if SNBase is None:
SNBase = sequence_number
r = INDEX // C
c = INDEX % C
FECS[r][c] = bytes(buffer)
if c == C - 1:
xor = bytearray(FECS[r][0])
for p in range(1, C):
for i in range(len(xor)):
xor[i] ^= FECS[r][p][i]
for client in streamingClients:
client.datagram(
rtp(
fec(bytes(xor), 33, (SNBase + r * C) % (2 ** 16), 0, 1, C),
96,
0
)
)
if r == R - 1:
xor = bytearray(FECS[0][c])
for p in range(1, R):
for i in range(len(xor)):
xor[i] ^= FECS[p][c][i]
for client in streamingClients:
client.datagram(
rtp(
fec(bytes(xor), 33, (SNBase + c) % (2 ** 16), 1, C, R),
96,
0
)
)
INDEX += 1
sequence_number = (sequence_number + 1) % (2 ** 16)
if INDEX == R * C:
SNBase = sequence_number
INDEX = 0
FECS = [ [ None for _ in range(C) ] for _ in range(R) ]
buffer = bytearray()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('certificate')
parser.add_argument('key')
parser.add_argument('-i', '--input', type=argparse.FileType('rb'), nargs='?', default=sys.stdin.buffer)
args = parser.parse_args()
configuration = QuicConfiguration(
alpn_protocols=H3_ALPN,
is_client=False,
max_datagram_frame_size=65536,
quic_logger=quic_logger,
)
configuration.load_cert_chain(args.certificate, args.key)
loop = asyncio.get_event_loop()
loop.run_until_complete(
serve(
BIND_ADDRESS,
BIND_PORT,
configuration=configuration,
create_protocol=WebTransportProtocol,
))
try:
loop.run_until_complete(streaming(args.input))
except KeyboardInterrupt:
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment