Created
January 12, 2021 10:04
-
-
Save edelvalle/5ac811b46370457837a9d428964b30eb to your computer and use it in GitHub Desktop.
An attempt to use postgres pubsub for django-channels (NOT PROD READY and has potential bugs)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import asyncio | |
import base64 | |
import threading | |
from uuid import uuid4 | |
from queue import SimpleQueue | |
from collections import defaultdict | |
import asyncpg | |
import msgpack | |
from channels.layers import BaseChannelLayer | |
class PostgresChannelLayer(BaseChannelLayer): | |
def __init__( | |
self, expiry=60, capacity=100, channel_capacity=None, **config | |
): | |
super().__init__( | |
expiry=expiry, | |
capacity=capacity, | |
channel_capacity=channel_capacity | |
) | |
self._conn = Connection(self._enqueue_message, config) | |
self._init() | |
extensions = ('groups', 'flush') | |
async def send(self, channel: str, message: dict, optimize=True): | |
if optimize and channel in self._queues: | |
(await self._queues.get(channel)).put_nowait(message) | |
else: | |
encoded_message = self._serialize(message) | |
self._conn.send(channel, encoded_message) | |
async def receive(self, channel: str): | |
queue = await self._queues.get(channel) | |
return await queue.get() | |
async def new_channel(self, prefix: str = 'specific') -> str: | |
return f'{prefix}.pgsql.{uuid4().hex}' | |
# Flush extension | |
async def flush(self): | |
subscriptions = getattr(self, '_subscriptions', {}) | |
for group_name, group in subscriptions.values(): | |
self._conn.stop_listening(group_name) | |
self._init() | |
def _init(self): | |
self._subscriptions = defaultdict(set) | |
self._queues = TTLCache( | |
create=self._create_queue, | |
expire=self._delete_queue, | |
expiry=self.expiry, | |
) | |
async def _create_queue(self, name): | |
self._conn.listen_to(name) | |
return asyncio.Queue() | |
async def _delete_queue(self, name): | |
self._conn.stop_listening(name) | |
async def close(self): | |
self._conn.close() | |
self._conn = None | |
# Groups extensions | |
async def group_add(self, group: str, channel: str): | |
new_group = group not in self._subscriptions | |
self._subscriptions[group].add(channel) | |
if new_group: | |
self._conn.listen_to(group) | |
async def group_discard(self, group, channel): | |
self._subscriptions[group].difference_update({channel}) | |
if not self._subscriptions[group]: | |
self._conn.stop_listening(group) | |
del self._subscriptions[group] | |
async def group_send(self, group: str, message: dict): | |
await self.send(group, message, optimize=False) | |
# Connection receiver | |
def _enqueue_message(self, connection, pid, channel, payload): | |
message = self._deserialize(payload) | |
asyncio.create_task(self._enqueue_message_async(channel, message)) | |
async def _enqueue_message_async(self, channel, message): | |
if channel in self._subscriptions: | |
recipients = self._subscriptions[channel] | |
else: | |
recipients = {channel} | |
for recipient in recipients: | |
if recipient in self._queues: | |
queue = await self._queues.get(recipient) | |
queue.put_nowait(message) | |
else: | |
# group clean up based on queue expiration | |
await self.group_discard(channel, recipient) | |
# Serialization | |
def _serialize(self, message: dict) -> str: | |
value = msgpack.packb(message, use_bin_type=True) | |
return base64.encodebytes(value).decode().strip() | |
def _deserialize(self, message: str) -> dict: | |
message = base64.decodebytes(message.encode()) | |
return msgpack.unpackb(message, raw=False) | |
class Connection: | |
def __init__(self, callback, config): | |
self.config = config | |
self._callback = callback | |
self._queue = SimpleQueue() | |
self._thread = threading.Thread(target=self._start) | |
self._thread.start() | |
self._ping_timer = threading.Timer(5, self._ping) | |
self._ping_timer.start() | |
def send(self, channel: str, payload: str): | |
self._queue.put({ | |
'type': 'send', | |
'channel': channel, | |
'payload': payload, | |
}) | |
def listen_to(self, channel: str): | |
self._queue.put({'type': 'listen', 'channel': channel}) | |
def stop_listening(self, channel: str): | |
self._queue.put({'type': 'stop_listening', 'channel': channel}) | |
def close(self): | |
self._queue.put({'type': 'close'}) | |
self._ping_timer.cancel() | |
def _ping(self): | |
self.send('ping', 'ping') | |
self._ping_timer = threading.Timer(5, self._ping) | |
self._ping_timer.start() | |
def __del__(self): | |
self.close() | |
def _start(self): | |
asyncio.run(self._loop()) | |
async def _loop(self): | |
conn = await asyncpg.connect(**self.config) | |
while True: | |
message = self._queue.get() | |
channel = message.get('channel') | |
if message['type'] == 'send': | |
payload = message.get('payload') | |
await conn.execute(f'''NOTIFY "{channel}", '{payload}';''') | |
elif message['type'] == 'listen': | |
await conn.add_listener(channel, self._callback) | |
elif message['type'] == 'stop_listening': | |
await conn.remove_listener(channel, self._callback) | |
elif message['type'] == 'close': | |
await conn.close() | |
break | |
class TTLCache: | |
def __init__(self, create, expire, expiry): | |
self._create = create | |
self._expire = expire | |
self._expiry = expiry | |
self._expires = {} | |
self._objects = {} | |
def __contains__(self, key): | |
return key in self._objects | |
async def get(self, key): | |
if key not in self._objects: | |
self._objects[key] = await self._create(key) | |
self._expires[key] = time.time() + self._expiry | |
delay(self._expiry)(self.cleanup, key) | |
return self._objects[key] | |
async def cleanup(self, key): | |
now = time.time() | |
if self._expires.get(key, now + 1) < now: | |
await self._expire(key) | |
del self._expires[key] | |
del self._objects[key] | |
def __del__(self): | |
for key in self._expires: | |
asyncio.create_task(self._expire(key)) | |
def delay(seconds): | |
def decorator(_f, *args, **kwargs): | |
async def wrapper(): | |
await asyncio.sleep(seconds) | |
await _f(*args, **kwargs) | |
asyncio.create_task(wrapper()) | |
return decorator |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment