Skip to content

Instantly share code, notes, and snippets.

@KazakovDenis
Last active October 24, 2022 14:52
Show Gist options
  • Save KazakovDenis/5fba068230282cba64d8f4b4e024f9fe to your computer and use it in GitHub Desktop.
Save KazakovDenis/5fba068230282cba64d8f4b4e024f9fe to your computer and use it in GitHub Desktop.
[ Django | psycopg2 ] The DatabaseWrapper with a pool of database connections to share between threads
"""
Django isolates database connections for every thread.
If your application is multithreaded and has many instances
it may open a plenty of database connections. But every connection
is established via TCP and requires a time that may exceed an
execution time of SQL query itself.
To prevent establishing a new database connection we return it to
a pool, where the next thread can pick it up.
Important!
DatabaseWrapper cannot return a connection to a pool by itself.
You should wrap your thread or task for a ThreadPoolExecutor
with `close_db_conn_in_thread` decorator.
"""
from functools import wraps
from threading import RLock
from typing import Dict, Optional
import psycopg2
from django.core.exceptions import ImproperlyConfigured
from django.db import DEFAULT_DB_ALIAS, connections
from django.db.backends.postgresql.base import (
DatabaseWrapper as PostgresqlWrapper,
)
from psycopg2._psycopg import connection as Psycopg2Connection # noqa
from psycopg2.pool import ThreadedConnectionPool
# Django creates a new DatabaseWrapper for each
# new thread, so we register pools for them here
_pool_registry: Dict[str, 'PatchedConnectionPool'] = {}
_pool_lock = RLock()
def close_db_conn_in_thread(func):
"""Ensure that connections established in a thread will be returned to a pool."""
@wraps(func)
def wrapper(*args, **kwargs):
try:
result = func(*args, **kwargs)
finally:
for conn in connections.all():
conn.close()
return result
return wrapper
def is_usable(conn: Psycopg2Connection) -> bool:
"""Check a connection is ready for queries."""
if not conn:
return False
try:
with conn.cursor() as cursor:
cursor.execute('SELECT 1;')
# By default, psycopg2 opens a transaction
# before executing the first command
conn.rollback()
except psycopg2.Error:
return False
return True
class ReusableConnection:
"""
The connection that is returned to its pool instead of closing.
This is a proxy to psycopg2.connection.
"""
def __init__(self, pool: ThreadedConnectionPool, conn: Psycopg2Connection):
self.pool = pool
self.conn = conn
def __getattr__(self, item):
return getattr(self.conn, item)
def __setattr__(self, name, value):
if name in ('pool', 'conn'):
self.__dict__[name] = value
else:
setattr(self.conn, name, value)
def close(self, force: bool = False):
self.pool.putconn(self.conn, close=force)
class PatchedConnectionPool(ThreadedConnectionPool):
"""The pool that patches connections to be reusable."""
def __init__(self, min_conn: int, max_conn: int, *args, **kwargs):
if min_conn > max_conn:
raise ImproperlyConfigured('min_conn should not be greater than max_conn.')
super().__init__(min_conn, max_conn, *args, **kwargs)
def get_connection(self) -> ReusableConnection:
conn = None
while not conn:
conn = self.getconn()
if not is_usable(conn):
self.putconn(conn)
conn = None
return ReusableConnection(self, conn)
def drop_all(self):
"""Close all connections with no closing the pool.
This is useful for app startup preparation (e.g. in ready() methods).
For example, you should drop all DB connections after uWSGI loaded
the app, but before it forked workers.
"""
for conn in self._used.values():
self.putconn(conn, close=True)
for conn in self._pool:
conn.close()
self._pool.clear()
class DatabaseWrapper(PostgresqlWrapper):
"""The database wrapper that uses connections shared between threads."""
pool: Optional[PatchedConnectionPool] = None
isolation_level: Optional[int]
def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
super().__init__(settings_dict, alias)
min_conn = settings_dict.get('POOL_MIN_CONN', 0)
max_conn = settings_dict.get('POOL_MAX_CONN', 100)
self.init_pool(min_conn, max_conn)
def init_pool(self, min_conn, max_conn):
"""Initialize a pool for all DatabaseWrappers with this db alias."""
params = self.get_connection_params()
with _pool_lock:
if self.alias not in _pool_registry:
_pool_registry[self.alias] = PatchedConnectionPool(min_conn, max_conn, **params)
self.pool = _pool_registry[self.alias]
def get_new_connection(self, conn_params) -> ReusableConnection: # noqa: F841
# The super class uses the global object to get a new
# connection, so we need to override the entire method
conn = self.pool.get_connection()
options = self.settings_dict['OPTIONS']
try:
self.isolation_level = options['isolation_level']
except KeyError:
self.isolation_level = conn.isolation_level
else:
if self.isolation_level != conn.isolation_level:
conn.set_session(isolation_level=self.isolation_level)
return conn
@property
def allow_thread_sharing(self):
return True
def force_close(self):
"""Close all connections in a pool related to this db alias."""
self.close()
self.pool.drop_all()
from concurrent.futures import ThreadPoolExecutor
from threading import Thread
from .base import close_db_conn_in_thread
class SharedThreadPoolExecutor(ThreadPoolExecutor):
"""The executor that returns db connections to the pool."""
def submit(self, fn, *args, **kwargs):
fn = close_db_conn_in_thread(fn)
return super().submit(fn, *args, **kwargs)
# or if you use threads directly
@close_db_conn_in_thread
def task():
"""Do some query to db."""
thread = Thread(target=task)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment