Last active
September 7, 2023 13:55
-
-
Save truetug/f0637642f216c260ededb6037e858a1b to your computer and use it in GitHub Desktop.
Async Sqlalchemy for FastAPI in 2023
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from asyncio import current_task | |
from contextlib import asynccontextmanager, contextmanager | |
from dataclasses import dataclass, field | |
from typing import Any, AsyncIterator, Iterator | |
from sqlalchemy import create_engine | |
from sqlalchemy.engine.url import make_url | |
from sqlalchemy.ext.asyncio import ( | |
AsyncSession, | |
async_scoped_session, | |
async_sessionmaker, | |
create_async_engine, | |
) | |
from sqlalchemy.orm import Session, declarative_base, scoped_session, sessionmaker | |
__all__ = ["DBConfig", "DB", "db", "BaseModel"] | |
logger = logging.getLogger(__name__) | |
@dataclass | |
class DBConfig: | |
dsn: str | |
pool_size: int = 50 | |
pool_pre_ping: bool = True | |
echo: bool = False | |
auto_commit: bool = False | |
auto_flush: bool = False | |
expire_on_commit: bool = False | |
executemany_mode: str = "" | |
disable_test_suffix: bool = False | |
def get_dsn_as_dict(self) -> dict[str, Any]: | |
conf_url = make_url(self.dsn) | |
return { | |
"username": conf_url.username, | |
"database": conf_url.database, | |
"port": conf_url.port, | |
"host": conf_url.host, | |
"password": conf_url.password, | |
} | |
def get_dsn_as_safe_url(self) -> str: | |
conf_url = make_url(self.dsn) | |
return f"{conf_url.username}:***@{conf_url.host}:{conf_url.port}/{conf_url.database}" | |
@dataclass | |
class DB: | |
config: DBConfig = field(init=False, repr=False) | |
session: Session = field(init=False, repr=False) | |
create_engine_func = create_engine | |
session_class = Session | |
def __getattribute__(self, name: str) -> Any: | |
try: | |
return object.__getattribute__(self, name) | |
except AttributeError as exc: | |
if name in ["config", "session"]: | |
print(f"DB: You need to call setup() for getting attribute {name}") | |
raise exc | |
def get_session_kwargs(self): | |
return { | |
"autocommit": self.config.auto_commit, | |
"autoflush": self.config.auto_flush, | |
"expire_on_commit": self.config.expire_on_commit, | |
"bind": self.engine, | |
"class_": self.session_class, | |
} | |
def get_session(self): | |
return scoped_session(sessionmaker(**self.get_session_kwargs())) | |
def get_engine_kwargs(self) -> dict[str, Any]: | |
return { | |
"url": self.config.dsn, | |
"pool_size": self.config.pool_size, | |
"pool_pre_ping": self.config.pool_pre_ping, | |
"echo": self.config.echo, | |
} | |
def get_engine(self): | |
func = self.create_engine_func.__func__ # type: ignore[misc, attr-defined] | |
return func(**self.get_engine_kwargs()) | |
def setup(self, config: DBConfig) -> None: | |
logger.info( | |
"Init database connection pool: %s", | |
config.get_dsn_as_safe_url(), | |
) | |
self.config = config | |
self.engine = self.get_engine() # pylint:disable=attribute-defined-outside-init | |
self.session = self.get_session() | |
def shutdown(self) -> None: | |
logger.info("Shutdown database connection pool") | |
self.session.remove() # type: ignore[attr-defined] | |
@contextmanager | |
def session_scope(self) -> Iterator[None]: | |
yield self.session # type: ignore[misc] | |
self.session.close() | |
def get_status_info(self) -> tuple[dict[str, Any], bool]: | |
status = True | |
session = self.session() # type: ignore[operator] | |
try: | |
session.execute("select version();") | |
except Exception as e: # pylint: disable=broad-except | |
status &= False | |
logger.exception(e) | |
finally: | |
session.close() | |
# fmt: off | |
return ( | |
{"status": "OK"} | |
if status else | |
{"status": "FAILED"} | |
), status | |
# fmt: on | |
class AsyncDB(DB): | |
session: AsyncSession = field(init=False, repr=False) # type: ignore[assignment] | |
create_engine_func = create_async_engine # type: ignore[assignment] | |
session_class = AsyncSession # type: ignore[assignment] | |
def get_engine_kwargs(self): | |
result = super().get_engine_kwargs() | |
result.pop("executemany_mode", None) | |
return result | |
def get_session(self): | |
return async_scoped_session( | |
async_sessionmaker(**self.get_session_kwargs()), scopefunc=current_task | |
) | |
@asynccontextmanager | |
async def session_scope(self) -> AsyncIterator[AsyncSession]: # type: ignore[override] | |
try: | |
yield self.session | |
except Exception as exc: | |
logger.error("Exception in session_scope: %s", exc) | |
raise | |
finally: | |
await self.session.close() | |
db = AsyncDB() | |
async def get_db_session() -> AsyncIterator[AsyncSession]: | |
"""Initialise session for usage in depends.""" | |
async with db.session() as session: # type: ignore[operator] | |
try: | |
yield session | |
finally: | |
await session.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
SHELL=/bin/bash | |
.PHONY: lint-yml | |
lint-yml: ## Lints all yaml files with yamllint docker image | |
@echo "* Lint YAML" | |
@yamllint -f parsable -c .yamllint . | |
@echo "done" | |
.PHONY: check-isort | |
check-isort: ## Check isort | |
@echo "* Check lint code" | |
@echo "Run isort" | |
@exec isort --check-only . | |
.PHONY: check-black | |
check-black: ## Check black | |
@echo "Run black" | |
@exec black --check --diff . | |
.PHONY: flake | |
flake: ## Check flake | |
@echo "Run flake" | |
@exec pflake8 . | |
.PHONY: vulture | |
vulture: ## Check vulture | |
@echo "Run vulture" | |
@exec vulture app | |
.PHONY: bandit | |
bandit: ## Check bandit | |
@echo "Run bandit" | |
@exec bandit -r app/* | |
.PHONY: mypy | |
mypy: ## Check mypy | |
@echo "Run mypy" | |
@exec mypy . | |
.PHONY: check # Runs linters only for check | |
check: check-isort check-black flake vulture bandit mypy ## Run all checks | |
.PHONY: autolint | |
autolint: ## Run linters | |
@echo "* Lint code" | |
@echo "Run autoflake" | |
@exec autoflake -r -i --remove-all-unused-imports --ignore-init-module-imports . | |
@echo "Run isort" | |
@exec isort . | |
@echo "Run black" | |
@exec black . | |
.PHONY: lint | |
lint: autolint flake vulture bandit mypy ## Runs linters and fixes auto-fixable errors | |
.PHONY: test # Runs tests | |
test: | |
@echo "* Run tests" | |
pytest -svvv -rs --cov app --cov-report term-missing -x | |
.PHONY: coverage | |
coverage: ## Runs tests with coverage | |
@echo "* Run tests with coverage" | |
coverage erase | |
coverage run | |
coverage xml -i | |
.PHONY: migrations | |
migrations: ## Generate migrations_psycopg2 for your models | |
@echo "* Alembic Revision Autogenerate" | |
alembic revision --autogenerate | |
.PHONY: migrate | |
migrate: ## Migrate database | |
@echo "* Alembic Upgrade Head" | |
alembic upgrade head | |
.PHONY: sqlmigrate | |
sqlmigrate: ## SQL to migrate database | |
@echo "* Alembic show SQL for Upgrade Head" | |
alembic upgrade head --sql | |
.PHONY: requirements | |
requirements: ## Lock current requirements | |
@echo "* Lock requirements" | |
pip install -U pip-tools | |
pip-compile \ | |
--verbose \ | |
--generate-hashes \ | |
--no-emit-index-url \ | |
-r requirements/base.in \ | |
-o requirements/base.txt | |
.PHONY: clean-extra | |
clean-extra: ## Cleans all temporary files (caches, dists, etc.) | |
@echo -n "* Clean temp files... " | |
@rm -rf `find . -name __pycache__` | |
@rm -rf `find . -type f -name '*.py[co]' ` | |
@rm -rf `find . -type f -name '*~' ` | |
@rm -rf `find . -type f -name '.*~' ` | |
@rm -rf `find . -type f -name '@*' ` | |
@rm -rf `find . -type f -name '#*#' ` | |
@rm -rf `find . -type f -name '*.orig' ` | |
@rm -rf `find . -type f -name '*.rej' ` | |
@rm -rf .coverage | |
@rm -rf coverage.html | |
@rm -rf coverage.xml | |
@rm -rf htmlcov | |
@rm -rf build | |
@rm -rf cover | |
@test -f setup.py && python setup.py clean || echo -n "" | |
@rm -rf .develop | |
@rm -rf .flake | |
@rm -rf .install-deps | |
@rm -rf *.egg-info | |
@rm -rf .mypy_cache | |
@rm -rf .pytest_cache | |
@rm -rf dist | |
@echo "done" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment