Skip to content

Instantly share code, notes, and snippets.

@kurtbrose
Created January 25, 2023 21:55
Show Gist options
  • Save kurtbrose/dc11fd9149f63159008f829d519ce517 to your computer and use it in GitHub Desktop.
Save kurtbrose/dc11fd9149f63159008f829d519ce517 to your computer and use it in GitHub Desktop.
Helper for converting type annotations into sqlalchemy columns
"""
This module is a helper for converting a dataclass-like annotation class into a sqlalchemy ORM.
"""
from dataclasses import dataclass
from datetime import datetime
import enum
import functools
import inspect
import re
import types
import typing
from sqlalchemy import Column, Integer, String, ForeignKey, Enum, Boolean, DateTime
from sqlalchemy.orm import relationship
_TYPE_MAP = {
int: Integer,
str: String,
bool: Boolean,
datetime: DateTime,
}
def _to_snake_case(camel_case):
"""CamelCase -> camel_case; SnakeCase -> snake_case"""
return re.sub(r'(?<!^)(?=[A-Z])', '_', camel_case).lower()
def _is_optional(annotation):
"""Check if an annotation was declared as Optional[type]."""
if typing.get_origin(annotation) in (typing.Union, types.UnionType):
args = typing.get_args(annotation)
if len(args) == 2 and args[-1] is None.__class__:
return True, args[0]
raise ValueError(f"unsupported type annotation: {annotation}")
return False, annotation
def _pytype2sqla_cols(name, type_) -> dict:
"""
Given a python type annotation, convert it to a dict to be added to a sqlalchemy orm class.
Does not attempt to support all possible type annotations, only a reasonable subset.
"""
is_optional, type_ = _is_optional(type_)
col_kwargs = dict(nullable=is_optional)
if type_ in _TYPE_MAP:
col_type = _TYPE_MAP[type_]
return {name: Column(col_type, **col_kwargs)}
if isinstance(type_, typing.ForwardRef): # sometimes the strings are wrapped in ForwardRef
type_ = type_.__forward_arg__
if isinstance(type_, str): # string referencing a type
model_ref_name = type_
table_ref_name = _to_snake_case(model_ref_name) + "s"
id_col = Column(Integer, ForeignKey(f"{table_ref_name}.id"), index=True, **col_kwargs)
rel = relationship(model_ref_name, foreign_keys=[id_col])
return {name + "_id": id_col, name: rel}
if isinstance(type_, type) and issubclass(type_, enum.Enum):
return {name: Column(Enum(type_), **col_kwargs)}
raise ValueError(f"unsupported type annotation: {type_}")
def auto_orm(cls: type) -> type:
"""
Given a dataclass-like annotation, returns a sqlalchemy ORM model
"""
annotations = inspect.get_annotations(cls)
if not annotations:
raise ValueError(f"{cls} does not define any columns!")
body = dict(
__original_class__=cls,
__tablename__=_to_snake_case(cls.__name__) + "s",
id=Column(Integer, primary_key=True)
)
for name, type_ in annotations.items():
body.update(_pytype2sqla_cols(name, type_))
return type(cls.__name__, cls.__bases__, body)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment