Skip to content

Instantly share code, notes, and snippets.

@dikkini
Forked from danielrichman/connection_factory.py
Created October 18, 2017 20:00
Show Gist options
  • Save dikkini/f455324a6c218da49698a7071423810d to your computer and use it in GitHub Desktop.
Save dikkini/f455324a6c218da49698a7071423810d to your computer and use it in GitHub Desktop.
nicer postgres connection class & flask postgres
class PostgreSQLConnection(psycopg2.extensions.connection):
"""
A custom `connection_factory` for :func:`psycopg2.connect`.
This
* puts the connection into unicode mode (for text)
* modifies the :meth:`cursor` method of a :class:`psycopg2.connection`,
facilitating easy acquiring of cursors made from
:cls:`psycopg2.extras.RealDictCursor`.
"""
# this may be omitted in py3k
def __init__(self, *args, **kwargs):
super(PostgreSQLConnection, self).__init__(*args, **kwargs)
for type in (psycopg2.extensions.UNICODE,
psycopg2.extensions.UNICODEARRAY):
psycopg2.extensions.register_type(type, self)
def cursor(self, real_dict_cursor=False):
"""
Get a new cursor.
If real_dict_cursor is set, a RealDictCursor is returned
"""
kwargs = {}
if real_dict_cursor:
kwargs["cursor_factory"] = psycopg2.extras.RealDictCursor
return super(PostgreSQLConnection, self).cursor(**kwargs)
from __future__ import unicode_literals
import logging
import threading
import flask
from werkzeug.local import LocalProxy
import psycopg2
import psycopg2.extras
import psycopg2.extensions
postgres = LocalProxy(lambda: flask.current_app.postgres)
class PostgreSQL(object):
"""
A PostgreSQL helper extension for Flask apps
On initialisation it adds an after_request function that commits the
transaction (so that if the transaction rolls back the request will
fail) and a app context teardown function that disconnects any active
connection.
You can of course (and indeed should) use :meth:`commit` if you need to
ensure some changes have made it to the database before performing
some other action. :meth:`teardown` is also available to be called
directly.
Connections are created by ``psycopg2.connect(**app.config["POSTGRES"])``
(e.g., ``app.config["POSTGRES"] = {"database": "mydb"}``),
are pooled (you can adjust the pool size with `pool`) and are tested for
server shutdown before being given to the request.
"""
def __init__(self, app=None, pool_size=2):
self.app = app
self._pool = []
self.pool_size = pool_size
self._lock = threading.RLock()
self.logger = logging.getLogger(__name__ + ".PostgreSQL")
if app is not None:
self.init_app(app)
def init_app(self, app):
"""
Initialises the app by adding hooks
* Hook: ``app.after_request(self.commit)``
* Hook: ``app.teardown_appcontext(self.teardown)``
"""
app.after_request(self.commit)
app.teardown_appcontext(self.teardown)
app.postgresql = self
def _connect(self):
"""Returns a connection to the database"""
with self._lock:
c = None
if len(self._pool):
c = self._pool.pop()
try:
# This tests if the connection is still alive.
c.reset()
except psycopg2.OperationalError:
self.logger.debug("assuming pool dead", exc_info=True)
# assume that the entire pool is dead
try:
c.close()
except psycopg2.OperationalError:
pass
for c in self._pool:
try:
c.close()
except psycopg2.OperationalError:
pass
self._pool = []
c = None
else:
self.logger.debug("got connection from pool")
if c is None:
c = self._new_connection()
return c
def _new_connection(self):
"""Create a new connection to the database"""
s = flask.current_app.config["POSTGRES"]
summary = ' '.join(k + "=" + v for k, v in s.iteritems())
self.logger.debug("connecting (%s)", summary)
c = psycopg2.connect(connection_factory=PostgreSQLConnection, **s)
return c
@property
def connection(self):
"""
Gets the PostgreSQL connection for this Flask request
If no connection has been used in this request, it connects to the
database. Further use of this property will reference the same
connection
The connection is committed and closed at the end of the request.
"""
g = flask.g
if not hasattr(g, '_postgresql'):
g._postgresql = self._connect()
return g._postgresql
def cursor(self, real_dict_cursor=False):
"""
Get a new postgres cursor for immediate use during a request
If a cursor has not yet been used in this request, it connects to the
database. Further cursors re-use the per-request connection.
The connection is committed and closed at the end of the request.
If real_dict_cursor is set, a RealDictCursor is returned
"""
return self.connection.cursor(real_dict_cursor)
def commit(self, response=None):
"""
(Almost an) alias for self.connection.commit()
... except if self.connection has never been used this is a noop
(i.e., it does nothing)
Returns `response` unmodified, so that this may be used as an
:meth:`flask.after_request` function.
"""
g = flask.g
if hasattr(g, '_postgresql'):
self.logger.debug("committing")
g._postgresql.commit()
return response
def teardown(self, exception):
"""Either return the connection to the pool or close it"""
g = flask.g
if hasattr(g, '_postgresql'):
c = g._postgresql
del g._postgresql
with self._lock:
s = len(self._pool)
if s >= self.pool_size:
self.logger.debug("teardown: pool size %i - closing", s)
c.close()
else:
self.logger.debug("teardown: adding to pool, new size %i",
s + 1)
c.reset()
self._pool.append(c)
import psycopg2
import psycopg2.extras
import psycopg2.extensions
import flask
class MockConnectionBase(object):
def __init__(self):
# prevent connection setup
self.registered_types = []
pass
def cursor(self, cursor_factory=None):
assert cursor_factory is None or \
cursor_factory == psycopg2.extras.RealDictCursor
if cursor_factory is None:
return "stubbed cursor"
else:
return "stubbed dict cursor"
class ConnectionRebaser(type):
def __new__(mcs, name, bases, dict):
bases += (MockConnectionBase, )
return type.__new__(mcs, name, bases, dict)
def mro(cls):
return (cls, utils.PostgreSQLConnection, MockConnectionBase) + \
utils.PostgreSQLConnection.__mro__[1:]
class RebasedPostgreSQLConnection(utils.PostgreSQLConnection):
__metaclass__ = ConnectionRebaser
class FakeExtensions(object):
UNICODE = psycopg2.extensions.UNICODE
UNICODEARRAY = psycopg2.extensions.UNICODEARRAY
@classmethod
def register_type(self, what, connection):
connection.registered_types.append(what)
class TestPostgreSQLConnection(object):
def setup(self):
assert psycopg2.extensions.__name__ == "psycopg2.extensions"
self.fakes = FakeExtensions()
self.original_extensions = psycopg2.extensions
psycopg2.extensions = self.fakes
def teardown(self):
assert isinstance(psycopg2.extensions, FakeExtensions)
psycopg2.extensions = self.original_extensions
def test_only_affects_cursor(self):
assert [x for x in utils.PostgreSQLConnection.__dict__
if not x.startswith("__")] == ["cursor"]
def test_cursor(self):
c = RebasedPostgreSQLConnection()
assert c.cursor() == "stubbed cursor"
assert c.cursor(False) == "stubbed cursor"
assert c.cursor(True) == "stubbed dict cursor"
def test_register_types(self):
c = RebasedPostgreSQLConnection()
assert c.registered_types == [psycopg2.extensions.UNICODE,
psycopg2.extensions.UNICODEARRAY]
class FakePsycopg2(object):
class connection(object):
class _cursor(object):
def __init__(self, connection, real_dict_cursor):
self.queries = []
self.connection = connection
self.real_dict_cursor = real_dict_cursor
def __enter__(self):
return self
def __exit__(self, *args):
pass
def execute(self, query, args=None):
if self.connection.calls["close"] != 0:
raise psycopg2.OperationalError
self.queries.append((query, args))
close_error = False
autocommit = False
def __init__(self, **settings):
self.settings = settings
self.types = []
self.cursors = []
self.calls = {"commit": 0, "reset": 0, "close": 0}
def cursor(self, real_dict_cursor=False):
c = self._cursor(self, real_dict_cursor)
self.cursors.append(c)
return c
def commit(self):
assert self.calls["close"] == 0
self.calls["commit"] += 1
def reset(self):
self.calls["reset"] += 1
if self.calls["close"] != 0:
raise psycopg2.OperationalError
def close(self):
self.calls["close"] += 1
if self.close_error:
raise psycopg2.OperationalError
connections = 0
def connect(self, **settings):
self.connections += 1
return self.connection(**settings)
OperationalError = psycopg2.OperationalError
extras = psycopg2.extras
class TestPostgreSQL(object):
def setup(self):
assert utils.psycopg2 is psycopg2
self.fakes = utils.psycopg2 = FakePsycopg2()
self.app = flask.Flask(__name__)
self.app.config["POSTGRES"] = {"database": "mydb", "user": "steve"}
self.postgres = utils.PostgreSQL(self.app)
def teardown(self):
assert isinstance(utils.psycopg2, FakePsycopg2)
utils.psycopg2 = psycopg2
def test_adds_hooks(self):
assert self.app.after_request_funcs == {None: [self.postgres.commit]}
assert self.app.teardown_appcontext_funcs == [self.postgres.teardown]
def test_connect_new(self):
with self.app.test_request_context("/"):
c = self.postgres.connection
assert isinstance(c, self.fakes.connection)
assert c.settings == \
{"database": "mydb", "user": "steve",
"connection_factory": utils.PostgreSQLConnection}
assert c.calls == {"commit": 0, "reset": 0, "close": 0}
assert c.autocommit is False
def test_connect_once(self):
with self.app.test_request_context("/"):
c = self.postgres.connection
d = self.postgres.connection
assert c is d
assert self.fakes.connections == 1
assert c.calls == {"commit": 0, "reset": 0, "close": 0}
def test_teardown_resets_before_store(self):
with self.app.test_request_context("/"):
c = self.postgres.connection
assert c.calls == {"commit": 0, "reset": 1, "close": 0}
def test_connect_from_pool(self):
with self.app.test_request_context("/"):
c = self.postgres.connection
assert c.calls == {"commit": 0, "reset": 1, "close": 0}
with self.app.test_request_context("/"):
d = self.postgres.connection
assert d is c
assert d.calls == {"commit": 0, "reset": 2, "close": 0}
assert d.calls == {"commit": 0, "reset": 3, "close": 0}
def test_removes_from_pool(self):
# put a connection in the pool
with self.app.test_request_context("/"):
c = self.postgres.connection
# now get two connections from the pool
# must explicitly create two app contexts.
# in normal usage, flask promises to never share an app context
# between requests. When testing, it will only create an app context
# when test_request_context is __enter__'d and there is no existing
# app context
with self.app.app_context(), self.app.test_request_context("/1"):
d = self.postgres.connection
with self.app.app_context(), self.app.test_request_context("/2"):
e = self.postgres.connection
assert d is c
assert e is not d
assert d.calls == {"commit": 0, "reset": 3, "close": 0}
assert e.calls == {"commit": 0, "reset": 1, "close": 0}
def test_connect_from_pool_bad(self):
# put two distinct connections in the pool
with self.app.app_context(), self.app.test_request_context("/1"):
c = self.postgres.connection
with self.app.app_context(), self.app.test_request_context("/2"):
d = self.postgres.connection
assert c is not d
assert c.calls == d.calls == {"commit": 0, "reset": 1, "close": 0}
c.close()
with self.app.test_request_context("/"):
e = self.postgres.connection
# it should try c.reset, which will fail, and then destroy the
# pool by closing d as well
# one close call by uut, one close call from above
assert c.calls == {"commit": 0, "reset": 2, "close": 1 + 1}
assert d.calls == {"commit": 0, "reset": 1, "close": 1}
# e should be a new connection
assert e is not c and e is not d
assert e.calls == {"commit": 0, "reset": 0, "close": 0}
assert e.calls == {"commit": 0, "reset": 1, "close": 0}
def test_absorbs_close_errors(self):
with self.app.app_context(), self.app.test_request_context("/1"):
c = self.postgres.connection
with self.app.app_context(), self.app.test_request_context("/2"):
d = self.postgres.connection
c.close()
d.close_error = True
with self.app.test_request_context("/"):
e = self.postgres.connection
def test_teardown_closes_if_pool_full(self):
# default pool size is 2
with self.app.app_context(), self.app.test_request_context("/1"):
c = self.postgres.connection
with self.app.app_context(), self.app.test_request_context("/2"):
d = self.postgres.connection
with self.app.app_context(), \
self.app.test_request_context("/3"):
e = self.postgres.connection
assert len(set([c, d, e])) == 3
assert c.calls == {"commit": 0, "reset": 0, "close": 1}
assert d.calls == {"commit": 0, "reset": 1, "close": 0}
assert e.calls == {"commit": 0, "reset": 1, "close": 0}
with self.app.app_context(), self.app.test_request_context("/1"):
f = self.postgres.connection
with self.app.app_context(), self.app.test_request_context("/2"):
g = self.postgres.connection
with self.app.app_context(), \
self.app.test_request_context("/3"):
h = self.postgres.connection
assert f is d
assert g is e
assert len(set([d, e, h])) == 3
def test_cursor(self):
with self.app.test_request_context("/"):
c = self.postgres.connection
x = self.postgres.cursor()
assert isinstance(x, c._cursor)
assert x.connection is c
assert len(c.cursors) == 1
y = self.postgres.cursor()
assert y.connection is c
assert len(c.cursors) == 2
assert c.calls == {"commit": 0, "reset": 1, "close": 0}
with self.app.test_request_context("/"):
# cursor first without asking for connection explicitly
x = self.postgres.cursor()
assert isinstance(x, utils.psycopg2.connection._cursor)
c = x.connection
assert isinstance(c, utils.psycopg2.connection)
assert c.calls == {"commit": 0, "reset": 2, "close": 0}
assert c.settings == \
{"database": "mydb", "user": "steve",
"connection_factory": utils.PostgreSQLConnection}
assert c is self.postgres.connection
def test_dict_cursor(self):
with self.app.test_request_context("/"):
c = self.postgres.cursor(True)
assert len(self.postgres.connection.cursors) == 1
assert c.real_dict_cursor
c = self.postgres.cursor()
assert not c.real_dict_cursor
def test_commit(self):
with self.app.test_request_context("/"):
c = self.postgres.connection
self.postgres.commit()
assert c.calls == {"commit": 1, "reset": 0, "close": 0}
def test_commit_as_hook(self):
# as an after request hook, commit must return the response object
# it is passed
response = object()
with self.app.test_request_context("/"):
c = self.postgres.connection
assert self.postgres.commit(response) is response
# now check it works as a hook
with self.app.test_request_context("/"):
d = self.postgres.connection
assert d is c
assert self.app.process_response(response) is response
assert c.calls == {"commit": 2, "reset": 3, "close": 0}
def test_commit_nop_if_no_connection(self):
with self.app.test_request_context("/"):
self.postgres.commit()
assert utils.psycopg2.connections == 0
with self.app.test_request_context("/"):
self.app.process_response(None)
assert utils.psycopg2.connections == 0
# should nop if teardown puts the connection in the pool
with self.app.test_request_context("/"):
c = self.postgres.connection
self.postgres.teardown(None)
self.postgres.commit()
assert c.calls == {"commit": 0, "reset": 1, "close": 0}
def test_teardown_nop_if_no_connection(self):
with self.app.test_request_context("/"):
self.postgres.teardown(None)
assert utils.psycopg2.connections == 0
with self.app.test_request_context("/"):
c = self.postgres.connection
self.postgres.teardown(None)
assert c.calls == {"commit": 0, "reset": 1, "close": 0}
assert c.calls == {"commit": 0, "reset": 1, "close": 0}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment