Skip to content

Instantly share code, notes, and snippets.

@abetkin
Last active July 19, 2022 08:34
Show Gist options
  • Save abetkin/863dbf79fb55b17f851afdd68d33f668 to your computer and use it in GitHub Desktop.
Save abetkin/863dbf79fb55b17f851afdd68d33f668 to your computer and use it in GitHub Desktop.
import asyncio
import functools
import inspect
from asgiref.sync import async_to_sync
from channels.exceptions import StopConsumer
from channels.generic.websocket import AsyncWebsocketConsumer
from channels.layers import get_channel_layer
class BaseConsumer(AsyncWebsocketConsumer):
async def __call__(self, scope, receive, send):
"""
Dispatches incoming messages to type-based handlers asynchronously.
"""
self.scope = scope
# Initialize channel layer
self.channel_layer = get_channel_layer(self.channel_layer_alias)
if self.channel_layer is not None:
self.channel_name = await self.channel_layer.new_channel()
self.channel_receive = functools.partial(
self.channel_layer.receive, self.channel_name
)
# Store send function
if self._sync:
self.base_send = async_to_sync(send)
else:
self.base_send = send
# Pass messages in from channel layer or client to dispatch method
try:
if self.channel_layer is not None:
await await_many_dispatch(
[receive, self.channel_receive], self.dispatch
)
else:
await await_many_dispatch([receive], self.dispatch)
except StopConsumer:
# Exit cleanly
pass
async def await_many_dispatch(consumer_callables, dispatch):
"""
Given a set of consumer callables, awaits on them all and passes results
from them to the dispatch awaitable as they come in.
"""
# Start them all off as tasks
loop = asyncio.get_event_loop()
tasks = [
consumer_callable()
for consumer_callable in consumer_callables
]
tasks = [
loop.create_task(task) if inspect.iscoroutine(task) else task
for task in tasks
]
try:
while True:
# Wait for any of them to complete
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
# Find the completed one(s), yield results, and replace them
for i, task in enumerate(tasks):
if task.done():
result = task.result()
await dispatch(result)
tasks[i] = asyncio.ensure_future(consumer_callables[i]())
finally:
# Make sure we clean up tasks on exit
for task in tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment