Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save sandys/671b8b86ba913e6436d4cb22d04b135f to your computer and use it in GitHub Desktop.
Save sandys/671b8b86ba913e6436d4cb22d04b135f to your computer and use it in GitHub Desktop.
fastapi with python 3.10 dataclasses - used to create both sqlalchemy and pydantic models simultaneously. And setting up sqlalchemy the right way (without deadlocks or other problems). Additionally, this also takes care of unified logging when running under gunicorn..as well as being able to run in restartable mode.

why this code ?

P.S. How to quickly run postgres (using docker)

This code was tested on Windows 11 WSL2 (Ubuntu VM)

  1. docker run -it --rm -p 5433:5432 --name some-postgres -e POSTGRES_PASSWORD=mysecretpassword -e PGDATA=/var/lib/postgresql/data -v /tmp/pgdata:/var/lib/postgresql/data -e POSTGRES_USER=test postgres . This command will quickly start postgres on port 5433 and create a database test with user test and password mysecretpassword. The reason I like to use 5433 is because many times i have seen people having a normal, default installation on 5432 and it causes a lot of mistakes/confusion.
  2. docker inspect -f '{{.Name}} - {{.NetworkSettings.IPAddress }}' $(docker ps -q). This will give you ip address of the postgres container.
  3. docker run --network=host -it --rm postgres psql postgresql://test:mysecretpassword@0.0.0.0:5433/test . this will connect to your localhost on port 5433

cmdline

gunicorn sync_p:app -w 1 -k sync_p.RestartableUvicornWorker --logger-class sync_p.GunicornLogger

benchmark

docker run --network=host --rm skandyla/wrk -t12 -c400 -d30s http://localhost:8000/users/

