Skip to content

Instantly share code, notes, and snippets.

@aeros
Last active November 4, 2019 12:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aeros/8a86de6b13f17b9f717ea539ee1ee78f to your computer and use it in GitHub Desktop.
Save aeros/8a86de6b13f17b9f717ea539ee1ee78f to your computer and use it in GitHub Desktop.
asyncio.ThreadPool prototype implementation and demo
# Authored by: Kyle Stanley (https://github.com/aeros)
import asyncio
import concurrent.futures
import threading
import functools
import weakref
import os
# Only needed for demo.
import time
_threads_queues = weakref.WeakKeyDictionary()
class ThreadPool(concurrent.futures.ThreadPoolExecutor):
def __init__(self, concurrency=None):
if concurrency is None:
concurrency = min(32, (os.cpu_count() or 1) + 4)
super().__init__(concurrency)
self.concurrency = concurrency
self._loop = None
async def __aenter__(self):
await self.start()
return self
async def __aexit__(self, *args):
await self.aclose()
async def start(self):
self._loop = asyncio.get_running_loop()
await self._spawn_threads()
async def aclose(self):
await self._shutdown_executor()
async def _spawn_threads(self):
future = self._loop.create_future()
thread = threading.Thread(target=self._do_spawn, args=(future,))
thread.start()
try:
await future
finally:
thread.join()
# Based on concurrent.futures.ThreadPoolExecutor._adjust_thread_count().
def _do_spawn(self, future):
# When the executor gets lost, the weakref callback will wake up
# the worker threads.
def weakref_cb(_, q=self._work_queue):
q.put(None)
try:
while num_threads := len(self._threads) < self.concurrency:
thread_name = '%s_%d' % (self._thread_name_prefix or self,
num_threads)
thread = threading.Thread(name=thread_name, target=concurrent.futures.thread._worker,
args=(weakref.ref(self, weakref_cb),
self._work_queue,
self._initializer,
self._initargs))
thread.daemon = True
thread.start()
self._threads.add(thread)
_threads_queues[thread] = self._work_queue
self._loop.call_soon_threadsafe(future.set_result, None)
except Exception as ex:
self._loop.call_soon_threadsafe(future.set_exception, ex)
async def _shutdown_executor(self):
future = self._loop.create_future()
thread = threading.Thread(target=self._do_shutdown, args=(future,))
thread.start()
try:
await future
finally:
thread.join()
def _do_shutdown(self, future):
try:
self.shutdown(wait=True)
self._loop.call_soon_threadsafe(future.set_result, None)
except Exception as ex:
self._loop.call_soon_threadsafe(future.set_exception, ex)
async def run(self, func, *args, **kwargs):
call = functools.partial(func, *args, **kwargs)
return await asyncio.futures.wrap_future(self.submit(call), loop=self._loop)
# Optimally, a cpu bound function should be ran in a ProcessPool instead of a
# ThreadPool, but this is just for demonstration purposes.
def fib(n, start_time):
print(f"start fib({n}) at {time.perf_counter() - start_time} seconds")
x, y = 0, 1
for _ in range(n):
x, y = y, x+y
print(f"end fib({n}) at {time.perf_counter() - start_time} seconds")
return x
async def main():
start_time = time.perf_counter()
async with ThreadPool(concurrency=5) as pool:
fut1 = pool.run(fib, 500000, start_time)
fut2 = pool.run(fib, 400000, start_time)
fut3 = pool.run(fib, 200000, start_time)
await asyncio.gather(fut1, fut2, fut3)
asyncio.run(main())
"""Output (OS: Arch Linux 5.3.7, CPU: Intel i5-4460):
start fib(500000) at 0.0011481469991849735 seconds
start fib(400000) at 0.00118773800204508 seconds
start fib(200000) at 0.006400163998478092 seconds
end fib(200000) at 1.439069856001879 seconds
end fib(400000) at 4.553958568008966 seconds
end fib(500000) at 5.3428641200007405 seconds
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment