Skip to content

Instantly share code, notes, and snippets.

@ThirVondukr
Last active November 6, 2023 03:13
Show Gist options
  • Save ThirVondukr/7bf7880ffbed1b445b4d573d51de8bd5 to your computer and use it in GitHub Desktop.
Save ThirVondukr/7bf7880ffbed1b445b4d573d51de8bd5 to your computer and use it in GitHub Desktop.
Redis/In-Memory pub-sub with multicast and client side routing
import asyncio
import contextlib
import random
import traceback
from collections import defaultdict
from collections.abc import AsyncIterator, Callable, Hashable
from types import TracebackType
from typing import (
TYPE_CHECKING,
Generic,
Protocol,
Self,
TypeVar,
)
import anyio
import redis.asyncio as redis
from anyio.streams.stapled import StapledObjectStream
from pydantic import BaseModel
TValue = TypeVar("TValue")
TBaseModel = TypeVar("TBaseModel", bound=BaseModel)
if TYPE_CHECKING:
RedisClient = redis.Redis[bytes]
else:
RedisClient = redis.Redis
class UserEvent(BaseModel):
user_id: int
TUserEvent = TypeVar("TUserEvent", bound=UserEvent)
class MulticastReceiver(Generic[TValue]):
def __init__(self, event: asyncio.Event, container: list[TValue]) -> None:
self._event = event
self._container = container
def __aiter__(self) -> Self:
return self
async def __anext__(self) -> TValue:
return await self.recv()
async def recv(self) -> TValue:
await self._event.wait()
value = self._container.pop()
if not self._container:
self._event.clear()
return value
class MulticastStream(Generic[TValue]):
def __init__(self) -> None:
self._receivers: dict[asyncio.Event, list[TValue]] = {}
async def send(self, value: TValue) -> None:
for event, container in self._receivers.items():
container.append(value)
event.set()
await asyncio.sleep(0)
@contextlib.asynccontextmanager
async def _get_event(
self,
) -> AsyncIterator[tuple[asyncio.Event, list[TValue]]]:
event = asyncio.Event()
self._receivers[event] = []
yield event, self._receivers[event]
self._receivers.pop(event)
@contextlib.asynccontextmanager
async def recv(self) -> AsyncIterator[MulticastReceiver[TValue]]:
async with self._get_event() as (event, container):
yield MulticastReceiver(event, container)
class ExchangeTransport(Protocol[TValue]):
async def send(self, item: TValue) -> None:
...
def recv(self) -> AsyncIterator[TValue]:
...
async def close(self) -> None:
pass
class InMemoryExchangeTransport(ExchangeTransport[TValue]):
def __init__(self) -> None:
self._stream = StapledObjectStream(*anyio.create_memory_object_stream())
async def send(self, item: TValue) -> None:
await self._stream.send(item)
def recv(self) -> AsyncIterator[TValue]:
return self._stream.receive_stream
class RedisExchangeTransport(ExchangeTransport[TBaseModel]):
def __init__(
self,
client: RedisClient,
channel: str,
model: type[TBaseModel],
) -> None:
self._client = client
self._channel = channel
self._model_cls = model
async def send(self, item: TBaseModel) -> None:
await self._client.publish(
channel=self._channel,
message=item.model_dump_json(),
)
async def recv(self) -> AsyncIterator[TBaseModel]:
async with self._client.pubsub() as pubsub:
await pubsub.subscribe(self._channel)
while True:
message = await pubsub.get_message(
ignore_subscribe_messages=True,
timeout=None, # type: ignore[arg-type]
)
if message:
yield self._model_cls.model_validate_json(message["data"])
async def close(self) -> None:
await self._client.aclose() # type: ignore[attr-defined]
RoutingKey = TypeVar("RoutingKey", bound=Hashable)
RouteKeyRouter = Callable[[TValue], RoutingKey]
def user_id_key(event: TUserEvent) -> int:
return event.user_id
class EventExchange(Generic[RoutingKey, TValue]):
def __init__(
self,
transport: ExchangeTransport[TValue],
route_key: RouteKeyRouter[TValue, RoutingKey],
) -> None:
self._routing_key = route_key
self._transport = transport
self._streams: dict[Hashable, MulticastStream[object]] = defaultdict(
MulticastStream,
)
self._consumer_task: asyncio.Task[None] | None = None
async def _consume(self) -> None:
try:
async for message in self._transport.recv():
key = self._routing_key(message)
await self._streams[key].send(message)
except Exception as e:
traceback.print_exception(e)
raise
async def __aenter__(self) -> Self:
self._consumer_task = asyncio.create_task(self._consume())
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self._consumer_task is not None:
self._consumer_task.cancel()
await self._transport.close()
def _get_receiver(self, key: RoutingKey) -> MulticastStream[TValue]:
return self._streams[key] # type: ignore[return-value]
async def publish(self, event: TValue) -> None:
await self._transport.send(event)
async def subscribe(
self,
routing_key: RoutingKey,
) -> AsyncIterator[TValue]:
stream = self._streams[routing_key]
async with stream.recv() as recv:
async for message in recv:
yield message # type: ignore[misc]
async def worker(
exchange: EventExchange[int, UserEvent],
test_duration: float,
num_clients: int,
) -> None:
try:
async with asyncio.timeout(test_duration):
while True:
await exchange.publish(
UserEvent(user_id=random.randint(0, num_clients)),
)
except asyncio.TimeoutError:
return
async def consumer(
exchange: EventExchange[int, UserEvent],
user_id: int,
test_duration: float,
) -> int:
count = 0
try:
async with asyncio.timeout(test_duration):
async for _ in exchange.subscribe(user_id):
count += 1
except asyncio.TimeoutError:
return count
raise NotImplementedError
async def main() -> None:
await asyncio.sleep(10)
transport = RedisExchangeTransport[UserEvent](
client=RedisClient(host="redis", db=1),
channel="events",
model=UserEvent,
)
# transport = InMemoryExchangeTransport[UserEvent]()
test_duration = 60
num_clients = 1000
async with (
EventExchange(transport=transport, route_key=user_id_key) as exchange,
asyncio.TaskGroup() as tg,
):
tg.create_task(
worker(exchange, test_duration=test_duration, num_clients=num_clients),
)
consumers = [
tg.create_task(consumer(exchange, user_id=i, test_duration=test_duration))
for i in range(num_clients)
]
values = [c.result() for c in consumers]
print(values)
print(sum(values) / len(consumers) / test_duration)
print(sum(values))
if __name__ == "__main__":
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment