Skip to content

Instantly share code, notes, and snippets.

@maksymx
Forked from jasonwalkeryung/sqlalchemy_replica.py
Created July 14, 2020 20:24
Show Gist options
  • Save maksymx/bc6feb3525da4ed9f3961936d2d76893 to your computer and use it in GitHub Desktop.
Save maksymx/bc6feb3525da4ed9f3961936d2d76893 to your computer and use it in GitHub Desktop.
SQLAlchemy read replica wrapper
"""This is not the full code. We do a lot of stuff to clean up connections, particularly for unit testing."""
import sqlalchemy
from sqlalchemy.orm import Query, Session, scoped_session, sessionmaker
CONFIG_KEY_SQLALCHEMY_BINDS = 'SQLALCHEMY_BINDS'
CONFIG_KEY_SQLALCHEMY_RO_BINDS = 'SQLALCHEMY_READ_ONLY_BINDS'
class Config:
# These default values are for testing. In a deployed environment, they would be three separate instances.
SQLALCHEMY_DATABASE_URI = 'postgresql://localhost/branded_dev'
SQLALCHEMY_READ_ONLY_BINDS = {
'replica': 'postgresql://localhost/branded_dev',
'replica_analytics': 'postgresql://localhost/branded_dev'
}
class DBSessionFactory:
"""
A wrapper for getting db sessions from the primary and read replicas.
"""
def register(config):
self.engines = dict() # type: Dict[str, Engine]
self.read_only_engines = defaultdict(list) # type: Dict[str, Engine]
# The session factories to be used by scoped_session to connect
self.session_factories = dict() # Dict[str, sessionmaker]
# The scoped sessions for each connection.
self.scoped_sessions = dict() # Dict[str, scoped_session]
# The scoped sessions for each read only connection.
self.read_only_scoped_sessions = defaultdict(list) # Dict[str, List[scoped_session]]
# The primary connection
self.add_engine(
'primary', config.SQLALCHEMY_DATABASE_URI, config=config
)
# Other read-write dbs
for name, connect_url in config[CONFIG_KEY_SQLALCHEMY_BINDS].items():
self.add_engine(name, connect_url, config=config)
# Read replica binds
for name, connect_url in config[CONFIG_KEY_SQLALCHEMY_RO_BINDS].items():
self.add_engine(name, connect_url, config=config, read_only=True)
def add_engine(self, name: DBInstance, uri: str, config: Config, read_only=False) -> None:
"""Initialize a database connection and register it in the appropriate internal dicts."""
# Clean up existing engine if present
if self.engines.get(name) or self.read_only_engines.get(name):
self.session_factories[name].close_all()
engines = [self._create_engine(u, config) for u in uri] if isinstance(uri, list) \
else [self._create_engine(uri, config)]
for engine in engines:
self.session_factories[name] = sessionmaker(bind=engine, expire_on_commit=False)
scoped_session_instance = scoped_session(self.session_factories[name])
if read_only:
self.read_only_engines[name].append(engine)
self.read_only_scoped_sessions[name].append(scoped_session_instance)
else:
self.engines[name] = engine
self.scoped_sessions[name] = scoped_session_instance
def _create_engine(self, url: str, config: Config): # pylint: disable=no-self-use
"""wrapper to set up our connections"""
engine = sqlalchemy.create_engine(
url,
pool_size=config.SQLALCHEMY_POOL_SIZE,
pool_recycle=config.SQLALCHEMY_POOL_RECYCLE,
echo=config.SQLALCHEMY_ECHO,
pool_pre_ping=config.SQLALCHEMY_POOL_PRE_PING
)
@contextmanager
def session(self, engine: DBInstance=None) -> Generator[scoped_session, None, None]:
"""
Generate a session and yield it out.
After resuming, commit, unless an exception happens,
in which case we roll back.
:param engine: connection to use
:return: a generator for a scoped session
"""
session = self.raw_scoped_session(engine)
try:
yield session
session.commit()
except:
session.rollback()
raise
finally:
session.remove()
def read_only_session(self, engine: str=None) -> scoped_session:
"""
Return a session for a read-only db
:param engine: connection to use
:return: a Session via scoped_session
"""
if engine in self.read_only_engines:
return random.choice(self.read_only_scoped_sessions[engine])
else:
raise DBConfigurationException(
"Requested session for '{}', which is not bound in this app. Try: [{}]".
format(engine, ','.join(self.read_only_engines.keys()))
)
# The global db factory instance.
db = DBSessionFactory()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment