Skip to content

Instantly share code, notes, and snippets.

@charbonnierg
Last active March 12, 2024 17:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save charbonnierg/2807f34905a3099137f905d107489376 to your computer and use it in GitHub Desktop.
Save charbonnierg/2807f34905a3099137f905d107489376 to your computer and use it in GitHub Desktop.
broadcaster
"""Usage:
def get_listener(request: Request) -> Listener:
'''Function used to access listener from the application state.'''
return request.app.state.listener
def set_listener(app: FastAPI, url: str) -> None:
'''Function to be called once on application startup.'''
app.state.listener = Listener(url)
async def example(listener: Listener = Depends(get_listener)) -> None:
async with listener.subscribe("test") as subscription:
async for event in subscription:
print(event)
async def other_example(listener: Listener = Depends(get_listener)) -> None:
async with listener.subscribe("test") as subscription:
while True:
event = await subscription.next()
if event is None:
break
print(event)
"""
from __future__ import annotations
import asyncio
from typing import Any, AsyncContextManager
from broadcaster import Broadcast, Event
from broadcaster._base import Subscriber, Unsubscribed
from fastapi import Depends, FastAPI, Request, WebSocket, WebSocketDisconnect
class Subscription:
"""A wrapper around a subscriber to make it an async iterator."""
def __init__(
self,
channel: str,
listener: Listener,
) -> None:
self.channel = channel
self._listener = listener
self._context: AsyncContextManager[Subscriber] | None = None
self._subscriber: Subscriber | None = None
def __aiter__(self) -> Subscription:
return self
async def __aenter__(self) -> Subscription:
"""Start the subscription as an async context manager."""
await self._listener._connect_subscription(self)
return self
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
"""Stop the subscription as an async context manager."""
try:
if self._context:
await self._context.__aexit__(exc_type, exc, tb)
finally:
await self._listener._remove_subscription(self)
async def __anext__(self) -> Event:
"""Return next event or raise StopAsyncIteration if subscription is closed."""
if not self._subscriber:
raise RuntimeError("Subscription not started")
try:
return await self._subscriber.get()
except Unsubscribed:
raise StopAsyncIteration
async def next(self) -> Event | None:
"""Return next event or None is subscription is closed."""
if not self._subscriber:
raise RuntimeError("Subscription not started")
try:
return await self._subscriber.get()
except Unsubscribed:
return None
class Listener:
"""A wrapper around broadcaster.Broadcast.
Unlike the original implementation, this broadcaster will only connect
when the first subscription is created and disconnect when the last
subscription is removed.
"""
def __init__(self, url: str) -> None:
self.url = url
self._connect_task: asyncio.Task[Broadcast] | None = None
self._broadcast: Broadcast | None = None
self._subscriptions: list[Subscription] = []
async def _connect_broadcast(self) -> Broadcast:
"""Connect a new broadcast, save it and return it."""
broadcast = Broadcast(self.url)
await broadcast.connect()
self._broadcast = broadcast
return broadcast
async def _get_or_create_broadcaster(self) -> Broadcast:
"""Return the current broadcast or create a new one and connect it before returning it.
I did not want to use a lock, because it's easy to forget to release it.
So I use a task to connect the broadcast and save it in the attribute.
Because asyncio is not parallel, but concurrent, there won't be two
tasks trying to connect the broadcast at the same time.
Either no connect task is running, then the task is created and awaited,
or the task is already running, then the await will wait for the task to finish.
"""
if broadcast := self._broadcast:
return broadcast
if connect_task := self._connect_task:
if not connect_task.done():
await asyncio.wait([connect_task])
return connect_task.result()
else:
self._connect_task = asyncio.create_task(self._connect_broadcast())
return await self._connect_task
async def _remove_subscription(self, subscription: Subscription) -> None:
try:
self._subscriptions.remove(subscription)
except ValueError:
pass
if not self._subscriptions and self._broadcast:
broadcast = self._broadcast
self._broadcast = None
self._connect_task = None
await broadcast.disconnect()
async def _connect_subscription(self, subscription: Subscription) -> None:
broadcaster = await self._get_or_create_broadcaster()
provider = broadcaster.subscribe(subscription.channel)
subscriber = await provider.__aenter__()
subscription._context = provider
subscription._subscriber = subscriber
self._subscriptions.append(subscription)
def subscribe(self, channel: str) -> Subscription:
"""Return a new subscription for the given channel.
The subscription must be used as an asynchronous context manager
"""
return Subscription(channel=channel, listener=self)
def get_listener(request: Request) -> Listener:
"""Function used to access listener from the application state."""
return request.app.state.listener
def set_listener(app: FastAPI, url: str) -> None:
"""Function to be called once on application startup."""
app.state.listener = Listener(url)
async def callback(
websocket: WebSocket,
listener: Listener = Depends(get_listener),
) -> None:
"""Example of a websocket endpoint that listens to a broadcast channel."""
# Make sure that we've got a subscription before accepting the websocket connection
async with listener.subscribe("test") as subscription:
# Accept the websocket
await websocket.accept()
# Enter an infinite loop
while True:
# Create tasks to wait for the next event and the next websocket message
next_event = asyncio.create_task(subscription.next())
next_websocket_event = asyncio.create_task(websocket.receive_text())
# Wait until one of the tasks is done
try:
await asyncio.wait(
[next_event, next_websocket_event],
return_when=asyncio.FIRST_COMPLETED,
)
finally:
# Check if the task waiting upon websocket is done
next_websocket_event_done = next_websocket_event.done()
# Cancel both tasks to avoid memory leaks
next_event.cancel()
next_websocket_event.cancel()
# Check if the websocket has been closed
if next_websocket_event_done:
if err := next_websocket_event.exception():
# Raise error unless it's a disconnect
try:
raise err
except WebSocketDisconnect:
# In which case simply return
return
# If the websocket is still open, it means we received a message
# but we did not expect one
await websocket.close(reason="unexpected message")
return
# Get the event from the subscription
event = next_event.result()
# No more event, close the websocket
if event is None:
# This should not happen, because there is no reason for
# the listener to be closed unless the application itself is stopped
# So I think that's it's OK to label this error as "not available"
await websocket.close(reason="not available")
return
# Send the event to the websocket
# (Or optionally transform the message before hand and send some custom JSON)
await websocket.send_bytes(event.message)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment