Skip to content

Instantly share code, notes, and snippets.

@amir-shehzad
Forked from HacKanCuBa/sqlalchemy_helpers.py
Created March 17, 2024 07:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save amir-shehzad/0d95f1cebffef32230108169acdf086d to your computer and use it in GitHub Desktop.
Save amir-shehzad/0d95f1cebffef32230108169acdf086d to your computer and use it in GitHub Desktop.
SQLAlchemy handy helper functions
import functools
from contextlib import asynccontextmanager, contextmanager
from time import monotonic
from typing import Annotated, Any, AsyncGenerator, Generator, Hashable, Iterable, Literal, Optional, Sized, Union, overload
from sqlalchemy import event
from sqlalchemy.dialects.mysql.asyncmy import AsyncAdapt_asyncmy_cursor
from sqlalchemy.engine import URL, Connection, Engine, Row, create_engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import AsyncAdaptedQueuePool, QueuePool
AnyCacheable = Annotated[Hashable, "Any type that works well with functools.cache, meaning hashable (i.e., not dicts!)"]
@event.listens_for(Engine, "before_cursor_execute")
def _before_cursor_execute(conn: Connection, *_: Any) -> None:
conn.info.setdefault("query_start_time", []).append(monotonic())
# noinspection PyUnusedLocal
@event.listens_for(Engine, "after_cursor_execute")
def _after_cursor_execute(
conn: Connection,
cursor: AsyncAdapt_asyncmy_cursor,
statement: str,
parameters: tuple[dict[str, Any], ...] | dict[str, Any] | None,
*_: Any,
) -> None:
total = monotonic() - conn.info["query_start_time"].pop(-1)
# logger.debug('DB query\n\t%s\n\tparams: %s\n\tfinished in %f seconds', statement.replace("\n", ""), parameters, total)
@overload
def _get_db_engine(db_url: Union[str, URL], *, sync: Literal[True], **kwargs: AnyCacheable) -> Engine:
...
@overload
def _get_db_engine(db_url: Union[str, URL], *, sync: Literal[False], **kwargs: AnyCacheable) -> AsyncEngine:
...
@functools.cache
def _get_db_engine(db_url: Union[str, URL], *, sync: bool, **kwargs: AnyCacheable) -> Union[Engine, AsyncEngine]:
if "connect_args" in kwargs:
connect_args_raw = kwargs.pop("connect_args")
assert isinstance(connect_args_raw, Iterable) and all(
isinstance(arg, Sized) and len(arg) == 2 for arg in connect_args_raw
)
connect_args = dict(connect_args_raw)
else:
connect_args = {"connect_timeout": 5} # Some dialects use "timeout"
poolclass = QueuePool if sync else AsyncAdaptedQueuePool
# You may want to move some of this to some sort of global constant
params = {
"isolation_level": "READ COMMITTED", # See https://docs.sqlalchemy.org/en/20/core/connections.html#dbapi-autocommit
"echo": kwargs.pop("echo", False), # Don't be so verbose unless this is true
"future": True,
"connect_args": connect_args,
"poolclass": poolclass,
}
params.update(kwargs)
if sync:
return create_engine(db_url, **params)
return create_async_engine(db_url, **params)
@asynccontextmanager
async def async_db_engine(db_url: Union[str, URL], **kwargs: AnyCacheable) -> AsyncGenerator[AsyncEngine, None]:
"""Get a new async pooled engine ready to be used, as a context manager."""
engine = _get_db_engine(db_url, sync=False, **kwargs)
try:
yield engine
finally:
await engine.dispose()
@asynccontextmanager
async def async_db_session(engine: AsyncEngine, **kwargs: Any) -> AsyncGenerator[AsyncSession, None]:
"""Get a new async ORM session ready to be used, as a context manager."""
# You may want to move some of this to some sort of global constant
params = {
"expire_on_commit": False,
}
params.update(kwargs)
async_session = async_sessionmaker(engine, **params) # type: ignore[call-overload]
async with async_session() as session:
yield session
@contextmanager
def db_engine(db_url: Union[str, URL], **kwargs: Any) -> Generator[Engine, None, None]:
"""Get a new pooled engine ready to be used, as a context manager."""
engine = _get_db_engine(db_url, sync=True, **kwargs)
try:
yield engine
finally:
engine.dispose()
@contextmanager
def db_session(engine: Engine, **kwargs: Any) -> Generator[Session, None, None]:
"""Get a new ORM session ready to be used, as a context manager."""
params = {
"expire_on_commit": False,
}
params.update(kwargs)
session = sessionmaker(engine, **params) # type: ignore[call-overload]
with session() as session:
yield session
def asdict(row: Row) -> dict[str, Any]:
"""Convert a row to a dict."""
# Yeah, I have no idea why it's a protected method, but it is properly documented, and we are supposed to use this.
# See: https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.Row._asdict
# noinspection PyProtectedMember
dct = row._asdict() # this may have keys as `sqlalchemy.sql.elements.quoted_name` instead of str
return {f"{key}": value for key, value in dct.items()}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment