Skip to content

Instantly share code, notes, and snippets.

@erewok
Last active February 20, 2024 20:39
Show Gist options
  • Save erewok/0db1bd9b131a5f59c30b9380d7b18c20 to your computer and use it in GitHub Desktop.
Save erewok/0db1bd9b131a5f59c30b9380d7b18c20 to your computer and use it in GitHub Desktop.
Azure Python SDK Connection Pooling (Sharing)
"""
connection_pooling.py
The idea for our Azure clients is that unlike database connections
*all of our clients can be reused* across multiple requesters.
What we really need to achieve is the following:
- A "pool" of connections, where
- Each connection may be shared by more than 1 requester, and
- Each connection has an idle lifespan.
The latter is the most important because Azure will enforce
idle timeouts for *all sockets*. For this reason, we will do the following:
- lazily create connections as needed
- share connections between many requesters
- put connections back into an idle data structure when necessary
- when connections are dead (exception or idle timeout) then we'll lock and recreate
- when a connection has exceeded its "share" amount, we'll lock and create a new one.
"""
from abc import ABC, abstractmethod
import binascii
from contextlib import asynccontextmanager
from functools import total_ordering, wraps
import heapq
import logging
import os
import time
from typing import AsyncGenerator
import anyio
from anyio import create_task_group, move_on_after
DEFAULT_MAX_SIZE = 10
DEFAULT_CONNECTION_MAX_IDLE_SECONDS = 300
DEFAULT_SHARED_TRANSPORT_CLIENT_LIMIT = 100
NANO_TIME_MULT = 1_000_000_000
logger = logging.getLogger(__name__)
class ConnectionsExhausted(ValueError):
pass
class ConnectionFailed(ConnectionError):
pass
class AbstractConnection(ABC):
@abstractmethod
async def close(self): ...
class AbstractorConnector(ABC):
@abstractmethod
async def create(self) -> AbstractConnection: ...
@abstractmethod
async def ready(self, connection: AbstractConnection) -> bool: ...
def send_time_deco(log=None, msg: str = ""):
"""
Checking the timing required to invoke: useful for checking if
acquiring a connection takes a long time. Wraps an async function
that acquires and uses a connection pool connection!
Pass the logger you want to use and a sample message.
"""
_logger = log
if _logger is None:
_logger = logger
def send_time_deco_wrapper(fn):
@wraps(fn)
async def inner(*args, **kwargs):
now = time.monotonic_ns()
result = await fn(*args, **kwargs)
timing = time.monotonic_ns() - now
if msg:
message = f"{msg} timing: {timing}ns"
else:
message = f"Connection pool using function timing: {timing}ns"
_logger.debug(message)
return result
return inner
return send_time_deco_wrapper
@total_ordering
class SharedTransportConnection:
"""
Each connection can be shared by many clients.
The problem we need to solve for most pressingly is idle timeouts, but
we also have problems around *opening*, establishing, and *closing* connections.
Thus, each connection has the following lifecycle phases:
- Closed
- Open and not ready
- Open and ready
These are also the critical sections of work, so transitioning from one phase
to another involves *locking* the resource.
The problem is that when a client first attempts to *use* a connection, it calls
one of the network-communication methods, and at that point, the connection
is established. To *other* clients who are `await`ing their turn, the connection
*already looks open*, so they may try to use it early and fail. The same problem
happens on closing: one client closes while another still thinks the connection is live.
Outside of this, after we have sent for the first time, we're fine to share the connection
as much as we want. Thus, we need to lock out all clients during its critical sections of work:
- Lock this connection when *opening* an underlying connection
- Lock this connection when *establishing "readiness"* (first usage)
- Lock this connection when *closing* an underlying connection
At all other points we can share it a whole lot (between 100 clients or more). To see what's
happening, enable debug logs:
logger.getLogger("connection_pooling").setLevel(logging.DEBUG)
Footnote: most Azure sdk clients use aiohttp shared transports below
the surface which actually has threadpooling with up to 100 connections. We wanted something
more generic, though, which is why this class exists.
Azure Python SDK has an example of a shared transport for Azure clients
but we wanted to start simpler and more agnostic here:
https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/core/azure-core/samples/example_shared_transport_async.py#L51-L71
Based on the Azure example, though, we should easily be able to share one of these "connection"
objects with 100 requesters.
Args:
connector:
An instance of an AbstractConnector for creating connections.
client_limit:
The max clients allowed _per_ connection (default: 100).
max_idle_seconds:
Maximum duration allowed for an idle connection before recylcing it.
max_lifespan_seconds:
Optional setting which controls how long a connection live before recycling.
"""
__slots__ = (
"last_idle_start",
"max_idle_ns",
"max_lifespan_ns",
"max_clients_allowed",
"client_limiter",
"connection_created_ts",
"_connector",
"_connection",
"_open_close_lock",
"_id",
"_ready",
)
def __init__(
self,
connector: AbstractorConnector,
client_limit: int = DEFAULT_SHARED_TRANSPORT_CLIENT_LIMIT,
max_lifespan_seconds: int | None = None,
max_idle_seconds: int = DEFAULT_CONNECTION_MAX_IDLE_SECONDS,
) -> None:
# When was this connection created
self.connection_created_ts: int | None = None
# When did this connection last become idle
self.last_idle_start: int | None = None
# What's the max lifespan (idle or active) allowed for this connection
if max_lifespan_seconds is None:
# If `None`, this feature is disabled
self.max_lifespan_ns = None
else:
self.max_lifespan_ns = max_lifespan_seconds * NANO_TIME_MULT
self.max_clients_allowed = client_limit
# What's the max idle time allowed for this connection
self.max_idle_ns = max_idle_seconds * NANO_TIME_MULT
# How many clients are allowed
self.client_limiter = anyio.CapacityLimiter(total_tokens=client_limit)
self._connector: AbstractorConnector = connector
self._connection: AbstractConnection | None = None
self._open_close_lock: anyio.Lock = anyio.Lock()
self._id = (binascii.hexlify(os.urandom(3))).decode()
self._ready = anyio.Event()
@property
def available(self):
"""Check if connection exists and client usage limit has been reached"""
return self.current_client_count < self.max_clients_allowed
@property
def current_client_count(self):
return self.client_limiter.borrowed_tokens
@property
def expired(self) -> bool:
"""Calculate if connection has been idle or active longer than allowed"""
if self._connection is None:
return False
if self.max_lifespan_ns is not None and self.connection_created_ts is not None:
lifetime_expired = self.lifetime > self.max_lifespan_ns
else:
lifetime_expired = False
# Check if time limit has been exceeded
return self.time_spent_idle > self.max_idle_ns or lifetime_expired
@property
def lifetime(self) -> int:
"""Check the lifetime of this object (in nanos)"""
if self.connection_created_ts is None:
return 0
now = time.monotonic_ns()
return now - self.connection_created_ts
@property
def time_spent_idle(self) -> int:
"""Check the idle time of this object (in nanos)"""
if self.last_idle_start is None:
return 0
now = time.monotonic_ns()
return now - self.last_idle_start
def __str__(self):
return f"[Connection {self._id}]"
# The following comparison functions check the "freshness"
# of a connection. Our logic is as follows: a connection is "fresher"
# than another if:
# - it has fewer clients connected
# - it's been idle longer
def __gt__(self, other):
if self.current_client_count == other.current_client_count:
return self.time_spent_idle < other.time_spent_idle
return self.current_client_count > other.current_client_count
def __gte__(self, other):
if self.current_client_count == other.current_client_count:
return self.time_spent_idle <= other.time_spent_idle
return self.current_client_count >= other.current_client_count
def __eq__(self, other):
return (
self.is_ready == other.is_ready
and self.current_client_count == other.current_client_count
and self.time_spent_idle == other.time_spent_idle
)
def __lt__(self, other):
if self.current_client_count == other.current_client_count:
return self.time_spent_idle > other.time_spent_idle
return self.current_client_count < other.current_client_count
def __lte__(self, other):
if self.current_client_count <= other.current_client_count:
return self.time_spent_idle >= other.time_spent_idle
return self.current_client_count <= other.current_client_count
async def checkout(self) -> AbstractConnection:
"""
This function has the important job of keeping
track of `last_idle_start` and making sure a connection has been
established and that it is ready.
Must be followed by checkin!
"""
# Bookkeeping: we want to know how long it takes to acquire a connection
now = time.monotonic_ns()
# We use a semaphore to manage client limits
await self.client_limiter.acquire()
if self.expired and self.current_client_count == 1:
logger.debug(f"{self} Retiring Connection past its lifespan")
# Question: can it be closed because one of our clients yielded
# its await point? In other words, can it be closed out from under a client?
# self.current_client_count == 1 *should* mean this is the only task!
await self.close()
if not self._connection:
self._connection = await self.create()
# Make sure connection is ready
# one thing we could do here is yield the connection and set our event after
# the *first* successful usage, but defining that success is tougher...?
await self.check_readiness()
self.last_idle_start = None
# We do not want to use locks here because it creates a lot of lock contention,
# We *only* need to wait for the first successful connection to indicate readiness
await self._ready.wait()
# Debug timings to reveal the extent of lock contention
logger.debug(
f"{self} available in {time.monotonic_ns()-now}ns "
f"semaphore state {repr(self.client_limiter)} "
f"Active client count: {self.current_client_count}"
)
return self._connection
async def checkin(self):
"""Called after a connection has been used"""
# Release the client limit semaphore!
try:
self.client_limiter.release()
except RuntimeError:
pass
logger.debug(f"{self}.current_client_count is now {self.current_client_count}")
# we only consider idle time to start when *no* clients are connected
if self.current_client_count == 0:
logger.debug(f"{self} is now idle")
self.last_idle_start = time.monotonic_ns()
# Check if TTL exceeded for this connection
if self.max_lifespan_ns is not None and self.expired:
await self.close()
@asynccontextmanager
async def acquire(
self, timeout=10
) -> AsyncGenerator[AbstractConnection | None, None]:
"""Acquire a connection with a timeout"""
acquired_conn = None
async with create_task_group():
with move_on_after(timeout) as scope:
acquired_conn = await self.checkout()
# If this were nested under `create_task_group` then any exceptions
# get thrown under `BaseExceptionGroup`, which is surprising for clients.
# See: https://github.com/agronholm/anyio/issues/141
if not scope.cancelled_caught and acquired_conn:
try:
yield acquired_conn
finally:
await self.checkin()
else:
await self.checkin()
yield None
async def create(self) -> AbstractConnection:
"""Establishes the connection or reuses existing if already created."""
if self._connection:
return self._connection
# We use a lock on *opening* a connection
async with self._open_close_lock:
logger.debug(f"{self} Creating a new connection")
self._connection = await self._connector.create()
# Check if we need to expire connections based on lifespan
if self.max_lifespan_ns is not None:
self.connection_created_ts = time.monotonic_ns()
return self._connection
@property
def is_ready(self) -> bool:
"""Proxy for whether our readiness Event has been set."""
return self._ready.is_set()
async def check_readiness(self) -> None:
"""Indicates when ready by waiting for the connector to signal"""
if self._ready.is_set():
return None
# We use a lock when making sure the connection is ready
# Our goal is to set readiness Event once for one client.
if self._connection:
async with self._open_close_lock:
logger.debug(f"{self} Setting readiness")
is_ready = await self._connector.ready(self._connection)
if is_ready:
self._ready.set()
else:
raise ConnectionFailed("Failed readying connection")
async def close(self) -> None:
"""Closes the connection"""
if self._connection is None:
return None
# We use a lock on *closing* a connection
async with self._open_close_lock:
logger.debug(f"{self} Closing the Connection")
try:
await self._connection.close()
except Exception:
pass
self._connection = None
self._ready = anyio.Event()
self.last_idle_start = None
if self.max_lifespan_ns is not None:
self.connection_created_ts = None
class ConnectionPool:
"""
Our goal here is to allow many clients to share connections,
but to expire them when they've reached their idle time limits.
Most clients can call this with the default values below.
Args:
connector:
An instance of an AbstractConnector for creating connections.
client_limit:
The max clients allowed _per_ connection (default: 100).
max_size:
The max size for the connection pool or max connections held (default: 10).
max_idle_seconds:
Maximum duration allowed for an idle connection before recylcing it.
max_lifespan_seconds:
Optional setting which controls how long a connection live before recycling.
"""
def __init__(
self,
connector: AbstractorConnector,
client_limit: int = DEFAULT_SHARED_TRANSPORT_CLIENT_LIMIT,
max_size: int = DEFAULT_MAX_SIZE,
max_idle_seconds: int = DEFAULT_CONNECTION_MAX_IDLE_SECONDS,
max_lifespan_seconds: int | None = None,
):
# Each shared connection allows up to this many connections
self.client_limit = client_limit
# Pool's max size
self.max_size = max_size
if self.max_size < 1:
raise ValueError("max_size must a postive integer")
# Number of available connections
self.connections: int = 0
self.connector = connector
# A pool is just a heap of connection-managing things
# All synchronization primitives are in the connections
self._pool = [
SharedTransportConnection(
self.connector,
client_limit=self.client_limit,
max_idle_seconds=max_idle_seconds,
max_lifespan_seconds=max_lifespan_seconds,
)
for _ in range(self.max_size)
]
heapq.heapify(self._pool)
self.max_lifespan_seconds = max_lifespan_seconds
@asynccontextmanager
async def get(
self,
timeout=10.0,
) -> AsyncGenerator[AbstractConnection, None]:
"""
Pull out an idle connection.
The binary heap allows us to always pull out the *youngest*
connection, which is the one most likely to connect without issues.
This relies on the less-than/greater-than implementation above.
Throws: `ConnectionsExhausted` if too many connections opened.
"""
connection_reached = False
total_time = 0
# We'll loop through roughly half the pool to find a candidate
# We add one in case it's zero.
conn_check_n = (self.max_size // 2) + 1
while not connection_reached and total_time < timeout:
for _conn in heapq.nsmallest(conn_check_n, self._pool):
if _conn.available:
async with _conn.acquire(timeout=timeout) as conn:
if conn is not None:
yield conn
connection_reached = True
break
await anyio.sleep(0.01)
total_time += 0.01
if connection_reached:
heapq.heapify(self._pool)
else:
raise ConnectionsExhausted(
"No connections available: consider using a larger value for `client_limit`"
)
async def closeall(self) -> None:
"""Close all connections"""
async with create_task_group() as tg:
for conn in self._pool:
tg.start_soon(conn.close)
"""
service_bus.py
This is an example of using connection_pooling to send ServiceBus messages.
"""
import datetime
import logging
from azure.core import exceptions
from azure.identity.aio import DefaultAzureCredential
from azure.servicebus.aio import ServiceBusClient, ServiceBusReceiver, ServiceBusSender
from azure.servicebus import ServiceBusMessage, ServiceBusReceiveMode
from . import connection_pooling
SERVICE_BUS_SEND_TTL_SECONDS = 300
logger = logging.getLogger(__name__)
class AzureServiceBus:
"""
Basic AzureServiceBus client without connection pooling.
For connection pooling see `ManagedAzureServiceBus` below.
"""
def __init__(
self,
service_bus_namespace_url: str,
service_bus_queue_name: str,
credential: DefaultAzureCredential,
):
self.namespace_url = service_bus_namespace_url
self.queue_name = service_bus_queue_name
self.credential = credential
self._client: ServiceBusClient | None = None
self._receiver_client: ServiceBusReceiver | None = None
self._sender_client: ServiceBusSender | None = None
def _validate_access_settings(self):
if not all((self.namespace_url, self.queue_name, self.credential)):
raise ValueError("Invalid configuration for AzureServiceBus")
return None
@property
def client(self):
if self._client is None:
self._validate_access_settings()
self._client = ServiceBusClient(self.namespace_url, self.credential)
return self._client
def get_receiver(self) -> ServiceBusReceiver:
if self._receiver_client is not None:
return self._receiver_client
self._receiver_client = self.client.get_queue_receiver(
queue_name=self.queue_name, receive_mode=ServiceBusReceiveMode.PEEK_LOCK
)
return self._receiver_client
def get_sender(self) -> ServiceBusSender:
if self._sender_client is not None:
return self._sender_client
self._sender_client = self.client.get_queue_sender(queue_name=self.queue_name)
return self._sender_client
async def close(self):
if self._receiver_client is not None:
await self._receiver_client.close()
self._receiver_client = None
if self._sender_client is not None:
await self._sender_client.close()
self._sender_client = None
if self._client is not None:
await self._client.close()
self._client = None
async def send_message(self, msg: str, delay: int = 0):
message = ServiceBusMessage(msg)
now = datetime.datetime.now(tz=datetime.timezone.utc)
scheduled_time_utc = now + datetime.timedelta(seconds=delay)
sender = self.get_sender()
await sender.schedule_messages(message, scheduled_time_utc)
class ManagedAzureServiceBusSender(connection_pooling.AbstractorConnector):
"""Azure ServiceBus Sender client with connnection pooling built in."""
def __init__(
self,
service_bus_namespace_url: str,
service_bus_queue_name: str,
credential: DefaultAzureCredential,
client_limit: int = connection_pooling.DEFAULT_SHARED_TRANSPORT_CLIENT_LIMIT,
max_size: int = connection_pooling.DEFAULT_MAX_SIZE,
max_idle_seconds: int = SERVICE_BUS_SEND_TTL_SECONDS,
ready_message: str = "Connection established",
):
self.service_bus_namespace_url = service_bus_namespace_url
self.service_bus_queue_name = service_bus_queue_name
self.credential = credential
self.pool = connection_pooling.ConnectionPool(
self,
client_limit=client_limit,
max_size=max_size,
max_idle_seconds=max_idle_seconds,
)
self.ready_message = ready_message
def get_sender(self) -> ServiceBusSender:
client = AzureServiceBus(
self.service_bus_namespace_url,
self.service_bus_queue_name,
self.credential,
)
return client.get_sender()
async def create(self) -> ServiceBusSender:
"""Creates a new connection for our pool"""
return self.get_sender()
def get_receiver(self) -> ServiceBusReceiver:
"""
Proxy for AzureServiceBus.get_receiver. Here
for consistency with above class.
"""
client = AzureServiceBus(
self.service_bus_namespace_url,
self.service_bus_queue_name,
self.credential,
)
return client.get_receiver()
async def close(self):
"""Closes all connections in our pool"""
await self.pool.closeall()
@connection_pooling.send_time_deco(logger, "ServiceBus.ready")
async def ready(self, conn: ServiceBusSender) -> bool:
"""Establishes readiness for a new connection"""
message = ServiceBusMessage(self.ready_message)
now = datetime.datetime.now(tz=datetime.timezone.utc)
attempts = 3
while attempts > 0:
try:
await conn.schedule_messages(message, now)
return True
except (AttributeError, exceptions.AzureError):
logger.warning(
f"ServiceBus readiness check #{3 - attempts} failed; trying again."
)
attempts -= 1
logger.error("ServiceBus readiness check failed. Not ready.")
return False
@connection_pooling.send_time_deco(logger, "ServiceBus.send_message")
async def send_message(self, msg: str, delay: int = 0):
message = ServiceBusMessage(msg)
now = datetime.datetime.now(tz=datetime.timezone.utc)
scheduled_time_utc = now + datetime.timedelta(seconds=delay)
async with self.pool.get() as conn:
await conn.schedule_messages(message, scheduled_time_utc)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment