Skip to content

Instantly share code, notes, and snippets.

@Compro-Prasad
Last active December 10, 2023 13:43
Show Gist options
  • Save Compro-Prasad/87dc9942e94296e4d98a11403c915135 to your computer and use it in GitHub Desktop.
Save Compro-Prasad/87dc9942e94296e4d98a11403c915135 to your computer and use it in GitHub Desktop.
Sqlalchemy base template for FastAPI projects
import os
from typing import Any, AsyncGenerator, Generator
from datetime import datetime
from alembic_utils.pg_trigger import PGTrigger
from alembic_utils.pg_function import PGFunction
from sqlalchemy import create_engine
from sqlalchemy import func
from sqlalchemy.orm import Mapped as T
from sqlalchemy.orm import mapped_column as column
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.ext.declarative import as_declarative, declared_attr
from sqlalchemy.orm import sessionmaker
def updated_at_trigger(tablename):
return PGTrigger(
schema="public",
signature=f"{tablename}_set_updated_at_on_update",
on_entity=tablename,
definition=f"""
BEFORE UPDATE ON {tablename}
FOR EACH ROW
EXECUTE PROCEDURE set_updated_at();
""",
)
updated_at_trigger.function = PGFunction(
schema="public",
signature="set_updated_at()", # Can be reused for any table with column updated_at
definition="""
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at := now();
return NEW;
END;
$$ language 'plpgsql'
""",
)
pat1 = re.compile("[A-Z]{2,}")
pat2 = re.compile(r"(?<!^)(?=[A-Z])")
@as_declarative()
class Base:
id: T[int] = column(primary_key=True, autoincrement=True)
created_at: T[datetime] = column(server_default=func.now())
updated_at: T[datetime] = column(server_default=func.now())
_name_: str
# Generate _tablename_ automatically
@declared_attr
def _tablename_(cls) -> str:
name = cls._name_
assert not pat1.findall(name), "Use proper camel case for model names"
return pat2.sub("_", name).lower()
engine = create_engine(os.getenv("DATABASE_URL"), pool_pre_ping=True)
session_maker = sessionmaker(bind=engine)
aio_engine = create_async_engine(
os.getenv("DATABASE_URL").replace("postgresql://", "postgresql+asyncpg://"),
pool_pre_ping=True,
)
aio_session_maker = async_sessionmaker(engine)
def get_session() -> Generator:
try:
db = session_maker()
yield db
finally:
db.close()
async def get_aio_session() -> AsyncGenerator[AsyncSession, None]:
try:
db = aio_session_maker()
yield db
finally:
await db.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment