Skip to content

Instantly share code, notes, and snippets.

@Tobi-De
Last active December 1, 2023 11:19
Show Gist options
  • Save Tobi-De/437717c792f90814d10eae31eb8d12a5 to your computer and use it in GitHub Desktop.
Save Tobi-De/437717c792f90814d10eae31eb8d12a5 to your computer and use it in GitHub Desktop.
sse-starlette
starlette
redis
psycopg[c]
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Generator
import redis.asyncio as async_redis
from sse_starlette import EventSourceResponse, ServerSentEvent
import json
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse
import redis
import time
import psycopg
import logging
from logging import getLogger
from starlette.routing import Route
logging.basicConfig(level=logging.DEBUG)
logger = getLogger(__name__)
COUNT_KEY = "REDIS_COUNT_KEY"
class PostgresBroker:
def __init__(self, dbname, user, password, host=None) -> None:
self.db_params = {
"client_encoding": "UTF8",
"dbname": dbname,
"user": user,
"password": password,
"host": host or "127.0.0.1",
}
async def listen(self, channel: str) -> AsyncGenerator[ServerSentEvent, None]:
connection = await psycopg.AsyncConnection.connect(
**self.db_params,
autocommit=True,
)
async with connection.cursor() as cursor:
logger.debug(f"Listening to {channel}")
await cursor.execute(f"LISTEN {channel}")
generator = connection.notifies()
async for notify_message in generator:
payload = json.loads(notify_message.payload)
logger.debug(f"Data received from {channel}")
yield ServerSentEvent(**payload)
def notify(self, channel: str, sse_payload: dict) -> None:
connection = psycopg.Connection.connect(
**self.db_params,
autocommit=True,
)
logger.debug(f"Publishing to {channel}: {sse_payload}")
with connection.cursor() as cursor:
cursor.execute(f"NOTIFY {channel}, '{json.dumps(sse_payload)}'")
class RedisBroker:
def __init__(self, redis_url: str) -> None:
self._client = async_redis.from_url(redis_url)
self._sync_client = redis.from_url(redis_url)
self._pubsub = self._client.pubsub()
@asynccontextmanager
async def increment(self, channel: str) -> Generator[None, None, None]:
logger.debug(f"Incrementing {channel}")
await self._client.hincrby(COUNT_KEY, channel, 1)
try:
yield
finally:
logger.debug(f"Decrementing {channel}")
await self._client.hincrby(COUNT_KEY, channel, -1)
async def listen(self, channel: str) -> AsyncGenerator[ServerSentEvent, None]:
async with self.increment(channel):
logger.debug(f"Listening to {channel}")
await self._pubsub.subscribe(channel)
while True:
message = await self._pubsub.get_message(ignore_subscribe_messages=True)
if message is not None:
payload = json.loads(message["data"].decode())
logger.debug(f"Data received from {channel}")
yield ServerSentEvent(**payload)
async def value(self) -> dict[str, int]:
return {
k.decode(): int(v)
for k, v in (await self._client.hgetall(COUNT_KEY)).items()
}
def notify(self, channel: str, sse_payload: dict) -> None:
logger.debug(f"Publishing to {channel}: {sse_payload}")
self._sync_client.publish(channel=channel, message=json.dumps(sse_payload))
broker = RedisBroker("redis://localhost:6379")
#broker = PostgresBroker(dbname="estate_sh", user="postgres", password="blumenkranz")
async def sse(request: Request):
channel = request.path_params.get("channel")
logger.info(f"New SSE connection to {channel}")
return EventSourceResponse(broker.listen(channel))
async def count(_: Request):
return JSONResponse(await broker.value())
routes = [Route("/count", endpoint=count), Route("/{channel}", endpoint=sse)]
app = Starlette(routes=routes)
if __name__ == "__main__":
counter = 0
while True:
print("Sending message")
broker.notify("test_channel", {"data": counter})
counter += 1
time.sleep(4)
# stream event with curl
# curl -N http://localhost:8001/test_channel
# send event with python
# python test.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment