Last active
July 19, 2022 08:34
-
-
Save abetkin/863dbf79fb55b17f851afdd68d33f668 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import functools | |
import inspect | |
from asgiref.sync import async_to_sync | |
from channels.exceptions import StopConsumer | |
from channels.generic.websocket import AsyncWebsocketConsumer | |
from channels.layers import get_channel_layer | |
class BaseConsumer(AsyncWebsocketConsumer): | |
async def __call__(self, scope, receive, send): | |
""" | |
Dispatches incoming messages to type-based handlers asynchronously. | |
""" | |
self.scope = scope | |
# Initialize channel layer | |
self.channel_layer = get_channel_layer(self.channel_layer_alias) | |
if self.channel_layer is not None: | |
self.channel_name = await self.channel_layer.new_channel() | |
self.channel_receive = functools.partial( | |
self.channel_layer.receive, self.channel_name | |
) | |
# Store send function | |
if self._sync: | |
self.base_send = async_to_sync(send) | |
else: | |
self.base_send = send | |
# Pass messages in from channel layer or client to dispatch method | |
try: | |
if self.channel_layer is not None: | |
await await_many_dispatch( | |
[receive, self.channel_receive], self.dispatch | |
) | |
else: | |
await await_many_dispatch([receive], self.dispatch) | |
except StopConsumer: | |
# Exit cleanly | |
pass | |
async def await_many_dispatch(consumer_callables, dispatch): | |
""" | |
Given a set of consumer callables, awaits on them all and passes results | |
from them to the dispatch awaitable as they come in. | |
""" | |
# Start them all off as tasks | |
loop = asyncio.get_event_loop() | |
tasks = [ | |
consumer_callable() | |
for consumer_callable in consumer_callables | |
] | |
tasks = [ | |
loop.create_task(task) if inspect.iscoroutine(task) else task | |
for task in tasks | |
] | |
try: | |
while True: | |
# Wait for any of them to complete | |
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) | |
# Find the completed one(s), yield results, and replace them | |
for i, task in enumerate(tasks): | |
if task.done(): | |
result = task.result() | |
await dispatch(result) | |
tasks[i] = asyncio.ensure_future(consumer_callables[i]()) | |
finally: | |
# Make sure we clean up tasks on exit | |
for task in tasks: | |
task.cancel() | |
try: | |
await task | |
except asyncio.CancelledError: | |
pass | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment