Skip to content

Instantly share code, notes, and snippets.

@truetug
Last active September 7, 2023 13:55
Show Gist options
  • Save truetug/f0637642f216c260ededb6037e858a1b to your computer and use it in GitHub Desktop.
Save truetug/f0637642f216c260ededb6037e858a1b to your computer and use it in GitHub Desktop.
Async Sqlalchemy for FastAPI in 2023
import logging
from asyncio import current_task
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass, field
from typing import Any, AsyncIterator, Iterator
from sqlalchemy import create_engine
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_scoped_session,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import Session, declarative_base, scoped_session, sessionmaker
__all__ = ["DBConfig", "DB", "db", "BaseModel"]
logger = logging.getLogger(__name__)
@dataclass
class DBConfig:
dsn: str
pool_size: int = 50
pool_pre_ping: bool = True
echo: bool = False
auto_commit: bool = False
auto_flush: bool = False
expire_on_commit: bool = False
executemany_mode: str = ""
disable_test_suffix: bool = False
def get_dsn_as_dict(self) -> dict[str, Any]:
conf_url = make_url(self.dsn)
return {
"username": conf_url.username,
"database": conf_url.database,
"port": conf_url.port,
"host": conf_url.host,
"password": conf_url.password,
}
def get_dsn_as_safe_url(self) -> str:
conf_url = make_url(self.dsn)
return f"{conf_url.username}:***@{conf_url.host}:{conf_url.port}/{conf_url.database}"
@dataclass
class DB:
config: DBConfig = field(init=False, repr=False)
session: Session = field(init=False, repr=False)
create_engine_func = create_engine
session_class = Session
def __getattribute__(self, name: str) -> Any:
try:
return object.__getattribute__(self, name)
except AttributeError as exc:
if name in ["config", "session"]:
print(f"DB: You need to call setup() for getting attribute {name}")
raise exc
def get_session_kwargs(self):
return {
"autocommit": self.config.auto_commit,
"autoflush": self.config.auto_flush,
"expire_on_commit": self.config.expire_on_commit,
"bind": self.engine,
"class_": self.session_class,
}
def get_session(self):
return scoped_session(sessionmaker(**self.get_session_kwargs()))
def get_engine_kwargs(self) -> dict[str, Any]:
return {
"url": self.config.dsn,
"pool_size": self.config.pool_size,
"pool_pre_ping": self.config.pool_pre_ping,
"echo": self.config.echo,
}
def get_engine(self):
func = self.create_engine_func.__func__ # type: ignore[misc, attr-defined]
return func(**self.get_engine_kwargs())
def setup(self, config: DBConfig) -> None:
logger.info(
"Init database connection pool: %s",
config.get_dsn_as_safe_url(),
)
self.config = config
self.engine = self.get_engine() # pylint:disable=attribute-defined-outside-init
self.session = self.get_session()
def shutdown(self) -> None:
logger.info("Shutdown database connection pool")
self.session.remove() # type: ignore[attr-defined]
@contextmanager
def session_scope(self) -> Iterator[None]:
yield self.session # type: ignore[misc]
self.session.close()
def get_status_info(self) -> tuple[dict[str, Any], bool]:
status = True
session = self.session() # type: ignore[operator]
try:
session.execute("select version();")
except Exception as e: # pylint: disable=broad-except
status &= False
logger.exception(e)
finally:
session.close()
# fmt: off
return (
{"status": "OK"}
if status else
{"status": "FAILED"}
), status
# fmt: on
class AsyncDB(DB):
session: AsyncSession = field(init=False, repr=False) # type: ignore[assignment]
create_engine_func = create_async_engine # type: ignore[assignment]
session_class = AsyncSession # type: ignore[assignment]
def get_engine_kwargs(self):
result = super().get_engine_kwargs()
result.pop("executemany_mode", None)
return result
def get_session(self):
return async_scoped_session(
async_sessionmaker(**self.get_session_kwargs()), scopefunc=current_task
)
@asynccontextmanager
async def session_scope(self) -> AsyncIterator[AsyncSession]: # type: ignore[override]
try:
yield self.session
except Exception as exc:
logger.error("Exception in session_scope: %s", exc)
raise
finally:
await self.session.close()
db = AsyncDB()
async def get_db_session() -> AsyncIterator[AsyncSession]:
"""Initialise session for usage in depends."""
async with db.session() as session: # type: ignore[operator]
try:
yield session
finally:
await session.close()
SHELL=/bin/bash
.PHONY: lint-yml
lint-yml: ## Lints all yaml files with yamllint docker image
@echo "* Lint YAML"
@yamllint -f parsable -c .yamllint .
@echo "done"
.PHONY: check-isort
check-isort: ## Check isort
@echo "* Check lint code"
@echo "Run isort"
@exec isort --check-only .
.PHONY: check-black
check-black: ## Check black
@echo "Run black"
@exec black --check --diff .
.PHONY: flake
flake: ## Check flake
@echo "Run flake"
@exec pflake8 .
.PHONY: vulture
vulture: ## Check vulture
@echo "Run vulture"
@exec vulture app
.PHONY: bandit
bandit: ## Check bandit
@echo "Run bandit"
@exec bandit -r app/*
.PHONY: mypy
mypy: ## Check mypy
@echo "Run mypy"
@exec mypy .
.PHONY: check # Runs linters only for check
check: check-isort check-black flake vulture bandit mypy ## Run all checks
.PHONY: autolint
autolint: ## Run linters
@echo "* Lint code"
@echo "Run autoflake"
@exec autoflake -r -i --remove-all-unused-imports --ignore-init-module-imports .
@echo "Run isort"
@exec isort .
@echo "Run black"
@exec black .
.PHONY: lint
lint: autolint flake vulture bandit mypy ## Runs linters and fixes auto-fixable errors
.PHONY: test # Runs tests
test:
@echo "* Run tests"
pytest -svvv -rs --cov app --cov-report term-missing -x
.PHONY: coverage
coverage: ## Runs tests with coverage
@echo "* Run tests with coverage"
coverage erase
coverage run
coverage xml -i
.PHONY: migrations
migrations: ## Generate migrations_psycopg2 for your models
@echo "* Alembic Revision Autogenerate"
alembic revision --autogenerate
.PHONY: migrate
migrate: ## Migrate database
@echo "* Alembic Upgrade Head"
alembic upgrade head
.PHONY: sqlmigrate
sqlmigrate: ## SQL to migrate database
@echo "* Alembic show SQL for Upgrade Head"
alembic upgrade head --sql
.PHONY: requirements
requirements: ## Lock current requirements
@echo "* Lock requirements"
pip install -U pip-tools
pip-compile \
--verbose \
--generate-hashes \
--no-emit-index-url \
-r requirements/base.in \
-o requirements/base.txt
.PHONY: clean-extra
clean-extra: ## Cleans all temporary files (caches, dists, etc.)
@echo -n "* Clean temp files... "
@rm -rf `find . -name __pycache__`
@rm -rf `find . -type f -name '*.py[co]' `
@rm -rf `find . -type f -name '*~' `
@rm -rf `find . -type f -name '.*~' `
@rm -rf `find . -type f -name '@*' `
@rm -rf `find . -type f -name '#*#' `
@rm -rf `find . -type f -name '*.orig' `
@rm -rf `find . -type f -name '*.rej' `
@rm -rf .coverage
@rm -rf coverage.html
@rm -rf coverage.xml
@rm -rf htmlcov
@rm -rf build
@rm -rf cover
@test -f setup.py && python setup.py clean || echo -n ""
@rm -rf .develop
@rm -rf .flake
@rm -rf .install-deps
@rm -rf *.egg-info
@rm -rf .mypy_cache
@rm -rf .pytest_cache
@rm -rf dist
@echo "done"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment