Skip to content

Instantly share code, notes, and snippets.

@ash2shukla
Created May 7, 2023 22:18
Show Gist options
  • Save ash2shukla/2f341cfa28bc3e058378b22def1977c0 to your computer and use it in GitHub Desktop.
Save ash2shukla/2f341cfa28bc3e058378b22def1977c0 to your computer and use it in GitHub Desktop.
Scale websockets with Redis Pubsub in FastAPI
import asyncio
import logging
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from fastapi.websockets import WebSocket, WebSocketDisconnect
import aioredis
import async_timeout
import json
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
html = """
<!DOCTYPE html>
<html>
<head>
<title>Chat</title>
</head>
<body>
<h1>WebSocket Chat</h1>
<h3 id="userid"></h3>
<form action="" onsubmit="sendMessage(event)">
<input type="text" id="toUser" autocomplete="off"/>
<input type="text" id="messageText" autocomplete="off"/>
<button>Send</button>
</form>
<ul id='messages'>
</ul>
<script>
const userID = "user-" + Math.floor(Math.random() * 100);
document.getElementById("userid").innerHTML = "You are " + userID;
var ws = new WebSocket("ws://localhost:8000/ws?userid="+userID);
ws.onmessage = function(event) {
var messages = document.getElementById('messages')
var message = document.createElement('li')
const event_data = JSON.parse(event.data);
const formattedMessage = "From: " + event_data.from + " > " + event_data.content
var content = document.createTextNode(formattedMessage)
message.appendChild(content)
messages.appendChild(message)
};
function sendMessage(event) {
const input = document.getElementById("messageText")
const toUser = document.getElementById("toUser")
ws.send(JSON.stringify({"to": toUser.value,"content": input.value}))
input.value = ''
event.preventDefault()
}
</script>
</body>
</html>
"""
redis_uri = "redis://localhost:6379"
redis = aioredis.from_url(redis_uri, decode_responses=True)
@app.get("/")
async def get():
return HTMLResponse(html)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, userid: str):
await websocket.accept()
websocket._userid = userid
await redis_connector(websocket)
async def redis_connector(websocket: WebSocket):
async def consumer_handler(ws: WebSocket, r):
try:
while True:
message = await ws.receive_json()
if message:
await r.publish(f"chat:{message['to']}", json.dumps({"from": ws._userid, "content": message['content']}))
except WebSocketDisconnect as exc:
logger.error(exc)
async def producer_handler(channel, ws: WebSocket):
await channel.subscribe(f"chat:{ws._userid}")
while True:
try:
async with async_timeout.timeout(1):
message = await channel.get_message(ignore_subscribe_messages=True)
if message is not None:
data = json.loads(message["data"])
await ws.send_json({"from": data["from"], "content": data["content"]})
await asyncio.sleep(0.01)
except asyncio.TimeoutError:
pass
ps = redis.pubsub()
consumer_task = consumer_handler(websocket, redis)
producer_task = producer_handler(ps, websocket)
done, pending = await asyncio.wait(
[consumer_task, producer_task], return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment