Last active
September 27, 2022 08:55
-
-
Save thehesiod/524a1f005d0f3fb61a8952f272d8709e to your computer and use it in GitHub Desktop.
asyncio cancel all tasks on first task's exception
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import logging | |
from typing import List | |
def _ignore_task_exception(task: asyncio.Future, logger: logging.Logger): | |
# noinspection PyBroadException | |
try: | |
task.result() | |
except BaseException: | |
# These may be things like boto 404s | |
pass | |
# logger.info("Ignoring exception", exc_info=sys.exc_info()) | |
async def _wait_task_cancellations(cancelled_tasks: List[asyncio.Task], logger: logging.Logger, loop: asyncio.BaseEventLoop=None, log_cancellation: bool=False, log_exceptions: bool=True): | |
if not len(cancelled_tasks): | |
return | |
# now wait for cancelled tasks to finish and log exceptions encountered | |
done_futures, _ = await asyncio.wait(cancelled_tasks, loop=loop) | |
for fut in done_futures: | |
try: | |
fut.result() | |
except asyncio.CancelledError: | |
if log_cancellation: | |
logger.exception('Task was cancelled') | |
except: | |
if log_exceptions: | |
logger.exception("Task raised unexpected exception during cancellation") | |
async def gather_cancel_on_raise(*tasks: Union[asyncio.Task, Coroutine, asyncio.Future], loop: asyncio.BaseEventLoop=None, logger: logging.Logger): | |
""" | |
Similar to asyncio.gather, however if any gathering future raises, will | |
immediately cancel any unfinished tasks. This method is useful if `return_exceptions` is set to | |
True (default) with the `asyncio.gather` call. | |
The default asyncio.gather on first task exception will allow the pending tasks to continue and raise | |
said exception. This leaves unparented tasks. | |
:param tasks: tasks to pass to `asyncio.gather` | |
:param loop: loop to pass to `asyncio.gather` | |
:param logger: logger to use | |
:return: result of `asyncio.gather` | |
""" | |
loop = loop or asyncio.get_event_loop() | |
tasks: List[asyncio.Future] = [asyncio.ensure_future(task, loop=loop) for task in tasks] | |
gfut = asyncio.gather(*tasks, loop=loop) | |
try: | |
return await gfut | |
except BaseException as e: | |
outer_cancelled = isinstance(e, asyncio.CancelledError) | |
# first cancel all the unfinished tasks | |
cancelled_tasks = [] | |
# We're reaching into the internals of gather to avoid having to re-implement gather | |
# We have a unittest to ensure this works correctly going forward | |
for task in tasks: | |
if not task.done(): | |
task.cancel() | |
cancelled_tasks.append(task) | |
elif task.cancelled() or task.exception(): | |
_ignore_task_exception(task, logger) | |
await _wait_task_cancellations(cancelled_tasks, logger, loop, not outer_cancelled) | |
# re-raise outer exception | |
raise |
@mvolfik cool thanks for tip, updated above
Thanks for this code ! 👍
Thanks a lot! Works as expected.
For others: you may want to add
from typing import List, Union, Coroutine
👍
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for this! Just a note to line 48 -
asyncio.ensure_future(x)
itself checks forisfuture(x)
, so the condition there is unnecessary, use just[asyncio.ensure_future(task, loop=loop) for task in tasks]