IMPORTANT - do not forget the trailing slash (at the end of "users/". otherwise wrk will silently croak and u wont know.

IMPORTANT - expire_on_commit=False is the important setting here. if u dont set it, the following error will happen

user = UserFactory()
session.add(user)
session.commit()

# missing session.refresh(user) and causing the problem

return user

created an object, added and ,committed it to the db and after that I tried to access on of the original object attributes without refreshing session session.refresh(object)

NOTE: I have removed the async code from here. right now it is very unstable because of inherent issues in fastapi (like fastapi/full-stack-fastapi-template#290 and fastapi/fastapi#726 (comment)) . Using async in fastapi with sqlalchemy is absolutely not recommended right now.

# ...
# settings.py (or settings obj)
# from main import SQLALCHEMY_DATABASE_URL
LOG_LEVEL = "DEBUG" # (or 10 if `logging.DEBUG`)
# custom handlers removed, we catch logs via loguru
UVICORN_LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"()": "uvicorn.logging.DefaultFormatter",
"fmt": "%(levelprefix)s %(message)s",
"use_colors": None,
},
"access": {
"()": "uvicorn.logging.AccessFormatter",
"fmt": '%(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s',
},
},
"loggers": {
"uvicorn": {"level": "INFO"},
"uvicorn.error": {"level": "INFO"},
"uvicorn.access": {"level": "INFO", "propagate": False},
},
}
SQLALCHEMY_DATABASE_URL = "postgresql://test:mysecretpassword@0.0.0.0:5433/test"
from __future__ import annotations
import decimal
import logging
import os
import signal
import sys
import threading
import time
from contextlib import contextmanager
from dataclasses import (
dataclass,
field,
)
from datetime import datetime
from functools import lru_cache
from typing import (
Any,
Dict,
Iterable,
Generator,
List,
Optional,
Tuple,
)
import pydantic
import typer
import uvicorn as uvicorn
import yaml
from fastapi import (
BackgroundTasks,
Depends,
FastAPI,
HTTPException,
Response,
status,
)
from fastapi.responses import PlainTextResponse
from gunicorn.glogging import Logger
from loguru import logger
from pydantic import BaseModel
from pydantic_sqlalchemy import sqlalchemy_to_pydantic
from sqlalchemy import (
ARRAY,
DECIMAL,
TEXT,
TIMESTAMP,
BigInteger,
Boolean,
CheckConstraint,
Column,
Date,
DateTime,
Enum,
Float,
ForeignKey,
Index,
Integer,
Numeric,
PrimaryKeyConstraint,
String,
Table,
Text,
UniqueConstraint,
and_,
create_engine,
engine,
event,
func,
or_,
)
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import (
Session,
registry,
relationship,
sessionmaker,
)
from sqlalchemy.schema import Index
from starlette.middleware.cors import CORSMiddleware
from uvicorn.workers import UvicornWorker
import settings
# from cmath import log
class ReloaderThread(threading.Thread):
def __init__(self, worker: UvicornWorker, sleep_interval: float = 1.0):
super().__init__()
self.setDaemon(True)
self._worker = worker
self._interval = sleep_interval
def run(self) -> None:
while True:
if not self._worker.alive:
os.kill(os.getpid(), signal.SIGINT)
time.sleep(self._interval)
class RestartableUvicornWorker(UvicornWorker):
CONFIG_KWARGS = {
"loop": "uvloop",
"http": "httptools",
# "log_config": yaml.safe_load(open(os.path.join(os.path.dirname(__file__), "logging.yaml"), "r")
}
def __init__(self, *args: List[Any], **kwargs: Dict[str, Any]):
super().__init__(*args, **kwargs)
self._reloader_thread = ReloaderThread(self)
def run(self) -> None:
if self.cfg.reload:
self._reloader_thread.start()
super().run()
class InterceptHandler(logging.Handler):
"""
Default handler from examples in loguru documentaion.
See https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging
"""
def emit(self, record: logging.LogRecord):
# Get corresponding Loguru level if it exists
try:
level = logger.level(record.levelname).name
except ValueError:
level = record.levelno
# Find caller from where originated the logged message
frame, depth = logging.currentframe(), 1
# while frame.f_code.co_filename == logging.__file__:
# frame = frame.f_back
# depth += 1
logger.opt(depth=depth, exception=record.exc_info).log(
level, record.getMessage()
)
class GunicornLogger(Logger):
def setup(self, cfg) -> None:
handler = InterceptHandler()
# handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
logging.Formatter("%(asctime)s %(name)-12s %(levelname)-8s %(message)s")
)
# Add log handler to logger and set log level
self.error_log.addHandler(handler)
self.error_log.setLevel(settings.LOG_LEVEL)
self.access_log.addHandler(handler)
self.access_log.setLevel(settings.LOG_LEVEL)
# Configure logger before gunicorn starts logging
logger.configure(handlers=[{"sink": sys.stdout, "level": settings.LOG_LEVEL}])
@lru_cache()
def get_engine() -> engine.Engine:
return create_engine(
settings.SQLALCHEMY_DATABASE_URL,
# connect_args={"check_same_thread": False},
echo=True,
pool_pre_ping=True,
)
def get_db() -> Generator[Session, None, None]:
# Explicit type because sessionmaker.__call__ stub is Any
session: Session = sessionmaker(
autocommit=False, autoflush=False,expire_on_commit=False, bind=get_engine()
)()
# session = SessionLocal()
try:
yield session
session.commit()
except:
session.rollback()
raise
finally:
session.close()
mapper_registry = registry()
@dataclass
class SurrogatePK:
__sa_dataclass_metadata_key__ = "sa"
id: int = field(
init=False,
default=None,
metadata={"sa": Column(Integer, primary_key=True, autoincrement=True)},
)
@dataclass
class TimeStampMixin:
__sa_dataclass_metadata_key__ = "sa"
created_at: datetime = field(
default_factory=datetime.now,
metadata={"sa": Column(DateTime, default=datetime.now)},
)
updated_at: datetime = field(
default_factory=datetime.now,
metadata={
"sa": Column(DateTime, default=datetime.now, onupdate=datetime.utcnow)
},
)
@mapper_registry.mapped
@dataclass
class User(SurrogatePK, TimeStampMixin):
__tablename__ = "user"
__sa_dataclass_metadata_key__ = "sa"
title: str = field(default=None, metadata={"sa": Column(String(50))})
description: str = field(default=None, metadata={"sa": Column(String(50))})
UserPyd = sqlalchemy_to_pydantic(User)
mapper_registry.metadata.create_all(bind=get_engine())
# Create the app, database, and stocks table
app = FastAPI(debug=True)
@app.exception_handler(Exception)
async def validation_exception_handler(request, exc):
logger.debug(str(exc))
return PlainTextResponse("Something went wrong", status_code=500)
cli = typer.Typer()
@cli.command()
def db_init_models():
Base = mapper_registry.generate_base()
Base.metadata.drop_all(bind=get_engine())
Base.metadata.create_all(bind=get_engine())
print("Done")
@cli.command()
def nothing(name: str):
print("Done")
@app.get("/items", response_model=List[UserPyd])
def read_items(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
items = db.query(User).offset(skip).limit(limit).all()
return items
@app.get("/users/", response_model=UserPyd, status_code=status.HTTP_201_CREATED)
def create_user(email: str = None, db: Session = Depends(get_db)):
u = User(title="sss")
db.add(u)
db.commit()
# return {"data":new_post}
return u
if __name__ == "__main__":
cli()
@lovetoburnswhen
Copy link

get_db

fair point - at that time, I couldnt figure it out. does it work for you ? wondering if u tested it.

Yep, pyright and mypy seem happy

@pozsa
Copy link

pozsa commented Jan 11, 2022

I use

from typing import AsyncIterator

from sqlalchemy.ext.asyncio import AsyncSession

async def get_db() -> AsyncIterator[AsyncSession]:

mypy has been happy so far

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment