Skip to content

Instantly share code, notes, and snippets.

@amencke
Last active October 23, 2021 21:47
Show Gist options
  • Save amencke/0cffa2c2df55825af0b94e13dd316738 to your computer and use it in GitHub Desktop.
Save amencke/0cffa2c2df55825af0b94e13dd316738 to your computer and use it in GitHub Desktop.
threadsafe connection pool implementation
import asyncio
import concurrent.futures
import threading
import uuid
from timeit import default_timer as timer
class NoConnectionAvailableError(Exception):
def __init__(self, msg):
super(NoConnectionAvailableError, self).__init__(msg)
class ClientConnectionError(Exception):
def __init__(self, msg):
super(ClientConnectionError, self).__init__(msg)
print_lock = threading.Lock()
completed_work = 0
def threadsafe_print(*args, **kwargs):
with print_lock:
print(*args, **kwargs)
class Connection(object):
def __init__(self, id_):
self._id = id_
async def work(self):
global completed_work
completed_work += 1
threadsafe_print(f"Connection {self._id} doing work...")
try:
await asyncio.sleep(1) # I/O bound network operation
except Exception:
raise ClientConnectionError("Something went wrong")
class ConnectionPool(object):
def __init__(self, max_connections, timeout=1):
self._pool_sema = threading.BoundedSemaphore(max_connections)
self._connections = []
self._connection_tracker = set()
self._timeout = timeout
def get(self):
# block until a connection becomes available
if not self._pool_sema.acquire(blocking=True, timeout=self._timeout):
raise NoConnectionAvailableError("No connection available!")
if self._connections:
return self._connections.pop()
conn = Connection(uuid.uuid4())
self._connection_tracker.add(conn)
return conn
def release(self, conn):
assert(conn in self._connection_tracker)
self._connections.append(conn)
# There are never more than min(max_connections, max_workers) connections created
# threadsafe_print(sorted([id(conn) for conn in self._connections]))
self._pool_sema.release()
async def work_and_release(pool):
try:
conn = pool.get()
except NoConnectionAvailableError:
threadsafe_print("Handling connection pool error...")
return
async def _do_work():
try:
await conn.work()
except ClientConnectionError:
threadsafe_print("Handling client connection error...")
return
coros = [_do_work() for _ in range(1000)]
await asyncio.gather(*coros)
pool.release(conn)
if __name__ == '__main__':
pool = ConnectionPool(max_connections=10, timeout=3)
start = timer()
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
for _ in range(40):
f = executor.submit(lambda: asyncio.run(work_and_release(pool)))
end = timer()
print(f"time taken: {end - start}") # ~ requested connections / min(max connections, max workers)
print(f"completed work: {completed_work}")
# 40 iterations, 1000 corutines per iteration, max_connections=10, max_workers=8, 1 second sleep
# ...
# Connection 938aa726-57cf-4737-86a5-a30fb02f5668 doing work...
# Connection fad94190-84a8-4497-b53a-308f24c1d1f0 doing work...
# Connection ff21e605-da6c-4748-90d1-89c9ef60d9db doing work...
# Connection 1d224283-905c-4abf-ac96-bc9301128450 doing work...
# Connection 41afe594-055a-4293-b240-14959467c5bc doing work...
# time taken: 6.165483556
# completed work: 40000
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment