Skip to content

Instantly share code, notes, and snippets.

@Flushot
Last active May 6, 2021 05:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Flushot/f81e1f0db479d115d491a3e034f5b91d to your computer and use it in GitHub Desktop.
Save Flushot/f81e1f0db479d115d491a3e034f5b91d to your computer and use it in GitHub Desktop.
Wrapper for psycopg2 that handles connection pooling, transactions, cursors, and makes the API easier to deal with
"""
PostgreSQL database utilities.
Wrapper for psycopg2 that handles connection pooling, transactions, cursors, and makes the API
easier to deal with.
"""
from collections import defaultdict
import contextlib
import logging
import re
import threading
from typing import Any, Dict, Generator, Type, List, Optional, Tuple, Union # Generator[yields, emits, returns]
import psycopg2
import psycopg2.extensions
import psycopg2.pool
import psycopg2.extras
from tqdm.auto import tqdm
from deprecated import deprecated
LOG = logging.getLogger(__name__)
FORMAT_SQL_QUERY_PAT = re.compile(r'(?:\s|\t|\n){2,}')
_default_pool: Optional[psycopg2.pool.ThreadedConnectionPool] = None
_default_pool_args: Optional[Tuple] = None
# TODO: rename to configure_default_pool
def configure_pool(*args, **kwargs) -> None:
"""
Configure default connection pool options. These options will be set lazily,
then used when get_pool() is called for the first time.
If your application needs to use the default connection pool (i.e. calls get_pool())
then you MUST call this function during app initialization, before get_pool().
For available pool args/kwargs, see:
https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
:param args: pool args.
:param kwargs: pool keyword args.
"""
global _default_pool_args
if _default_pool_args is not None:
LOG.warning('configure_pool() should only be called once')
_default_pool_args = (args, kwargs)
# TODO: rename to get_default_pool
def get_pool() -> psycopg2.pool.ThreadedConnectionPool:
"""
Get the default database connection pool.
This pool is the default connection pool used by all functions in this module,
where a connection isn't explicitly passed as an argument.
You generally don't need to call this directly to get connections, and should
almost always use the connection() context manager instead (so that your connection
is automatically returned to the pool when you're finished).
:return: default connection pool.
"""
global _default_pool
if _default_pool is None:
if _default_pool_args is None:
raise RuntimeError('You must call configure_pool() before get_pool()')
_default_pool = psycopg2.pool.ThreadedConnectionPool(*_default_pool_args[0],
**_default_pool_args[1])
return _default_pool
def get_conn(pool: Optional[psycopg2.pool.ThreadedConnectionPool] = None,
retry_limit: Optional[int] = 5) -> \
Tuple[psycopg2.extensions.connection, psycopg2.pool.ThreadedConnectionPool]:
"""
Check out a database connection from a pool and tests it to ensure it's working.
If the connection is bad, it will be discarded and a new connection will be attempted
up to retry_limit times.
You generally don't need to call this directly to get connections, and should
almost always use the connection() context manager instead (so that your connection
is automatically returned to the pool when you're finished).
:param pool: optional connection pool to use (if unspecified, the default pool will be used).
:param retry_limit: max number of times to get a connection from the pool if it is bad.
if None, the limit is infinite.
:return: tuple of:
- connection that was fetched from the pool.
- pool the connection came from.
"""
# TODO: Handle PoolError('connection pool exhausted') with optional blocking
if pool is None:
pool = get_pool()
@retry((psycopg2.Error,), retry_limit)
def try_get_conn() -> psycopg2.extensions.connection:
conn = pool.getconn()
# Test connection to ensure it's alive
cur = conn.cursor()
try:
cur.execute('select 1')
except psycopg2.Error:
# Connection is bad: Return it to the pool to be discarded.
if conn is not None:
pool.putconn(conn, close=True)
if not cur.closed:
cur.close()
return conn
return try_get_conn(), pool
@contextlib.contextmanager
def connection(pool: Optional[psycopg2.pool.ThreadedConnectionPool] = None) -> \
Generator[psycopg2.extensions.connection, None, None]:
"""
Context manager that returns a connection, then cleans it up and returns it to the pool
when finished. Use this instead of get_conn().
:param pool: optional connection pool to use (will fallback to default pool if unspecified).
:yield: connection.
"""
conn, pool = get_conn(pool)
try:
yield conn
finally:
if conn is not None:
pool.putconn(conn)
@contextlib.contextmanager
def _ensure_connection(conn: Optional[psycopg2.extensions.connection] = None) -> \
Generator[psycopg2.extensions.connection, None, None]:
"""
Internal context manager that will ensure the block is supplied with a user-defined or default
connection.
:param conn: optional user-defined connection (if unspecified, default connection will be used).
:yield: connection.
"""
if conn:
yield conn
else:
with connection() as conn:
yield conn
class cursor:
"""
Context manager for database cursor that has the following behavior:
- Yields a DictCursor that returns dict-like rows.
- Handles transaction behavior:
- Commits transaction upon exit (or rolls back if there was an exception).
- When nested (and when a user-defined connection is passed), supports nested
transaction behavior using savepoints.
- Closes the cursor when finished.
"""
use_savepoints: bool = True
class ConnOpts:
"""
Per-connection options.
"""
def __init__(self):
self.nest_level = 0 # Nested transaction level
self.thread_lock = threading.Lock() # Thread-safe access to vars
_conn_opts: Dict[psycopg2.extensions.connection, ConnOpts] = defaultdict(ConnOpts)
def __init__(self,
conn: Optional[psycopg2.extensions.connection] = None,
pool: Optional[psycopg2.pool.ThreadedConnectionPool] = None,
**cursor_args):
"""
:param conn: optional connection (if unspecified, default connection will be used).
:param pool: optional connection pool (if conn is None and you need to pass a custom pool).
"""
if conn is not None and pool is not None:
raise ValueError('conn and pool are mutually exclusive')
if conn is None:
conn, pool = get_conn(pool)
self._is_managed_connection = True
else:
self._is_managed_connection = False
self._conn = conn
self._pool = pool
self._opts = self._conn_opts[conn] # TODO: convert keys to weakref.ref()
self._cursor_args = cursor_args
if 'cursor_factory' not in self._cursor_args:
self._cursor_args['cursor_factory'] = psycopg2.extras.DictCursor
def __enter__(self) -> psycopg2.extensions.cursor:
"""
:return: cursor.
"""
with self._opts.thread_lock:
self._opts.nest_level += 1
self._cur = self._conn.cursor(**self._cursor_args)
if self._cur.closed:
raise RuntimeError('Connection returned a closed cursor')
return self._cur
def __exit__(self,
exc_type: Optional[Type[BaseException]],
exc_val: Exception,
exc_tb):
with self._opts.thread_lock:
nest_level = self._opts.nest_level
try:
if exc_type is None:
# Success
if not self._is_managed_connection and nest_level > 1:
# Nested transaction (create new savepoint)
if self.use_savepoints:
with _ensure_cursor(self._cur) as cur:
cur.execute(f'savepoint level_{nest_level}')
else:
# Topmost transaction (commit transaction)
self._conn.commit()
else:
# Failure
LOG.error(f'Rolling back {"transaction" if nest_level > 2 else "to previous savepoint"} because of {exc_type} error: {exc_val!r}', exc_info=exc_val)
if not self.use_savepoints or nest_level <= 2:
# Level 1 or 2 transaction (rollback transaction; no previous savepoint to rollback to)
try:
self._conn.rollback()
except psycopg2.Error as ex:
LOG.error(f'Rollback failed because of error: {ex!r}', exc_info=True)
elif self.use_savepoints and not self._is_managed_connection:
# Nested transaction (roll back to previous savepoint)
try:
with cursor(conn=self._conn) as cur: # Create new cursor (current may now be invalid)
cur.execute(f'rollback to savepoint level_{nest_level - 1}')
except psycopg2.Error as ex:
# Compound failure (rollback entire transaction to be safe)
LOG.error(f'Falling back to transaction rollback because savepoint rollback failed: {ex!r}', exc_info=True)
try:
self._conn.rollback()
except psycopg2.Error as ex:
LOG.error(f'Fallback rollback failed because of error: {ex!r}', exc_info=True)
finally:
# Cleanup
if not self._cur.closed:
self._cur.close()
with self._opts.thread_lock:
self._opts.nest_level -= 1
if self._is_managed_connection and self._conn is not None:
self._pool.putconn(self._conn)
return False # Re-raise exceptions
@contextlib.contextmanager
def _ensure_cursor(cur: Optional[psycopg2.extensions.cursor] = None, **cursor_args) -> \
Generator[psycopg2.extensions.cursor, None, None]:
"""
Internal context manager that will ensure the block is supplied with a user-defined or default
cursor.
:param cursor: optional user-defined cursor (if unspecified, default cursor/connection will be used).
:yield: cursor.
"""
if cur:
yield cur
else:
with cursor(**cursor_args) as cur:
yield cur
def fetchmany(statement: str,
params: Optional[Tuple] = None,
use_tqdm: Union[bool, dict] = False,
cur: Optional[psycopg2.extensions.cursor] = None) -> \
Generator[psycopg2.extras.DictRow, None, int]:
"""
Executes SQL and returns multi-row results.
Example:
for row in fetchmany('select id, name from foo where bar = %s', (some_var,)):
print(row['id'])
:param statement: SQL statement.
:param params: SQL statement parameters.
:param use_tqdm: whether ot not to use tqdm progress bar (can also be an options dict for tqdm).
:param cur: optional user-defined cursor.
:yield: row.
:return: row count.
"""
with _ensure_cursor(cur) as cur:
query = cur.mogrify(statement, params)
LOG.debug(f'SQL query: {_format_sql_query(query)}')
cur.execute(query)
if (use_tqdm is True or isinstance(use_tqdm, dict)) and cur.rowcount > 0:
# Show progress bar
tqdm_opts = {}
if isinstance(use_tqdm, dict):
tqdm_opts = use_tqdm
progress = tqdm(total=cur.rowcount, **tqdm_opts)
else:
# Hide progress bar
progress = None
while True:
rows = cur.fetchmany(cur.arraysize)
if len(rows) == 0:
break
for row in rows:
yield row
if progress is not None:
progress.update()
if progress is not None:
progress.close()
return cur.rowcount
def fetchone(statement: str,
params: Optional[Tuple] = None,
cur: Optional[psycopg2.extensions.cursor] = None) -> psycopg2.extras.DictRow:
"""
Execute SQL and returns single row result.
:param statement: SQL statement.
:param params: SQL statement parameters.
:param cur: optional user-defined cursor.
:return: row.
"""
with _ensure_cursor(cur) as cur:
query = cur.mogrify(statement, params)
LOG.debug(f'SQL query: {_format_sql_query(query)}')
cur.execute(query)
return cur.fetchone()
def execute(statement: str,
params: Optional[Union[Tuple, Dict]] = None,
cur: Optional[psycopg2.extensions.cursor] = None) -> None:
"""
Execute a database operation (query or command).
Parameters may be provided as sequence or mapping and will be bound to variables in the operation.
Variables are specified either with positional (%s) or named (%(name)s) placeholders.
The method returns None. If a query was executed, the returned values can be retrieved using
fetch*() methods.
:param statement: SQL statement.
:param params: SQL statement parameters.
:param cur: optional user-defined cursor.
"""
with _ensure_cursor(cur) as cur:
query = cur.mogrify(statement, params)
LOG.debug(f'SQL query: {_format_sql_query(query)}')
cur.execute(query)
def execute_values(statement: str,
values: List[Tuple],
template: Optional[str] = None,
page_size: Optional[int] = None,
fetch: bool = False,
cur: Optional[psycopg2.extensions.cursor] = None) -> \
Generator[psycopg2.extras.DictRow, None, int]:
"""
Execute a statement using VALUES with a sequence of parameters.
:param statement: SQL statement to execute. It must contain a single %s placeholder, which will
be replaced by a VALUES list.
Example: "INSERT INTO mytable (id, f1, f2) VALUES %s".
:param values: sequence of sequences or dictionaries with the arguments to send to the query.
The type and content must be consistent with template.
:param template: the snippet to merge to every item in argslist to compose the query.
- If the argslist items are sequences it should contain positional placeholders
(e.g. "(%s, %s, %s)", or "(%s, %s, 42)” if there are constants value…).
- If the argslist items are mappings it should contain named placeholders
(e.g. "(%(id)s, %(f1)s, 42)").
If not specified, assume the arguments are sequence and use a simple positional template
(i.e. (%s, %s, ...)), with the number of placeholders sniffed by the first element in argslist.
:param page_size: maximum number of argslist items to include in every statement.
If there are more items the function will execute more than one statement.
Defaults to the length of the values parameter.
:param fetch: if True return the query results into a list (like in a fetchall()).
Useful for queries with RETURNING clause.
:param cur: optional user-defined cursor.
:yield: row (if fetch parameter is True).
:return: row count (if fetch parameter is True).
"""
# TODO: Add tqdm_opts parameter like with fetchmany()
if page_size is None:
page_size = len(values)
with _ensure_cursor(cur) as cur:
LOG.debug(f'SQL query: {_format_sql_query(statement.encode("utf-8"))} -> {values!r}')
result = psycopg2.extras.execute_values(cur,
statement,
values,
template=template,
page_size=page_size,
fetch=fetch)
if fetch:
for row in result:
yield row
return cur.rowcount
def upsert(row: Dict[str, Any],
table_name: str,
primary_key: List[str],
include_keys: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
cur: Optional[psycopg2.extensions.cursor] = None) -> Any:
if include_keys is None:
include_keys = row.keys()
if exclude_keys is None:
exclude_keys = []
item_keys = [k for k in include_keys if k not in exclude_keys]
with _ensure_cursor(cur) as cur:
return fetchone(
f'''
insert into {table_name} ({', '.join(item_keys)})
values ({', '.join(['%s' for _ in item_keys])})
on conflict ({', '.join(primary_key)}) do update set
{', '.join([f'{k} = excluded.{k}' for k in item_keys if k not in primary_key])}
returning *
''',
tuple([row.get(key) for key in item_keys]),
cur=cur)
def _format_sql_query(query: bytes) -> str:
"""
Strips extra whitespace and newlines from SQL queries, so that they are easier to read in logs.
:param query: query to format.
:return: formatted query.
"""
return FORMAT_SQL_QUERY_PAT.sub(' ', normalize_line_endings(query.decode('utf-8'))).strip()
def normalize_line_endings(s: str) -> str:
"""
Converts various line ending characters/pairs into \n
:param s: string with possibly abnormal line endings.
:return: normalized string.
"""
return s.replace('\r\n', '\n').replace('\r', '\n')
def retry(exc_types: Sequence[Type],
max_attempts: Optional[int] = None,
delay: int = 0,
error_fn: Optional[Callable[[BaseException], None]] = None) -> Callable:
"""
Decorator that automatically re-calls a function if it throws a set of expected exception types.
:param exc_types: exception classes to retry on.
:param max_attempts: max number of attempts to retry before re-throwing.
if None, there is no limit.
:param delay: optional time delay between retry attempts (in seconds).
:param error_fn: optional function to call (with exception) when an error occurs.
"""
def retry_decorator(f: Callable) -> Callable:
def retryable_func(*args, **kwargs):
for attempt in range(max_attempts):
try:
return f(*args, **kwargs)
except tuple(exc_types) as ex:
if error_fn is not None:
error_fn(ex)
if attempt >= max_attempts:
raise
LOG.warning(f'Retrying because of {ex.__class__.__name__} error: {ex!r}')
if delay > 0:
time.sleep(delay)
return functools.wraps(f)(retryable_func)
return retry_decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment