-
-
Save ixycoexzckwpmlcu/d02b5951d17e6b25d4e5889b375ede53 to your computer and use it in GitHub Desktop.
SQLModel single table inheritance metaclass workaround
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Note: Output has been edited for brevity. This is the output (both stdout and stderr) when running testing_code.py | |
UserWarning: Field name "id" in "Admin" shadows an attribute in parent "User" | |
UserWarning: Field name "type" in "Admin" shadows an attribute in parent "User" | |
UserWarning: Field name "name" in "Admin" shadows an attribute in parent "User" | |
UserWarning: Field name "id" in "Moderator" shadows an attribute in parent "User" | |
UserWarning: Field name "type" in "Moderator" shadows an attribute in parent "User" | |
UserWarning: Field name "name" in "Moderator" shadows an attribute in parent "User" | |
BEGIN (implicit) | |
PRAGMA main.table_info("user") | |
PRAGMA temp.table_info("user") | |
PRAGMA main.table_info("post") | |
PRAGMA temp.table_info("post") | |
CREATE TABLE user ( | |
id INTEGER NOT NULL, | |
type VARCHAR NOT NULL, | |
name VARCHAR NOT NULL, | |
can_ban_users BOOLEAN, | |
PRIMARY KEY (id) | |
) | |
CREATE TABLE post ( | |
id INTEGER NOT NULL, | |
author_id INTEGER NOT NULL, | |
content VARCHAR NOT NULL, | |
PRIMARY KEY (id), | |
FOREIGN KEY(author_id) REFERENCES user (id) | |
) | |
COMMIT | |
BEGIN (implicit) | |
INSERT INTO user (type, name, can_ban_users) VALUES ('user', 'user_a', 1) RETURNING id | |
INSERT INTO user (type, name, can_ban_users) VALUES ('user', 'user_b', 1) RETURNING id | |
INSERT INTO user (type, name, can_ban_users) VALUES ('user', 'user_c', 1) RETURNING id | |
INSERT INTO user (type, name, can_ban_users) VALUES ('admin', 'admin_a', 1) RETURNING id | |
INSERT INTO user (type, name, can_ban_users) VALUES ('moderator', 'moderator_a', 1) RETURNING id | |
INSERT INTO post (author_id, content) VALUES (1, 'post_a') RETURNING id | |
INSERT INTO post (author_id, content) VALUES (4, 'post_b') RETURNING id | |
COMMIT | |
BEGIN (implicit) | |
SELECT user.id, user.type, user.name FROM user | |
SELECT user.id, user.type, user.name, user.can_ban_users FROM user WHERE user.type IN ('admin',) | |
SELECT post.id, post.author_id, post.content FROM post | |
users=[User(type='user', id=1, name='user_a'), User(type='user', id=2, name='user_b'), User(type='user', id=3, name='user_c'), Admin(type='admin', id=4, name='admin_a', can_ban_users=True), Moderator(type='moderator', id=5, name='moderator_a')] | |
admins=[Admin(type='admin', id=4, name='admin_a', can_ban_users=True)] | |
posts=[Post(author_id=1, id=1, content='post_a'), Post(author_id=4, id=2, content='post_b')] | |
ROLLBACK |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import enum | |
import uuid | |
from datetime import datetime | |
from uuid import UUID | |
from sqlmodel import SQLModel, Field, Relationship, Session, create_engine | |
engine = create_engine("sqlite://", echo=True) | |
def get_db_session(): | |
return Session(engine) | |
class HardwareCategory(enum.Enum): | |
CONTROLLER = 'controller' | |
LIGHT = 'light' | |
CABLE_XLR = 'cable_xlr' | |
PLUG_COLD_APPLIANCE = 'plug_cold_appliance' | |
LAPTOP_STAND = 'laptop_stand' | |
OTHER = 'other' | |
class User(SQLModel, table=True): | |
__tablename__ = "users" | |
id: UUID | None = Field(default=None, primary_key=True) | |
username: str | |
first_name: str | None = None | |
last_name: str | None = None | |
created_at: datetime | None = None | |
updated_at: datetime | None = None | |
hardware: list["Hardware"] = Relationship( | |
back_populates="owner" | |
) | |
class Hardware(SQLModel, table=True): | |
__tablename__ = "hardware" | |
id: UUID | None = Field(default=None, primary_key=True) | |
name: str | |
serial: str | |
image: str | None = None | |
category: HardwareCategory | |
owner_id: UUID = Field(foreign_key="users.id") | |
created_at: datetime | None = None | |
updated_at: datetime | None = None | |
owner: User = Relationship(back_populates="hardware") | |
def create_hardware(hardware: Hardware) -> Hardware: | |
now = datetime.now() | |
hardware.id = uuid.uuid4() | |
hardware.created_at = now | |
hardware.updated_at = now | |
validated_model = Hardware.model_validate(hardware) | |
with get_db_session() as session: | |
session.add(validated_model) | |
session.commit() | |
session.refresh(validated_model) | |
return validated_model | |
def create_user(user: User) -> User: | |
now = datetime.now() | |
user.id = uuid.uuid4() | |
user.created_at = now | |
user.updated_at = now | |
validated_model = User.model_validate(user) | |
with get_db_session() as session: | |
session.add(validated_model) | |
session.commit() | |
session.refresh(validated_model) | |
return validated_model | |
if __name__ == "__main__": | |
SQLModel.metadata.create_all(engine) | |
user = create_user( | |
User( | |
id=uuid.uuid4(), | |
username='testuser', | |
first_name='Test', | |
last_name='User', | |
created_at=datetime.now(), | |
updated_at=datetime.now() | |
) | |
) | |
hardware = create_hardware( | |
Hardware( | |
id=uuid.uuid4(), | |
name='test hardware', | |
serial=f'hdw-{uuid.uuid4()}', | |
image=None, | |
category=HardwareCategory.CONTROLLER, | |
owner=user, | |
owner_id=user.id, | |
created_at=datetime.now(), | |
updated_at=datetime.now() | |
) | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine | |
from sqlmodel import select | |
from single_table_inheritance import BaseMetaclass | |
class Base(SQLModel, metaclass=BaseMetaclass): | |
pass | |
class User(Base, table=True): | |
id: int = Field(default=None, primary_key=True) | |
type: str = Field() | |
name: str = Field() | |
posts: list["Post"] = Relationship(back_populates="author") | |
__mapper_args__ = { | |
"polymorphic_identity": "user", | |
"polymorphic_on": "type", | |
} | |
class Admin(User, table=True, single_table_inheritance=True): | |
can_ban_users: bool = Field(default=True, nullable=True) | |
__mapper_args__ = { | |
"polymorphic_identity": "admin", | |
} | |
class Moderator(User, table=True, single_table_inheritance=True): | |
can_ban_users: bool = Field(default=False, nullable=True) | |
__mapper_args__ = { | |
"polymorphic_identity": "moderator", | |
} | |
class Post(SQLModel, table=True): | |
id: int = Field(default=None, primary_key=True) | |
author_id: int = Field(default=None, foreign_key="user.id") | |
content: str = Field() | |
author: User = Relationship(back_populates="posts") | |
if __name__ == "__main__": | |
engine = create_engine("sqlite://", echo=True) | |
SQLModel.metadata.create_all(engine) | |
with Session(engine) as session: | |
session.add(User(name="user_a", posts=[Post(content="post_a")])) | |
session.add(User(name="user_b")) | |
session.add(User(name="user_c")) | |
session.add(Admin(name="admin_a", posts=[Post(content="post_b")])) | |
session.add(Moderator(name="moderator_a", can_ban_users=True)) | |
session.commit() | |
users = session.exec(select(User)).all() | |
admins = session.exec(select(Admin)).all() | |
posts = session.exec(select(Post)).all() | |
print(f"{users=}") | |
print(f"{admins=}") | |
print(f"{posts=}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, could you give an example of the definition of
BaseMetaclass
? Thank you!