Last active
February 29, 2024 02:37
-
-
Save dummerbd/42bb425b3972df073d7a039b7ede47e3 to your computer and use it in GitHub Desktop.
Python asyncio event loop exectuor
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import atexit | |
import logging | |
import queue | |
import threading | |
import weakref | |
from concurrent.futures import Executor, Future, InvalidStateError | |
from typing import Any, Callable, Mapping, Optional, Set, Tuple | |
logger = logging.getLogger("asyncio.EventLoopExecutor") | |
WorkItemTuple = Tuple[Future, Callable, tuple, Mapping] | |
_process_shutdown = False | |
_worker_count = 0 | |
_worker_threads = weakref.WeakSet() | |
@atexit.register | |
def _process_exit() -> None: | |
logger.debug("process exit, shutting down worker threads") | |
global _process_shutdown | |
_process_shutdown = True | |
workers = list(_worker_threads) | |
for worker in workers: | |
worker.join() | |
class EventLoopExecutor(Executor): | |
def __init__(self, max_tasks: Optional[int] = None): | |
self._max_tasks = max_tasks or 10 | |
self._shutdown = False | |
self._work_queue: queue.SimpleQueue[WorkItemTuple] = queue.SimpleQueue() | |
global _worker_count | |
self._worker_thread = threading.Thread( | |
target=self._start_worker_loop, | |
name=f"EventLoopExecutor-{_worker_count}", | |
daemon=True, | |
) | |
_worker_count += 1 | |
_worker_threads.add(self._worker_thread) | |
self._worker_thread.start() | |
def submit(self, fn: Callable, *args: Any, **kwargs: Any) -> Future: | |
if self._shutdown or _process_shutdown: | |
raise RuntimeError("Cannot schedule new work after shutdown") | |
if not self._worker_thread.is_alive(): | |
raise RuntimeError("Event loop worker thread died") | |
fut = Future() | |
self._work_queue.put_nowait((fut, fn, args, kwargs)) | |
return fut | |
def _start_worker_loop(self) -> None: | |
logger.info("starting worker thread") | |
try: | |
asyncio.run(self._run_worker_tasks()) | |
except BaseException: | |
logger.critical("exception in worker thread", exc_info=True) | |
self._shutdown = True | |
raise | |
logger.info("worker thread shutdown done") | |
async def _run_worker_tasks(self) -> None: | |
logger.info("worker event loop started") | |
running_tasks: Set[asyncio.Task] = set() | |
while not self._shutdown and not _process_shutdown: | |
if len(running_tasks) < self._max_tasks: | |
try: | |
self._schedule_worker_task(running_tasks) | |
except queue.Empty: | |
pass | |
# Yield back to the event loop so tasks can run | |
await asyncio.sleep(0) | |
logger.info("worker event loop shutting down") | |
for task in running_tasks: | |
if not task.done(): | |
task.cancel() | |
def _schedule_worker_task(self, running_tasks: Set[asyncio.Task]) -> None: | |
fut, fn, args, kwargs = self._work_queue.get_nowait() | |
fn_name = repr(fn) | |
if not fut.set_running_or_notify_cancel(): | |
# Future was already canceled, no need to run the task | |
logger.debug(f"future already cancelled fn={fn_name}") | |
return | |
def task_done(t: asyncio.Task) -> None: | |
# Discard task | |
running_tasks.remove(t) | |
try: | |
if t.cancelled(): | |
logger.debug(f"task cancelled fn={fn_name}") | |
fut.set_exception(None) | |
return | |
ex = t.exception() | |
if ex: | |
logger.exception(ex) | |
fut.set_exception(ex) | |
else: | |
fut.set_result(t.result()) | |
except InvalidStateError: | |
logger.debug(f"task already done fn={fn_name}") | |
# Schedule work item to run in loop | |
logger.debug(f"scheduling task fn={fn_name}") | |
task = asyncio.create_task(fn(*args, **kwargs)) | |
running_tasks.add(task) | |
task.add_done_callback(task_done) | |
def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None: | |
if self._shutdown: | |
return | |
logger.info("starting shutdown") | |
self._shutdown = True | |
if wait: | |
logger.info("waiting for worker thread to join") | |
self._worker_thread.join() | |
logger.info("worker thread joined") | |
if cancel_futures: | |
logger.info("cancelling pending futures") | |
try: | |
while True: | |
fut, fn, args, kwargs = self._work_queue.get_nowait() | |
fut.cancel() | |
except queue.Empty: | |
pass | |
logger.info("shutdown done") | |
if __name__ == "__main__": | |
import sys | |
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) | |
async def greetings(name: str, wait: int) -> None: | |
print(f"Hello {name}") | |
try: | |
await asyncio.sleep(wait) | |
print(f"Goodbye {name}") | |
except asyncio.CancelledError: | |
print(f"asyncio cancelled {name}") | |
def done(name): | |
def _done(fut): | |
if fut.cancelled(): | |
print(f"future cancelled {name}") | |
else: | |
print(f"future done {name}") | |
return _done | |
executor = EventLoopExecutor(max_tasks=2) | |
slow_fut = executor.submit(greetings, "slow", 0.3) | |
for i in range(6): | |
fut = executor.submit(greetings, str(i), 0.1) | |
fut.add_done_callback(done(str(i))) | |
slow_fut.result() | |
executor.shutdown(cancel_futures=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment