|
from datetime import datetime |
|
from fastapi import BackgroundTasks, Depends, FastAPI |
|
from pydantic import BaseModel |
|
from sqlalchemy import ( |
|
Column, |
|
create_engine, |
|
DateTime, |
|
TIMESTAMP, |
|
Boolean, |
|
Numeric, |
|
Integer, |
|
String, |
|
engine, |
|
Table, |
|
ForeignKey, |
|
ARRAY, |
|
) |
|
from sqlalchemy import ( |
|
DECIMAL, |
|
TEXT, |
|
TIMESTAMP, |
|
BigInteger, |
|
Boolean, |
|
CheckConstraint, |
|
Column, |
|
Date, |
|
Enum, |
|
Float, |
|
ForeignKey, |
|
Index, |
|
Integer, |
|
Numeric, |
|
PrimaryKeyConstraint, |
|
String, |
|
Text, |
|
UniqueConstraint, |
|
and_, |
|
create_engine, |
|
event, |
|
func, |
|
or_, |
|
) |
|
from sqlalchemy.orm import Session, sessionmaker |
|
from sqlalchemy import select |
|
from sqlalchemy.ext.declarative import declared_attr |
|
from starlette.middleware.cors import CORSMiddleware |
|
|
|
import decimal |
|
from sqlalchemy.schema import Index |
|
from typing import Optional, Dict, List, Any, Tuple |
|
from contextlib import asynccontextmanager |
|
from functools import lru_cache |
|
|
|
# from async_lru import alru_cache as async_lru_cache |
|
|
|
|
|
from typing import List |
|
from typing import Optional |
|
|
|
from dataclasses import dataclass |
|
from dataclasses import field, dataclass |
|
from sqlalchemy.orm import registry |
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine |
|
from sqlalchemy.ext.asyncio import create_async_engine |
|
|
|
import pydantic |
|
import asyncio |
|
import typer |
|
|
|
# Standard for SQLite |
|
# SQLALCHEMY_DATABASE_URL = "sqlite:///test10.db" |
|
SQLALCHEMY_DATABASE_URL = "postgresql+asyncpg://test@localhost:5432/test" |
|
|
|
mapper_registry = registry() |
|
|
|
|
|
@lru_cache() |
|
def get_engine() -> AsyncEngine: |
|
return create_async_engine( |
|
SQLALCHEMY_DATABASE_URL, |
|
# connect_args={"check_same_thread": False}, |
|
pool_pre_ping=True, |
|
) |
|
|
|
|
|
@asynccontextmanager |
|
async def get_db(db_conn=Depends(get_engine)) -> AsyncSession: |
|
# Explicit type because sessionmaker.__call__ stub is Any |
|
# e = await get_engine() |
|
session: AsyncSession = sessionmaker( |
|
autocommit=False, |
|
autoflush=False, |
|
bind=db_conn, |
|
class_=AsyncSession, |
|
expire_on_commit=False, |
|
)() |
|
try: |
|
yield session |
|
await session.commit() |
|
except: |
|
await session.rollback() |
|
raise |
|
finally: |
|
await session.close() |
|
|
|
|
|
@dataclass |
|
class SurrogatePK: |
|
__sa_dataclass_metadata_key__ = "sa" |
|
id: int = field( |
|
init=False, metadata={"sa": Column(Integer, primary_key=True)}, |
|
) |
|
|
|
|
|
@dataclass |
|
class TimeStampMixin: |
|
__sa_dataclass_metadata_key__ = "sa" |
|
created_at: datetime = field( |
|
init=False, metadata={"sa": Column(DateTime, default=datetime.utcnow)} |
|
) |
|
updated_at: datetime = field( |
|
init=False, |
|
metadata={ |
|
"sa": Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) |
|
}, |
|
) |
|
|
|
|
|
# @mapper_registry.mapped |
|
@dataclass |
|
class User(SurrogatePK, TimeStampMixin): |
|
__tablename__ = "user" |
|
|
|
identity: Optional[str] = field( |
|
default=None, metadata={"sa": Column(String(length=255), nullable=False)} |
|
) |
|
|
|
row_status: Optional[str] = field( |
|
default=None, metadata={"sa": Column(String(length=20), nullable=False)} |
|
) |
|
|
|
@declared_attr |
|
def __table_args__(cls): |
|
return ( |
|
Index( |
|
"index_on_identity_v3_user_identity", |
|
"identity", |
|
"row_status", |
|
unique=True, |
|
postgresql_where=cls.row_status == "active", |
|
), |
|
) |
|
|
|
|
|
@mapper_registry.mapped |
|
@dataclass |
|
class UserSQL(User): |
|
pass |
|
|
|
|
|
UserPyd = pydantic.dataclasses.dataclass(User) |
|
|
|
# Create the app, database, and stocks table |
|
app = FastAPI() |
|
cli = typer.Typer() |
|
|
|
Base = mapper_registry.generate_base() |
|
|
|
|
|
async def init_models(): |
|
# e = await get_engine() |
|
async with get_engine().begin() as conn: |
|
await conn.run_sync(Base.metadata.drop_all) |
|
await conn.run_sync(Base.metadata.create_all) |
|
|
|
|
|
@cli.command() |
|
def db_init_models(name: str): |
|
asyncio.run(init_models()) |
|
print("Done") |
|
|
|
|
|
@app.on_event("startup") |
|
def open_database_connection_pools(): |
|
get_engine() |
|
|
|
|
|
@app.on_event("shutdown") |
|
def close_database_connection_pools(): |
|
_db_conn = get_engine() |
|
if _db_conn: |
|
_db_conn.dispose() |
|
|
|
|
|
# init_models() |
|
|
|
|
|
@app.get("/", response_model=List[UserPyd]) |
|
async def foo(context_session: AsyncSession = Depends(get_db)): |
|
|
|
async with context_session as db: |
|
# Query stocks table and print results |
|
query = await db.execute(select(UserSQL)) |
|
for d in query: |
|
print( |
|
f"""{d.identity}\t |
|
{d.row_status}\t |
|
{d.created_at}\t |
|
{d.updated_at}""" |
|
) |
|
|
|
return query.scalars().all() |
|
|
|
|
|
if __name__ == "__main__": |
|
cli() |