Skip to content

Instantly share code, notes, and snippets.

@vertigg
Last active January 17, 2023 07:29
Show Gist options
  • Save vertigg/be2c8758e2d92be0777c47f3a6ba2f26 to your computer and use it in GitHub Desktop.
Save vertigg/be2c8758e2d92be0777c47f3a6ba2f26 to your computer and use it in GitHub Desktop.
Starlette class-based PubSub Websocket Endpoint
"""
Heavily inspired by https://gist.github.com/timhughes/313c89a0d587a25506e204573c8017e4
This is an example of class-based WebSocketEndpoint.
To use - just subclass your endpoint with RedisPubSubWebSocketEndpoint and use it as is.
You can also override on_connect, on_disconnect, on_receive and new on_published_message methods
New internal method "publish" acts like "broadcast" for any subscribed websocket
"""
import json
import asyncio
import logging
import typing
from redis import asyncio as aioredis
from redis.asyncio.client import PubSub, Redis
from starlette import status
from starlette.applications import Starlette
from starlette.authentication import requires
from starlette.endpoints import WebSocketEndpoint
from starlette.exceptions import WebSocketException
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.routing import WebSocketRoute
from starlette.websockets import WebSocket
from starlette_jwt.middleware import JWTWebSocketAuthenticationBackend
logger = logging.getLogger('test.logger')
async def get_redis_pool():
return await aioredis.from_url(
'redis://localhost:6379/0',
encoding='utf-8',
decode_responses=True
)
class RedisPubSubWebSocketEndpoint(WebSocketEndpoint):
channel_name = 'chat:lobby'
encoding = 'json'
redis: Redis
async def publish(self, data: typing.Any):
"""Prepare received data from client and publish it to redis channel"""
if self.encoding == 'json':
data = json.dumps(data)
payload = json.dumps({'text': data})
elif self.encoding == 'text':
payload = json.dumps({'text': data})
elif self.encoding == 'bytes':
payload = json.dumps({'bytes': data.decode('utf-8')})
else:
payload = json.dumps(data)
await self.redis.publish(self.channel_name, payload)
async def consumer_handler(self, websocket: WebSocket):
try:
while True:
message = await websocket.receive()
if message["type"] == "websocket.receive":
data = await self.decode(websocket, message)
if data:
await self.on_receive(websocket, data)
elif message["type"] == "websocket.disconnect":
break
except WebSocketException as exc:
logger.error(exc)
raise
async def publisher_handler(self, websocket: WebSocket, pubsub: PubSub):
await pubsub.subscribe(self.channel_name)
try:
while True:
message = await pubsub.get_message(ignore_subscribe_messages=True)
if message:
data = json.loads(message.get('data'))
decoded_data = await self.decode(websocket, data)
await self.on_published_message(websocket, decoded_data)
except Exception as exc:
logger.error(exc)
raise
async def dispatch(self) -> None:
websocket = WebSocket(self.scope, receive=self.receive, send=self.send)
close_code = status.WS_1000_NORMAL_CLOSURE
pending: list[Task] = []
try:
self.redis: Redis = await get_redis_pool()
pubsub: PubSub = self.redis.pubsub()
await self.on_connect(websocket)
consumer_handler = create_task(self.consumer_handler(websocket))
publisher_handler = create_task(self.publisher_handler(websocket, pubsub))
tasks: list[Task] = [consumer_handler, publisher_handler]
done, pending = await wait(tasks, return_when=FIRST_COMPLETED, )
logger.debug('Task finished: %s', done)
for task in pending:
logger.debug('cancelling task %s', task)
task.cancel()
for task in done:
exception = task.exception()
if isinstance(exception, Exception):
raise exception
except ConnectionError as exc:
logger.error("Can't connect to Redis instance, %s", exc)
except Exception as exc:
close_code = status.WS_1011_INTERNAL_ERROR
raise exc
finally:
for task in pending:
task.cancel()
await pubsub.close()
await self.redis.close()
await self.on_disconnect(websocket, close_code)
async def on_published_message(self, websocket: WebSocket, data):
"""Override to handle an incoming message from Redis channel"""
# Usage example
class ChatLobby(RedisPubSubWebSocketEndpoint):
@requires('authenticated')
async def on_connect(self, websocket: WebSocket) -> None:
await websocket.accept()
print('conected')
async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
print('disconnected')
async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None:
print('Received from client:', data)
await self.publish(data)
async def on_published_message(self, websocket: WebSocket, data):
print('Received from redis:', data)
payload = {'type': 'chat_message'} | data
await websocket.send_json(payload)
middleware = [
Middleware(CORSMiddleware, allow_origins=['*']),
Middleware(AuthenticationMiddleware, backend=JWTWebSocketAuthenticationBackend(
secret_key='some-key',
query_param_name='token',
username_field='user_id'
))
]
app = Starlette(
debug=True,
routes=[WebSocketRoute('/ws/chat/lobby/', ChatLobby)],
middleware=middleware
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment