Skip to content

Instantly share code, notes, and snippets.

@DomWeldon
Last active June 24, 2022 15:38
Show Gist options
  • Save DomWeldon/adb17318202dd0b673a0753e37c46b4b to your computer and use it in GitHub Desktop.
Save DomWeldon/adb17318202dd0b673a0753e37c46b4b to your computer and use it in GitHub Desktop.
CRUD Router
# Standard Library
from typing import (
Any,
Callable,
Dict,
FrozenSet,
List,
Optional,
Sequence,
Set,
Type,
cast,
)
# Third Party Libraries
from fastapi import APIRouter, HTTPException, status
from fastapi.routing import APIRoute
from pydantic import BaseModel
from sqlalchemy import ForeignKeyConstraint, Table, inspect
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.orm import Session
from sqlalchemy.sql.elements import UnaryExpression
from starlette import routing
from starlette.responses import Response
from starlette.types import ASGIApp
# App and Model Imports
from app.utils.oop import all_subclasses
class CRUDApiRouter(APIRouter):
"""Automatically generate a router with essential methods on."""
db_dep: Session
"""Dependency to get a DB session"""
model: DeclarativeMeta
"""Model this router is for"""
_HTTP_2XX_NO_RETURN_CODES: Set[int] = {
status.HTTP_204_NO_CONTENT,
status.HTTP_205_RESET_CONTENT,
}
model_base: Optional[DeclarativeMeta] = None
"""Base for the model if using create or update"""
# for listing rows
list_schema: Optional[BaseModel] = None
list_deps: Optional[List[Any]] = None
list_sort: Optional[Sequence[UnaryExpression]] = None
list_default_offset: Optional[int] = 0
list_default_limit: Optional[int] = 100
list_max_limit: int = 1_000
list_view_path: str = "/"
# for detail page
detail_schema: Optional[BaseModel] = None
detail_deps: Optional[List[Any]] = None
detail_view_path: str = "/{id}"
detail_status_code_not_found: int = status.HTTP_404_NOT_FOUND
# for create view
create_schema_in: Optional[BaseModel] = None
create_schema_out: Optional[BaseModel] = None
create_status_code: int = status.HTTP_201_CREATED
create_status_code_fk_error: int = status.HTTP_422_UNPROCESSABLE_ENTITY
create_status_code_conflict: int = status.HTTP_409_CONFLICT
create_view_path: str = "/"
create_deps: Optional[List[Any]] = None
# for update view
update_schema_in: Optional[BaseModel] = None
update_schema_out: Optional[BaseModel] = None
update_status_code: int = status.HTTP_204_NO_CONTENT
update_status_code_fk_error: int = status.HTTP_422_UNPROCESSABLE_ENTITY
update_status_code_conflict: int = status.HTTP_409_CONFLICT
update_status_code_not_found: int = status.HTTP_404_NOT_FOUND
update_view_path: str = "/{id}"
update_deps: Optional[List[Any]] = None
delete_view: bool = False
delete_deps: Optional[List[Any]] = None
delete_view_path: str = "/{id}"
delete_status_code: int = status.HTTP_204_NO_CONTENT
delete_schema_out: Optional[BaseModel] = None
delete_status_code_not_found: int = status.HTTP_404_NOT_FOUND
def __init__(
self,
routes: Optional[List[routing.BaseRoute]] = None,
redirect_slashes: bool = True,
default: Optional[ASGIApp] = None,
dependency_overrides_provider: Optional[Any] = None,
route_class: Type[APIRoute] = APIRoute,
default_response_class: Optional[Type[Response]] = None,
on_startup: Optional[Sequence[Callable]] = None,
on_shutdown: Optional[Sequence[Callable]] = None,
) -> None:
"""Instantiate like a normal API view then add CRUD methods."""
assert self.model is not None
super().__init__(
routes=routes,
redirect_slashes=redirect_slashes,
default=default,
dependency_overrides_provider=dependency_overrides_provider,
route_class=route_class,
default_response_class=default_response_class,
on_startup=on_startup,
on_shutdown=on_shutdown,
)
if self.list_schema is not None:
endpoint = self._generate_list_view()
self.get(
self.list_view_path,
response_model=List[self.list_schema], # type: ignore
dependencies=(self.list_deps or []),
description=endpoint._description, # type: ignore
)(endpoint)
if self.detail_schema is not None:
endpoint = self._generate_detail_view()
self.get(
self.detail_view_path,
response_model=self.detail_schema, # type: ignore
dependencies=(self.detail_deps or []),
description=endpoint._description, # type: ignore
)(endpoint)
if self.create_schema_in is not None:
assert self.create_schema_out is not None
assert self.model_base is not None
endpoint = self._generate_create_view()
self.post(
self.create_view_path,
response_model=self.create_schema_out, # type: ignore
status_code=self.create_status_code,
dependencies=(self.create_deps or []),
description=endpoint._description, # type: ignore
)(endpoint)
if self.update_schema_in is not None:
assert self.model_base is not None
endpoint = self._generate_update_view()
self.put(
self.update_view_path,
response_model=self.update_schema_out, # type: ignore
status_code=self.update_status_code,
dependencies=(self.update_deps or []),
description=endpoint._description, # type: ignore
)(endpoint)
if self.delete_view is True:
endpoint = self._generate_delete_view()
self.delete(
self.delete_view_path,
response_model=self.delete_schema_out, # type: ignore
status_code=self.delete_status_code,
dependencies=(self.delete_deps or []),
description=endpoint._description, # type: ignore
)(endpoint)
def _generate_list_view(self) -> Callable:
"""Create a generic list view for Model.
To sort, set
list_sort = [SomeModel.some_property.asc()]
"""
def list_view(
db: Session = self.db_dep,
offset: Optional[int] = self.list_default_offset,
limit: Optional[int] = self.list_default_limit,
) -> Any:
query = db.query(self.model)
if self.list_sort is not None:
query = query.order_by(*self.list_sort)
if (
self.list_default_offset is not None
and self.list_default_limit is not None
):
query = query.offset(offset).limit(limit)
return query.all()
list_view._description = ( # type: ignore
f"🗃️ List {self.model.__name__} sorted by "
f"{', '.join(str(x) for x in self.list_sort or [])}"
)
return list_view
def _generate_detail_view(self) -> Callable:
"""Create a generic detail view for Model.
At the moment this supports _only_ a model with a single, non-composite
PK which is an integer called id.
"""
cols = inspect(self.model).primary_key
assert len(cols) == 1
[pk_col] = cols
assert pk_col.type.python_type == int
def detail_view(id: int, db: Session = self.db_dep,) -> Any:
obj = (
db.query(self.model)
.filter(getattr(self.model, pk_col.key) == id)
.scalar()
)
if obj is None:
raise HTTPException(
status_code=self.detail_status_code_not_found
)
return obj
detail_view._description = ( # type: ignore
f"""📁 Show {self.model.__name__} identified by {pk_col.key}"""
)
return detail_view
@property
def _MODEL_MAP(self) -> Dict[Table, DeclarativeMeta]:
"""Create a mapping of table names to models"""
return {
cast(Table, m.__table__): cast(DeclarativeMeta, m)
for m in all_subclasses(self.model_base) # type: ignore
}
def _check_fk_constraints(
self,
*,
db: Session,
model_map: Dict[Table, DeclarativeMeta],
fk_constraints: Set[ForeignKeyConstraint],
obj_in: BaseModel,
status_code_fk_error: int,
) -> None:
"""Raises errors if FK constraints are violated."""
for constraint in fk_constraints:
# check that the corresponding rows exist
referred_model = model_map[constraint.referred_table]
# neeed a mapping of column keys on self.model
constraint_columns_map = {
# which map to
col.key: next(
iter(
# the key on the foreign referred_model
col_fk.target_fullname.split(".")[-1]
# being referenced by this foreign key
for col_fk in col.foreign_keys
if col_fk.constraint.referred_table # type: ignore
== constraint.referred_table
)
)
# for every column in this constraint
for col in constraint.columns
}
# now, we check a row with that value exists
num_rows = (
db.query(referred_model)
.filter(
*(
getattr(referred_model, target_key)
== getattr(obj_in, local_key)
for (
local_key,
target_key,
) in constraint_columns_map.items()
)
)
.count()
)
if num_rows == 0:
error_row = (
f"{target_key}={getattr(obj_in, local_key)}"
for local_key, target_key in constraint_columns_map.items()
)
raise HTTPException(
status_code=status_code_fk_error,
detail=(
"I could not find a value of "
f"{referred_model.__name__} with values "
f"{' '.join(error_row)}"
),
)
def _check_unique_indexes(
self,
db: Session,
unique_indexes: Set[FrozenSet[str]],
obj_in: BaseModel,
status_code_conflict: int,
) -> None:
"""Check unique indexes won't be violated."""
for ix in unique_indexes:
num_rows = (
db.query(self.model)
.filter(
*(getattr(self.model, k) == getattr(obj_in, k) for k in ix)
)
.count()
)
if num_rows != 0:
error_row = (f"{k}={getattr(obj_in, k)}" for k in ix)
raise HTTPException(
status_code=self.create_status_code_conflict,
detail=(
f"A row already exists in {self.model.__name__} "
f"with values {' '.join(error_row)}"
),
)
def _generate_create_view(self) -> Callable:
"""Generate a generic create view for the Model.
Required features:
- check for conflicts on unique constraints
- check referenced foreign keys exist
- create and return resource with 201 by default
"""
# we know it's valid because it passed the schema
# are there any foreign keys?
schema_cols = (
cast(BaseModel, self.create_schema_in)
.schema()["properties"]
.keys()
)
# filter out the constraints we need to check
fkcs = self.model.__table__.foreign_key_constraints # type: ignore
fk_constraints = {
constraint
for constraint in fkcs
if {c.key for c in constraint.columns} < schema_cols
}
# likewise unique indexes
unique_indexes = {
frozenset(c.key for c in ix.columns)
for ix in self.model.__table__.indexes # type: ignore
if ix.unique and frozenset(c.key for c in ix.columns) < schema_cols
}
# we'll need this to lookup models based on tables
model_map = self._MODEL_MAP
def create_view(
*,
obj_in: self.create_schema_in, # type: ignore
db: Session = self.db_dep,
) -> Any:
f"""Create new {self.model.__name__}"""
# for every constraint on this model
self._check_fk_constraints(
db=db,
model_map=model_map,
fk_constraints=fk_constraints,
obj_in=obj_in,
status_code_fk_error=self.create_status_code_fk_error,
)
# check for unique indexes
self._check_unique_indexes(
db=db,
unique_indexes=unique_indexes,
obj_in=obj_in,
status_code_conflict=self.create_status_code_conflict,
)
# make the insert
instance = self.model()
for k, v in obj_in.dict().items():
setattr(instance, k, v)
db.add(instance)
db.commit()
db.refresh(instance)
return instance
create_view._description = ( # type: ignore
f"""💾 Create new {self.model.__name__}"""
)
return create_view
def _generate_update_view(self) -> Callable:
"""Generic update view creator."""
assert not (
self.update_schema_out is not None
and self.update_status_code in self._HTTP_2XX_NO_RETURN_CODES
)
cols = inspect(self.model).primary_key
assert len(cols) == 1
[pk_col] = cols
assert pk_col.type.python_type == int
# we know it's valid because it passed the schema
# are there any foreign keys?
schema_cols = (
cast(BaseModel, self.update_schema_in)
.schema()["properties"]
.keys()
)
# filter out the constraints we need to check
fkcs = self.model.__table__.foreign_key_constraints # type: ignore
fk_constraints = {
constraint
for constraint in fkcs # type: ignore
if {c.key for c in constraint.columns} < schema_cols
}
# likewise unique indexes
unique_indexes = {
frozenset(c.key for c in ix.columns)
for ix in self.model.__table__.indexes # type: ignore
if ix.unique and frozenset(c.key for c in ix.columns) < schema_cols
}
# we'll need this to lookup models based on tables
model_map = self._MODEL_MAP
def update_view(
*,
id: int,
db: Session = self.db_dep,
obj_in: self.update_schema_in, # type: ignore
) -> Any:
obj = (
db.query(self.model)
.filter(getattr(self.model, pk_col.key) == id)
.scalar()
)
if obj is None:
raise HTTPException(
status_code=self.update_status_code_not_found
)
# for every constraint on this model
self._check_fk_constraints(
db=db,
model_map=model_map,
fk_constraints=fk_constraints,
obj_in=obj_in,
status_code_fk_error=self.update_status_code_fk_error,
)
# check for unique indexes
self._check_unique_indexes(
db=db,
unique_indexes=unique_indexes,
obj_in=obj_in,
status_code_conflict=self.update_status_code_conflict,
)
for k, v in obj_in.dict().items():
setattr(obj, k, v)
db.add(obj)
db.commit()
db.refresh(obj)
return (
obj
if self.update_status_code
not in self._HTTP_2XX_NO_RETURN_CODES
else None
)
update_view._description = ( # type: ignore
f"""📝 Update {self.model.__name__} identified by {pk_col.key}"""
)
return update_view
def _generate_delete_view(self) -> Callable:
"""Generic delete view."""
assert not (
self.delete_schema_out is not None
and self.delete_status_code in self._HTTP_2XX_NO_RETURN_CODES
)
assert not (
self.delete_schema_out is None
and self.delete_status_code not in self._HTTP_2XX_NO_RETURN_CODES
)
cols = inspect(self.model).primary_key
assert len(cols) == 1
[pk_col] = cols
assert pk_col.type.python_type == int
def delete_view(*, id: int, db: Session = self.db_dep,) -> Any:
obj = (
db.query(self.model)
.filter(getattr(self.model, pk_col.key) == id)
.scalar()
)
if obj is None:
raise HTTPException(
status_code=self.delete_status_code_not_found
)
db.delete(obj)
db.commit()
return obj if self.delete_schema_out is not None else None
delete_view._description = ( # type: ignore
f"""❌ Delete {self.model.__name__} identified by {pk_col.key}"""
)
return delete_view
# Standard Library
from typing import TYPE_CHECKING
# Third Party Libraries
from sqlalchemy import Column, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
# App and Model Imports
from app.db.base_class import Base
if TYPE_CHECKING:
from .user import User # noqa: F401
class Item(Base):
id = Column(Integer, primary_key=True, index=True)
title = Column(String, index=True)
description = Column(String, index=True)
owner_id = Column(Integer, ForeignKey("user.id"))
owner = relationship("User", back_populates="items")
# Standard Library
from typing import TYPE_CHECKING
# Third Party Libraries
from sqlalchemy import Column, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
# App and Model Imports
from app.db.base_class import Base
if TYPE_CHECKING:
from .user import User # noqa: F401
class Item(Base):
id = Column(Integer, primary_key=True, index=True)
title = Column(String, index=True)
description = Column(String, index=True)
owner_id = Column(Integer, ForeignKey("user.id"))
owner = relationship("User", back_populates="items")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment