Last active
May 30, 2019 21:37
-
-
Save pourquoi/dd88e08e72b4a2a4e960a0106bdeda24 to your computer and use it in GitHub Desktop.
sqlalchemy models prepartion for pydantic
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
from typing import Union | |
from pydantic import BaseModel | |
from pydantic.fields import Field | |
from app.db.base_class import Base | |
def _get_field_model(field: Field): | |
""" | |
Returns the BaseModel of this field if it's associated to one. | |
""" | |
if field.sub_fields: | |
for f in field.sub_fields: | |
if inspect.isclass(f.type_) and issubclass(f.type_, BaseModel): | |
return f.type_ | |
if inspect.isclass(field.type_) and issubclass(field.type_, BaseModel): | |
return field.type_ | |
def marshal(data, model: BaseModel): | |
""" | |
Get a dict/list representation of sqlalchemy models suitable for pydantic. | |
""" | |
if data is None: | |
return None | |
if isinstance(data, (list, tuple)): | |
return [ marshal(d, model) for d in data ] | |
if isinstance(data, Base): | |
data = data.__dict__ | |
ret = {} | |
for k, field in model.__fields__.items(): | |
if k not in data or data[k] is None: | |
continue | |
ret[k] = _marshal_field(data[k], field) | |
return ret | |
def _marshal_field(data, field: Field): | |
if isinstance(data, (list,tuple)): | |
return [_marshal_field(d, field) for d in data] | |
model = _get_field_model(field) | |
if model: | |
return marshal(data, model) | |
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional, List | |
from pydantic import BaseModel | |
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String, Table, Text | |
from sqlalchemy.orm import relationship, backref | |
from app.db.base_class import Base | |
class User(Base): | |
id = Column(Integer, primary_key=True, index=True) | |
username = Column(String, index=True, unique=True, nullable=False) | |
email = Column(String, index=True) | |
posts = relationship("Post", back_populates="author") | |
class Post(Base): | |
id = Column(Integer, primary_key=True, index=True) | |
body = Column(Text) | |
title = Column(String, index=True, nullable=False) | |
author_id = Column(Integer, ForeignKey("user.id")) | |
author = relationship("User", back_populates="posts", lazy=False) | |
comments = relationship("Comment", back_populates="post") | |
class Comment(Base): | |
id = Column(Integer, primary_key=True, index=True) | |
body = Column(Text) | |
author_id = Column(Integer, ForeignKey("user.id"), nullable=False) | |
author = relationship("User") | |
post_id = Column(Integer, ForeignKey("post.id")) | |
post = relationship("Post", back_populates="comments") | |
parent_id = Column(Integer, ForeignKey('comment.id')) | |
replies = relationship("Comment", backref=backref('parent', remote_side=[id])) | |
class ApiUser(BaseModel): | |
id: int | |
username: str | |
class ApiComment(BaseModel): | |
id: int = None | |
post_id: int = None | |
body: str = None | |
replies: Optional[List['ApiComment']] = None | |
class ApiPost(BaseModel): | |
id: int | |
author: User | |
title: str | |
body: str = None | |
comments: Optional[List[ApiComment]] | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from app.db.session import db_session | |
from .models import User, Post, Comment, ApiUser, ApiPost, ApiComment | |
from .marshal import marshal | |
user = User(username='bobby') | |
post = Post(title='post', author=user) | |
comment = Comment(body='a comment', post=post, author=user) | |
db_session.add_all([user, post, comment]) | |
db_session.flush() | |
repr = marshal(post, ApiPost) | |
# now we can pass repr to a pydantic model or return it in fastapi response | |
out = ApiPost(**repr) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment