Skip to content

Instantly share code, notes, and snippets.

@ixycoexzckwpmlcu
Created September 1, 2024 01:26
Show Gist options
  • Save ixycoexzckwpmlcu/d02b5951d17e6b25d4e5889b375ede53 to your computer and use it in GitHub Desktop.
Save ixycoexzckwpmlcu/d02b5951d17e6b25d4e5889b375ede53 to your computer and use it in GitHub Desktop.
SQLModel single table inheritance metaclass workaround
Note: Output has been edited for brevity. This is the output (both stdout and stderr) when running testing_code.py
UserWarning: Field name "id" in "Admin" shadows an attribute in parent "User"
UserWarning: Field name "type" in "Admin" shadows an attribute in parent "User"
UserWarning: Field name "name" in "Admin" shadows an attribute in parent "User"
UserWarning: Field name "id" in "Moderator" shadows an attribute in parent "User"
UserWarning: Field name "type" in "Moderator" shadows an attribute in parent "User"
UserWarning: Field name "name" in "Moderator" shadows an attribute in parent "User"
BEGIN (implicit)
PRAGMA main.table_info("user")
PRAGMA temp.table_info("user")
PRAGMA main.table_info("post")
PRAGMA temp.table_info("post")
CREATE TABLE user (
id INTEGER NOT NULL,
type VARCHAR NOT NULL,
name VARCHAR NOT NULL,
can_ban_users BOOLEAN,
PRIMARY KEY (id)
)
CREATE TABLE post (
id INTEGER NOT NULL,
author_id INTEGER NOT NULL,
content VARCHAR NOT NULL,
PRIMARY KEY (id),
FOREIGN KEY(author_id) REFERENCES user (id)
)
COMMIT
BEGIN (implicit)
INSERT INTO user (type, name, can_ban_users) VALUES ('user', 'user_a', 1) RETURNING id
INSERT INTO user (type, name, can_ban_users) VALUES ('user', 'user_b', 1) RETURNING id
INSERT INTO user (type, name, can_ban_users) VALUES ('user', 'user_c', 1) RETURNING id
INSERT INTO user (type, name, can_ban_users) VALUES ('admin', 'admin_a', 1) RETURNING id
INSERT INTO user (type, name, can_ban_users) VALUES ('moderator', 'moderator_a', 1) RETURNING id
INSERT INTO post (author_id, content) VALUES (1, 'post_a') RETURNING id
INSERT INTO post (author_id, content) VALUES (4, 'post_b') RETURNING id
COMMIT
BEGIN (implicit)
SELECT user.id, user.type, user.name FROM user
SELECT user.id, user.type, user.name, user.can_ban_users FROM user WHERE user.type IN ('admin',)
SELECT post.id, post.author_id, post.content FROM post
users=[User(type='user', id=1, name='user_a'), User(type='user', id=2, name='user_b'), User(type='user', id=3, name='user_c'), Admin(type='admin', id=4, name='admin_a', can_ban_users=True), Moderator(type='moderator', id=5, name='moderator_a')]
admins=[Admin(type='admin', id=4, name='admin_a', can_ban_users=True)]
posts=[Post(author_id=1, id=1, content='post_a'), Post(author_id=4, id=2, content='post_b')]
ROLLBACK
import enum
import uuid
from datetime import datetime
from uuid import UUID
from sqlmodel import SQLModel, Field, Relationship, Session, create_engine
engine = create_engine("sqlite://", echo=True)
def get_db_session():
return Session(engine)
class HardwareCategory(enum.Enum):
CONTROLLER = 'controller'
LIGHT = 'light'
CABLE_XLR = 'cable_xlr'
PLUG_COLD_APPLIANCE = 'plug_cold_appliance'
LAPTOP_STAND = 'laptop_stand'
OTHER = 'other'
class User(SQLModel, table=True):
__tablename__ = "users"
id: UUID | None = Field(default=None, primary_key=True)
username: str
first_name: str | None = None
last_name: str | None = None
created_at: datetime | None = None
updated_at: datetime | None = None
hardware: list["Hardware"] = Relationship(
back_populates="owner"
)
class Hardware(SQLModel, table=True):
__tablename__ = "hardware"
id: UUID | None = Field(default=None, primary_key=True)
name: str
serial: str
image: str | None = None
category: HardwareCategory
owner_id: UUID = Field(foreign_key="users.id")
created_at: datetime | None = None
updated_at: datetime | None = None
owner: User = Relationship(back_populates="hardware")
def create_hardware(hardware: Hardware) -> Hardware:
now = datetime.now()
hardware.id = uuid.uuid4()
hardware.created_at = now
hardware.updated_at = now
validated_model = Hardware.model_validate(hardware)
with get_db_session() as session:
session.add(validated_model)
session.commit()
session.refresh(validated_model)
return validated_model
def create_user(user: User) -> User:
now = datetime.now()
user.id = uuid.uuid4()
user.created_at = now
user.updated_at = now
validated_model = User.model_validate(user)
with get_db_session() as session:
session.add(validated_model)
session.commit()
session.refresh(validated_model)
return validated_model
if __name__ == "__main__":
SQLModel.metadata.create_all(engine)
user = create_user(
User(
id=uuid.uuid4(),
username='testuser',
first_name='Test',
last_name='User',
created_at=datetime.now(),
updated_at=datetime.now()
)
)
hardware = create_hardware(
Hardware(
id=uuid.uuid4(),
name='test hardware',
serial=f'hdw-{uuid.uuid4()}',
image=None,
category=HardwareCategory.CONTROLLER,
owner=user,
owner_id=user.id,
created_at=datetime.now(),
updated_at=datetime.now()
)
)
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine
from sqlmodel import select
from single_table_inheritance import BaseMetaclass
class Base(SQLModel, metaclass=BaseMetaclass):
pass
class User(Base, table=True):
id: int = Field(default=None, primary_key=True)
type: str = Field()
name: str = Field()
posts: list["Post"] = Relationship(back_populates="author")
__mapper_args__ = {
"polymorphic_identity": "user",
"polymorphic_on": "type",
}
class Admin(User, table=True, single_table_inheritance=True):
can_ban_users: bool = Field(default=True, nullable=True)
__mapper_args__ = {
"polymorphic_identity": "admin",
}
class Moderator(User, table=True, single_table_inheritance=True):
can_ban_users: bool = Field(default=False, nullable=True)
__mapper_args__ = {
"polymorphic_identity": "moderator",
}
class Post(SQLModel, table=True):
id: int = Field(default=None, primary_key=True)
author_id: int = Field(default=None, foreign_key="user.id")
content: str = Field()
author: User = Relationship(back_populates="posts")
if __name__ == "__main__":
engine = create_engine("sqlite://", echo=True)
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
session.add(User(name="user_a", posts=[Post(content="post_a")]))
session.add(User(name="user_b"))
session.add(User(name="user_c"))
session.add(Admin(name="admin_a", posts=[Post(content="post_b")]))
session.add(Moderator(name="moderator_a", can_ban_users=True))
session.commit()
users = session.exec(select(User)).all()
admins = session.exec(select(Admin)).all()
posts = session.exec(select(Post)).all()
print(f"{users=}")
print(f"{admins=}")
print(f"{posts=}")
@aidenprice
Copy link

Hi, could you give an example of the definition of BaseMetaclass? Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment