Skip to content

Instantly share code, notes, and snippets.

@michaeltoohig
Last active March 4, 2022 10:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save michaeltoohig/15d7b12913498959987afa516fd7f8c9 to your computer and use it in GitHub Desktop.
Save michaeltoohig/15d7b12913498959987afa516fd7f8c9 to your computer and use it in GitHub Desktop.
SQLAlchemy-Continuum plugin for FastAPI with FastAPI-Users and async SQLAlchemy

The files here are the "full example" from the FastAPI-Users documenation plus some modifications for async SQLAlchemy 1.4 (maybe also 2.0?). All together it produces a simple example of using SQLAlchemy-Continuum with the transaction table having a relationship to the user.

But for most users I think the plugin.py, middleware.py and users.py files gives the basic gist of what is needed to use FastAPI with SQLAlchemy-Continuum.

EDIT: it appears less than two weeks ago an update to FastAPI may make all of the issues I was facing moot. https://github.com/tiangolo/fastapi/releases/tag/0.74.1 Keep an eye on newer versions of FastAPI and see if this gist is relevant to you in the future.

file tree

.
├── app
│   ├── crud.py
│   ├── database.py
│   ├── help.md
│   ├── main.py
│   ├── middlewares.py
│   ├── models.py
│   ├── plugin.py
│   ├── schemas.py
│   └── users.py
├── scripts
│   └── init.sh
├── docker-compose.yml
├── main.py
├── poetry.lock
└── pyproject.toml

dependencies

[tool.poetry.dependencies]
python = "^3.8"
fastapi = "^0.74.1"
uvicorn = "^0.17.5"
asyncpg = "^0.25.0"
sqlalchemy-continuum = {git = "https://github.com/kvesteri/sqlalchemy-continuum.git"}
fastapi-users = "^9.2.5"
fastapi-users-db-sqlalchemy = "^2.0.4"
starlette-context = "^0.3.3"

docker-compose.yml

version: "3.9"

services:
  db:
    image: postgres:10
    environment:
      - POSTGRES_USER=db_user
      - POSTGRES_PASSWORD=db_pass
      - PGDATA=/var/lib/postgresql/data/pgdata
    ports:
      - "5432:5432"
    volumes:
      - ./scripts:/docker-entrypoint-initdb.d

  admin:
    image: adminer
    restart: always
    ports:
      - 8081:8080
    depends_on: 
      - db

scripts/init.sh

psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" <<EOF
CREATE EXTENSION IF NOT EXISTS btree_gist;
EOF
import abc
from typing import Generic, List, Type, TypeVar
from sqlalchemy import select, update
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import NoResultFound
from . import models, schemas
IN_SCHEMA = TypeVar("IN_SCHEMA", bound=BaseModel)
SCHEMA = TypeVar("SCHEMA", bound=BaseModel)
TABLE = TypeVar("TABLE")
class CRUDBase(Generic[TABLE, IN_SCHEMA, SCHEMA], metaclass=abc.ABCMeta):
def __init__(self, db_session: AsyncSession, *args, **kwargs) -> None:
self._db_session: AsyncSession = db_session
@property
@abc.abstractmethod
def _table(self) -> Type[TABLE]:
...
@property
@abc.abstractmethod
def _schema(self) -> Type[SCHEMA]:
...
async def create(self, in_schema: IN_SCHEMA) -> SCHEMA:
item = self._table(**in_schema.dict())
self._db_session.add(item)
await self._db_session.commit()
return self._schema.from_orm(item)
async def _get_one(self, item_id: int):
query = (
select(self._table)
.filter(self._table.id == item_id)
)
try:
item = (await self._db_session.execute(query)).scalar_one()
except NoResultFound:
item = None
return item
async def get_by_id(self, item_id: int) -> SCHEMA:
item = await self._get_one(item_id)
return self._schema.from_orm(item)
async def get_multi(self) -> List[SCHEMA]:
query = select(self._table)
results = await self._db_session.execute(query)
return (self._schema.from_orm(item) for item in results.scalars())
async def update(self, item_id: int, update_schema) -> SCHEMA:
item = await self._get_one(item_id)
for key, value in update_schema.dict(exclude_unset=True).items():
setattr(item, key, value)
self._db_session.add(item)
await self._db_session.commit()
return self._schema.from_orm(item)
async def remove(self, item_id: int) -> SCHEMA:
item = await self._get_one(item_id)
await self._db_session.delete(item)
await self._db_session.commit()
return self._schema.from_orm(item)
class CRUDItem(CRUDBase[models.Item, schemas.ItemCreate, schemas.Item]):
@property
def _in_schema(self) -> Type[schemas.ItemCreate]:
return schemas.ItemCreate
@property
def _schema(self) -> Type[schemas.Item]:
return schemas.Item
@property
def _table(self) -> Type[models.Item]:
return models.Item
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, configure_mappers
from .plugin import FastAPIUsersPlugin
from sqlalchemy_continuum import make_versioned
make_versioned(plugins=[FastAPIUsersPlugin()], user_cls="UserTable")
SQLALCHEMY_DATABASE_URL = "postgresql+asyncpg://db_user:db_pass@localhost:5432/db_user"
engine = create_async_engine(SQLALCHEMY_DATABASE_URL)
async_session_maker = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
Base = declarative_base()
async def create_db_and_tables():
configure_mappers()
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async with async_session_maker() as session:
yield session
await session.commit()
from typing import List
from fastapi import Depends, FastAPI
from sqlalchemy.ext.asyncio.session import AsyncSession
from . import crud, schemas
from .middlewares import middleware
from .database import create_db_and_tables, get_async_session
from .users import auth_backend, current_active_user, fastapi_users_instance
app = FastAPI(middleware=middleware)
@app.on_event("startup")
async def on_startup():
# Not needed if you setup a migration system like Alembic
await create_db_and_tables()
# FastAPI-Users endpoints
app.include_router(
fastapi_users_instance.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"]
)
app.include_router(fastapi_users_instance.get_register_router(), prefix="/auth", tags=["auth"])
# Item endpoints
@app.post("/items/", response_model=schemas.Item)
async def create_item(
db: AsyncSession = Depends(get_async_session),
user: schemas.UserDB = Depends(current_active_user),
*,
item: schemas.ItemCreate,
):
item_crud = crud.CRUDItem(db)
item = await item_crud.create(in_schema=item)
return item
@app.put("/items/{item_id}", response_model=schemas.Item)
async def update_item(
db: AsyncSession = Depends(get_async_session),
user: schemas.UserDB = Depends(current_active_user),
*,
item_id: int,
schema_in: schemas.ItemCreate,
):
item_crud = crud.CRUDItem(db)
item = await item_crud.update(item_id, schema_in)
return item
@app.delete("/items/{item_id}", response_model=schemas.Item)
async def remove_item(
db: AsyncSession = Depends(get_async_session),
user: schemas.UserDB = Depends(current_active_user),
*,
item_id: int,
):
item_crud = crud.CRUDItem(db)
item = await item_crud.remove(item_id)
return item
@app.get("/items/", response_model=List[schemas.Item])
async def read_items(
db: AsyncSession = Depends(get_async_session),
):
item_crud = crud.CRUDItem(db)
items = await item_crud.get_multi()
return items
from typing import Any, Optional
from fastapi import Request
from fastapi.security.http import HTTPBearer
from fastapi.middleware import Middleware
from starlette_context import plugins
from starlette_context.middleware import ContextMiddleware
from .users import SECRET, AUDIENCE, get_jwt_strategy
class FastAPIUsersJWTPlugin(plugins.base.Plugin):
key = "user_id"
get_jwt_strategy = get_jwt_strategy()
jwt_bearer_authorization = HTTPBearer(auto_error=False)
secret = SECRET
audience = AUDIENCE
async def process_request(
self, request: Request
) -> Optional[Any]:
assert isinstance(self.key, str)
auth = await self.jwt_bearer_authorization(request)
if not auth:
return None
user_id = await self.get_jwt_strategy.get_user_id(auth.credentials)
return user_id
middleware = [
Middleware(
ContextMiddleware,
plugins=[
plugins.RequestIdPlugin(),
plugins.ForwardedForPlugin(),
FastAPIUsersJWTPlugin()
]
)
]
# The following is an alternative middleware not using `starlette-context` as seen above.
# Also, it gives an idea how you could do without FastAPI-Users and just manually decode the jwt token.
# from contextvars import ContextVar
# from fastapi import Request
# from fastapi.security.http import HTTPBearer
# from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
#
# USER_ID_CTX_KEY = "user_id"
# _user_id_ctx_var: ContextVar[str] = ContextVar(USER_ID_CTX_KEY, default=None)
#
# def get_user_id() -> str:
# return _user_id_ctx_var.get()
#
# class MyContextMiddleware(BaseHTTPMiddleware):
# async def dispatch(
# self, request: Request, call_next: RequestResponseEndpoint
# ):
# jwt_bearer_authorization = HTTPBearer(auto_error=False)
# auth = await jwt_bearer_authorization(request)
# if auth:
# try:
# payload = decode_jwt(auth.credentials, SECRET, [AUDIENCE])
# user_id = payload.get("user_id", None)
# if not user_id:
# user_id = None
# except jwt.PyJWTError:
# user_id = None
# else:
# user_id = None
# user_id_ctx = _user_id_ctx_var.set(user_id)
# response = await call_next(request)
# _user_id_ctx_var.reset(user_id_ctx)
# return response
from fastapi_users.db import SQLAlchemyBaseUserTable
from sqlalchemy import Column, Integer, String, Boolean
from .database import Base
class UserTable(Base, SQLAlchemyBaseUserTable):
pass
class Item(Base):
__tablename__ = "items"
__versioned__ = {}
id = Column(Integer, primary_key=True, autoincrement=True)
title = Column(String)
description = Column(String)
checked = Column(Boolean)
from starlette_context import context
from sqlalchemy_continuum.plugins import Plugin
def fetch_current_user_id():
return context.data.get("user_id")
def fetch_remote_addr():
return context.data.get("X-Forwarded-For")
class FastAPIUsersPlugin(Plugin):
def transaction_args(self, uow, session):
return {
"user_id": fetch_current_user_id(),
"remote_addr": fetch_remote_addr()
}
from typing import List, Optional
from fastapi_users import models
from pydantic import BaseModel
class User(models.BaseUser):
pass
class UserCreate(models.BaseUserCreate):
pass
class UserUpdate(models.BaseUserUpdate):
pass
class UserDB(User, models.BaseUserDB):
pass
class ItemBase(BaseModel):
title: str
description: Optional[str] = None
checked: Optional[bool] = None
class ItemCreate(ItemBase):
pass
class Item(ItemBase):
id: int
class Config:
orm_mode = True
from typing import Optional
from fastapi import Depends, Request
from fastapi_users import BaseUserManager, FastAPIUsers
from fastapi_users.authentication import AuthenticationBackend, BearerTransport, JWTStrategy
from fastapi_users.db import SQLAlchemyUserDatabase
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import UUID4
from .database import get_async_session
from .schemas import UserDB, UserCreate, UserUpdate, User
from .models import UserTable
SECRET = "SECRET"
AUDIENCE = "test-sqlalchemy-continuum:auth"
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
yield SQLAlchemyUserDatabase(UserDB, session, UserTable)
class UserManager(BaseUserManager[UserCreate, UserDB]):
user_db_model = UserDB
reset_password_token_secret = SECRET
verification_token_secret = SECRET
async def on_after_register(self, user: UserDB, request: Optional[Request] = None):
print(f"User {user.id} has registered.")
async def on_after_forgot_password(
self, user: UserDB, token: str, request: Optional[Request] = None
):
print(f"User {user.id} has forgot their password. Reset token: {token}")
async def on_after_request_verify(
self, user: UserDB, token: str, request: Optional[Request] = None
):
print(f"Verification requested for user {user.id}. Verification token: {token}")
async def get_user_manager(user_db=Depends(get_user_db)):
yield UserManager(user_db)
bearer_transport = BearerTransport(tokenUrl="auth/jwt/login")
#
## The important part below vvv
#
# Custom JWT Strategy
import jwt
from fastapi_users.jwt import decode_jwt
from fastapi_users import models
from fastapi_users.manager import BaseUserManager, UserNotExists
class MyJWTStrategy(JWTStrategy):
async def get_user_id(self, token: Optional[str]):
try:
data = decode_jwt(token, self.secret, self.token_audience)
user_id = data.get("user_id")
return user_id
except jwt.PyJWTError:
return None
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD]
) -> Optional[models.UD]:
if token is None:
return None
user_id = await self.get_user_id(token)
if user_id is None:
return None
try:
user_uiid = UUID4(user_id)
return await user_manager.get(user_uiid)
except ValueError:
return None
except UserNotExists:
return None
def get_jwt_strategy() -> MyJWTStrategy:
return MyJWTStrategy(secret=SECRET, lifetime_seconds=3600, token_audience=[AUDIENCE])
#
## The important part above ^^^
#
auth_backend = AuthenticationBackend(
name="jwt",
transport=bearer_transport,
get_strategy=get_jwt_strategy,
)
fastapi_users_instance = FastAPIUsers(
get_user_manager,
[auth_backend],
User,
UserCreate,
UserUpdate,
UserDB,
)
current_active_user = fastapi_users_instance.current_user(active=True)
@michaeltoohig
Copy link
Author

It appears the latest updates to FastAPI give new functionality to middleware which may make this gist obsolete in the future. But currently FastAPI-Users does not support this new version so I can not look further into it.

https://github.com/tiangolo/fastapi/releases/tag/0.74.1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment