Created
October 16, 2019 12:27
-
-
Save leonhard-s/a2aeabceeca82b1cdd1815e59c9ffea0 to your computer and use it in GitHub Desktop.
A `discord.py` extension adding support for awaiting multiple websocket events
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Adds support for multi-event waiting. | |
This is particularly useful for catching different but closely-related | |
events, such as `reaction_add` and `reaction_remove`, but it should | |
support any event type the original `wait_for` event can handle. | |
Do note that only the first event encountered will be honoured, any | |
additional events will be dismissed. Also make sure your callback can | |
handle the event's responses, as they may differ from event to event. | |
""" | |
import asyncio | |
from types import MethodType | |
from typing import Any, Callable, Optional, Union, Set | |
import discord | |
from discord.ext import commands | |
async def wait_for_any( | |
self: discord.Client, event_name: str, *args: str, | |
check: Optional[Callable[[Any], bool]] = None, | |
timeout: Optional[Union[float, int]] = None) -> Any: | |
"""Wait for any out of multiple events. | |
Only the first matching event will be honoured, the non-matching | |
events will be discarded. | |
Note that different events may have different return values. | |
""" | |
# If no check has been given, create a dummy that always returns True | |
if check is None: | |
def _check(*args: Any) -> bool: | |
return True | |
check = _check | |
# Create a list of all event names that were passed | |
event_names = set(event_name.lower()) | |
event_names.update(e.lower() for e in args) | |
# Create a future for every event name | |
futures = [self.loop.create_future() for e in event_names] | |
# Register one event / future tuple for every event name | |
for index, name in enumerate(event_names): | |
# Get the list of active event listeners for this event, or create it | |
# if it does not exist | |
try: | |
listeners = self._listeners[name] | |
except KeyError: | |
listeners = [] | |
self._listeners[name] = listeners | |
# Append the tuple | |
listeners.append((futures[index], check)) | |
# Wait for at least one of the futures to be completed | |
done: Set[asyncio.Future[Any]] | |
pending: Set[asyncio.Future[Any]] | |
done, pending = await asyncio.wait(futures, timeout=timeout, | |
return_when=asyncio.FIRST_COMPLETED) | |
# Cancel the incomplete futures | |
for future in pending: | |
future.cancel() | |
try: | |
await future | |
except asyncio.CancelledError: | |
pass | |
# If none of the futures were completed in time, raise a TimeoutError | |
if not done: | |
raise asyncio.TimeoutError | |
# If one of the futures did complete, return its result | |
return done.pop().result() | |
def setup(bot: commands.Bot) -> None: | |
"""Entry point for the bot extension loader.""" | |
bot.wait_for_any = MethodType(wait_for_any, bot) | |
def teardown(bot: commands.Bot) -> None: | |
"""Clean-up utility when unloading the extension.""" | |
del bot.wait_for_any |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment