Skip to content

Instantly share code, notes, and snippets.

@edelvalle
Created January 12, 2021 10:04
Show Gist options
  • Save edelvalle/5ac811b46370457837a9d428964b30eb to your computer and use it in GitHub Desktop.
Save edelvalle/5ac811b46370457837a9d428964b30eb to your computer and use it in GitHub Desktop.
An attempt to use postgres pubsub for django-channels (NOT PROD READY and has potential bugs)
import time
import asyncio
import base64
import threading
from uuid import uuid4
from queue import SimpleQueue
from collections import defaultdict
import asyncpg
import msgpack
from channels.layers import BaseChannelLayer
class PostgresChannelLayer(BaseChannelLayer):
def __init__(
self, expiry=60, capacity=100, channel_capacity=None, **config
):
super().__init__(
expiry=expiry,
capacity=capacity,
channel_capacity=channel_capacity
)
self._conn = Connection(self._enqueue_message, config)
self._init()
extensions = ('groups', 'flush')
async def send(self, channel: str, message: dict, optimize=True):
if optimize and channel in self._queues:
(await self._queues.get(channel)).put_nowait(message)
else:
encoded_message = self._serialize(message)
self._conn.send(channel, encoded_message)
async def receive(self, channel: str):
queue = await self._queues.get(channel)
return await queue.get()
async def new_channel(self, prefix: str = 'specific') -> str:
return f'{prefix}.pgsql.{uuid4().hex}'
# Flush extension
async def flush(self):
subscriptions = getattr(self, '_subscriptions', {})
for group_name, group in subscriptions.values():
self._conn.stop_listening(group_name)
self._init()
def _init(self):
self._subscriptions = defaultdict(set)
self._queues = TTLCache(
create=self._create_queue,
expire=self._delete_queue,
expiry=self.expiry,
)
async def _create_queue(self, name):
self._conn.listen_to(name)
return asyncio.Queue()
async def _delete_queue(self, name):
self._conn.stop_listening(name)
async def close(self):
self._conn.close()
self._conn = None
# Groups extensions
async def group_add(self, group: str, channel: str):
new_group = group not in self._subscriptions
self._subscriptions[group].add(channel)
if new_group:
self._conn.listen_to(group)
async def group_discard(self, group, channel):
self._subscriptions[group].difference_update({channel})
if not self._subscriptions[group]:
self._conn.stop_listening(group)
del self._subscriptions[group]
async def group_send(self, group: str, message: dict):
await self.send(group, message, optimize=False)
# Connection receiver
def _enqueue_message(self, connection, pid, channel, payload):
message = self._deserialize(payload)
asyncio.create_task(self._enqueue_message_async(channel, message))
async def _enqueue_message_async(self, channel, message):
if channel in self._subscriptions:
recipients = self._subscriptions[channel]
else:
recipients = {channel}
for recipient in recipients:
if recipient in self._queues:
queue = await self._queues.get(recipient)
queue.put_nowait(message)
else:
# group clean up based on queue expiration
await self.group_discard(channel, recipient)
# Serialization
def _serialize(self, message: dict) -> str:
value = msgpack.packb(message, use_bin_type=True)
return base64.encodebytes(value).decode().strip()
def _deserialize(self, message: str) -> dict:
message = base64.decodebytes(message.encode())
return msgpack.unpackb(message, raw=False)
class Connection:
def __init__(self, callback, config):
self.config = config
self._callback = callback
self._queue = SimpleQueue()
self._thread = threading.Thread(target=self._start)
self._thread.start()
self._ping_timer = threading.Timer(5, self._ping)
self._ping_timer.start()
def send(self, channel: str, payload: str):
self._queue.put({
'type': 'send',
'channel': channel,
'payload': payload,
})
def listen_to(self, channel: str):
self._queue.put({'type': 'listen', 'channel': channel})
def stop_listening(self, channel: str):
self._queue.put({'type': 'stop_listening', 'channel': channel})
def close(self):
self._queue.put({'type': 'close'})
self._ping_timer.cancel()
def _ping(self):
self.send('ping', 'ping')
self._ping_timer = threading.Timer(5, self._ping)
self._ping_timer.start()
def __del__(self):
self.close()
def _start(self):
asyncio.run(self._loop())
async def _loop(self):
conn = await asyncpg.connect(**self.config)
while True:
message = self._queue.get()
channel = message.get('channel')
if message['type'] == 'send':
payload = message.get('payload')
await conn.execute(f'''NOTIFY "{channel}", '{payload}';''')
elif message['type'] == 'listen':
await conn.add_listener(channel, self._callback)
elif message['type'] == 'stop_listening':
await conn.remove_listener(channel, self._callback)
elif message['type'] == 'close':
await conn.close()
break
class TTLCache:
def __init__(self, create, expire, expiry):
self._create = create
self._expire = expire
self._expiry = expiry
self._expires = {}
self._objects = {}
def __contains__(self, key):
return key in self._objects
async def get(self, key):
if key not in self._objects:
self._objects[key] = await self._create(key)
self._expires[key] = time.time() + self._expiry
delay(self._expiry)(self.cleanup, key)
return self._objects[key]
async def cleanup(self, key):
now = time.time()
if self._expires.get(key, now + 1) < now:
await self._expire(key)
del self._expires[key]
del self._objects[key]
def __del__(self):
for key in self._expires:
asyncio.create_task(self._expire(key))
def delay(seconds):
def decorator(_f, *args, **kwargs):
async def wrapper():
await asyncio.sleep(seconds)
await _f(*args, **kwargs)
asyncio.create_task(wrapper())
return decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment