Last active
January 17, 2023 07:29
-
-
Save vertigg/be2c8758e2d92be0777c47f3a6ba2f26 to your computer and use it in GitHub Desktop.
Starlette class-based PubSub Websocket Endpoint
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Heavily inspired by https://gist.github.com/timhughes/313c89a0d587a25506e204573c8017e4 | |
This is an example of class-based WebSocketEndpoint. | |
To use - just subclass your endpoint with RedisPubSubWebSocketEndpoint and use it as is. | |
You can also override on_connect, on_disconnect, on_receive and new on_published_message methods | |
New internal method "publish" acts like "broadcast" for any subscribed websocket | |
""" | |
import json | |
import asyncio | |
import logging | |
import typing | |
from redis import asyncio as aioredis | |
from redis.asyncio.client import PubSub, Redis | |
from starlette import status | |
from starlette.applications import Starlette | |
from starlette.authentication import requires | |
from starlette.endpoints import WebSocketEndpoint | |
from starlette.exceptions import WebSocketException | |
from starlette.middleware import Middleware | |
from starlette.middleware.authentication import AuthenticationMiddleware | |
from starlette.middleware.cors import CORSMiddleware | |
from starlette.routing import WebSocketRoute | |
from starlette.websockets import WebSocket | |
from starlette_jwt.middleware import JWTWebSocketAuthenticationBackend | |
logger = logging.getLogger('test.logger') | |
async def get_redis_pool(): | |
return await aioredis.from_url( | |
'redis://localhost:6379/0', | |
encoding='utf-8', | |
decode_responses=True | |
) | |
class RedisPubSubWebSocketEndpoint(WebSocketEndpoint): | |
channel_name = 'chat:lobby' | |
encoding = 'json' | |
redis: Redis | |
async def publish(self, data: typing.Any): | |
"""Prepare received data from client and publish it to redis channel""" | |
if self.encoding == 'json': | |
data = json.dumps(data) | |
payload = json.dumps({'text': data}) | |
elif self.encoding == 'text': | |
payload = json.dumps({'text': data}) | |
elif self.encoding == 'bytes': | |
payload = json.dumps({'bytes': data.decode('utf-8')}) | |
else: | |
payload = json.dumps(data) | |
await self.redis.publish(self.channel_name, payload) | |
async def consumer_handler(self, websocket: WebSocket): | |
try: | |
while True: | |
message = await websocket.receive() | |
if message["type"] == "websocket.receive": | |
data = await self.decode(websocket, message) | |
if data: | |
await self.on_receive(websocket, data) | |
elif message["type"] == "websocket.disconnect": | |
break | |
except WebSocketException as exc: | |
logger.error(exc) | |
raise | |
async def publisher_handler(self, websocket: WebSocket, pubsub: PubSub): | |
await pubsub.subscribe(self.channel_name) | |
try: | |
while True: | |
message = await pubsub.get_message(ignore_subscribe_messages=True) | |
if message: | |
data = json.loads(message.get('data')) | |
decoded_data = await self.decode(websocket, data) | |
await self.on_published_message(websocket, decoded_data) | |
except Exception as exc: | |
logger.error(exc) | |
raise | |
async def dispatch(self) -> None: | |
websocket = WebSocket(self.scope, receive=self.receive, send=self.send) | |
close_code = status.WS_1000_NORMAL_CLOSURE | |
pending: list[Task] = [] | |
try: | |
self.redis: Redis = await get_redis_pool() | |
pubsub: PubSub = self.redis.pubsub() | |
await self.on_connect(websocket) | |
consumer_handler = create_task(self.consumer_handler(websocket)) | |
publisher_handler = create_task(self.publisher_handler(websocket, pubsub)) | |
tasks: list[Task] = [consumer_handler, publisher_handler] | |
done, pending = await wait(tasks, return_when=FIRST_COMPLETED, ) | |
logger.debug('Task finished: %s', done) | |
for task in pending: | |
logger.debug('cancelling task %s', task) | |
task.cancel() | |
for task in done: | |
exception = task.exception() | |
if isinstance(exception, Exception): | |
raise exception | |
except ConnectionError as exc: | |
logger.error("Can't connect to Redis instance, %s", exc) | |
except Exception as exc: | |
close_code = status.WS_1011_INTERNAL_ERROR | |
raise exc | |
finally: | |
for task in pending: | |
task.cancel() | |
await pubsub.close() | |
await self.redis.close() | |
await self.on_disconnect(websocket, close_code) | |
async def on_published_message(self, websocket: WebSocket, data): | |
"""Override to handle an incoming message from Redis channel""" | |
# Usage example | |
class ChatLobby(RedisPubSubWebSocketEndpoint): | |
@requires('authenticated') | |
async def on_connect(self, websocket: WebSocket) -> None: | |
await websocket.accept() | |
print('conected') | |
async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: | |
print('disconnected') | |
async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None: | |
print('Received from client:', data) | |
await self.publish(data) | |
async def on_published_message(self, websocket: WebSocket, data): | |
print('Received from redis:', data) | |
payload = {'type': 'chat_message'} | data | |
await websocket.send_json(payload) | |
middleware = [ | |
Middleware(CORSMiddleware, allow_origins=['*']), | |
Middleware(AuthenticationMiddleware, backend=JWTWebSocketAuthenticationBackend( | |
secret_key='some-key', | |
query_param_name='token', | |
username_field='user_id' | |
)) | |
] | |
app = Starlette( | |
debug=True, | |
routes=[WebSocketRoute('/ws/chat/lobby/', ChatLobby)], | |
middleware=middleware | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment