Last active
November 4, 2019 12:10
-
-
Save aeros/8a86de6b13f17b9f717ea539ee1ee78f to your computer and use it in GitHub Desktop.
asyncio.ThreadPool prototype implementation and demo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Authored by: Kyle Stanley (https://github.com/aeros) | |
import asyncio | |
import concurrent.futures | |
import threading | |
import functools | |
import weakref | |
import os | |
# Only needed for demo. | |
import time | |
_threads_queues = weakref.WeakKeyDictionary() | |
class ThreadPool(concurrent.futures.ThreadPoolExecutor): | |
def __init__(self, concurrency=None): | |
if concurrency is None: | |
concurrency = min(32, (os.cpu_count() or 1) + 4) | |
super().__init__(concurrency) | |
self.concurrency = concurrency | |
self._loop = None | |
async def __aenter__(self): | |
await self.start() | |
return self | |
async def __aexit__(self, *args): | |
await self.aclose() | |
async def start(self): | |
self._loop = asyncio.get_running_loop() | |
await self._spawn_threads() | |
async def aclose(self): | |
await self._shutdown_executor() | |
async def _spawn_threads(self): | |
future = self._loop.create_future() | |
thread = threading.Thread(target=self._do_spawn, args=(future,)) | |
thread.start() | |
try: | |
await future | |
finally: | |
thread.join() | |
# Based on concurrent.futures.ThreadPoolExecutor._adjust_thread_count(). | |
def _do_spawn(self, future): | |
# When the executor gets lost, the weakref callback will wake up | |
# the worker threads. | |
def weakref_cb(_, q=self._work_queue): | |
q.put(None) | |
try: | |
while num_threads := len(self._threads) < self.concurrency: | |
thread_name = '%s_%d' % (self._thread_name_prefix or self, | |
num_threads) | |
thread = threading.Thread(name=thread_name, target=concurrent.futures.thread._worker, | |
args=(weakref.ref(self, weakref_cb), | |
self._work_queue, | |
self._initializer, | |
self._initargs)) | |
thread.daemon = True | |
thread.start() | |
self._threads.add(thread) | |
_threads_queues[thread] = self._work_queue | |
self._loop.call_soon_threadsafe(future.set_result, None) | |
except Exception as ex: | |
self._loop.call_soon_threadsafe(future.set_exception, ex) | |
async def _shutdown_executor(self): | |
future = self._loop.create_future() | |
thread = threading.Thread(target=self._do_shutdown, args=(future,)) | |
thread.start() | |
try: | |
await future | |
finally: | |
thread.join() | |
def _do_shutdown(self, future): | |
try: | |
self.shutdown(wait=True) | |
self._loop.call_soon_threadsafe(future.set_result, None) | |
except Exception as ex: | |
self._loop.call_soon_threadsafe(future.set_exception, ex) | |
async def run(self, func, *args, **kwargs): | |
call = functools.partial(func, *args, **kwargs) | |
return await asyncio.futures.wrap_future(self.submit(call), loop=self._loop) | |
# Optimally, a cpu bound function should be ran in a ProcessPool instead of a | |
# ThreadPool, but this is just for demonstration purposes. | |
def fib(n, start_time): | |
print(f"start fib({n}) at {time.perf_counter() - start_time} seconds") | |
x, y = 0, 1 | |
for _ in range(n): | |
x, y = y, x+y | |
print(f"end fib({n}) at {time.perf_counter() - start_time} seconds") | |
return x | |
async def main(): | |
start_time = time.perf_counter() | |
async with ThreadPool(concurrency=5) as pool: | |
fut1 = pool.run(fib, 500000, start_time) | |
fut2 = pool.run(fib, 400000, start_time) | |
fut3 = pool.run(fib, 200000, start_time) | |
await asyncio.gather(fut1, fut2, fut3) | |
asyncio.run(main()) | |
"""Output (OS: Arch Linux 5.3.7, CPU: Intel i5-4460): | |
start fib(500000) at 0.0011481469991849735 seconds | |
start fib(400000) at 0.00118773800204508 seconds | |
start fib(200000) at 0.006400163998478092 seconds | |
end fib(200000) at 1.439069856001879 seconds | |
end fib(400000) at 4.553958568008966 seconds | |
end fib(500000) at 5.3428641200007405 seconds | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment