Skip to content

Instantly share code, notes, and snippets.

@yasamoka
Last active October 22, 2022 02:20
Show Gist options
  • Save yasamoka/241c79782d88cb05b33b5dc7be41e222 to your computer and use it in GitHub Desktop.
Save yasamoka/241c79782d88cb05b33b5dc7be41e222 to your computer and use it in GitHub Desktop.
SQLAlchemy 2.0 - using declarative mixins with dataclasses example - extended with relationships taking advantage of PEP 681 (dataclass transforms) for synthesis of __init__ based on fields
# Original example at: https://docs.sqlalchemy.org/en/20/orm/dataclasses.html
from dataclasses import dataclass, field
from typing import List, Optional
from sqlalchemy.engine import create_engine
from sqlalchemy.orm import (
Session,
mapped_column,
registry,
relationship,
)
from sqlalchemy.sql.schema import ForeignKey
from sqlalchemy.types import Integer, String
mapper_registry = registry()
@dataclass
class UserMixin:
__tablename__ = "user"
__sa_dataclass_metadata_key__ = "sa"
id: int = field(
init=False, metadata={"sa": mapped_column(Integer, primary_key=True)}
)
name: str = field(init=True, metadata={"sa": mapped_column(String, nullable=False)})
addresses: List["Address"] = field(
init=True,
default_factory=list,
metadata={"sa": lambda: relationship("Address", back_populates="user")},
)
@dataclass
class AddressMixin:
__tablename__ = "address"
__sa_dataclass_metadata_key__ = "sa"
id: int = field(
init=False, metadata={"sa": mapped_column(Integer, primary_key=True)}
)
user_id: int = field(
init=False,
metadata={"sa": mapped_column(ForeignKey("user.id"))},
)
email_address: str = field(init=True, metadata={"sa": mapped_column(String)})
user: Optional["User"] = field(
init=True,
default=None,
metadata={"sa": lambda: relationship("User", back_populates="addresses")},
)
@mapper_registry.mapped
class User(UserMixin):
pass
@mapper_registry.mapped
class Address(AddressMixin):
pass
user1 = User(name="Name 1")
address1 = Address(email_address="email1@gmail.com", user=user1)
address2 = Address(email_address="email2@gmail.com")
user2 = User(name="Name 2", addresses=[address2])
address3 = Address(email_address="email3@gmail.com")
user3 = User(name="Name 3")
user3.addresses.append(address3)
engine = create_engine(url="sqlite:///test.db")
mapper_registry.metadata.create_all(engine)
with Session(engine) as session:
session.add_all([user1, user2, user3, address1, address2, address3])
session.commit()
# Original example at: https://docs.sqlalchemy.org/en/20/orm/dataclasses.html
# Requires modification of:
# module: sqlalchemy/orm/decl_base.py
# class: _ClassScanMapperConfig
# method: _scan_attributes
# lines:
# if not isinstance(ret, InspectionAttr):
# - ret = obj.fget()
# + ret = obj.fget(cls)
from dataclasses import dataclass, field
from typing import ClassVar, Generic, List, Optional, Protocol, Type, TypeVar
from sqlalchemy.engine import create_engine
from sqlalchemy.orm import (
Session,
mapped_column,
registry,
relationship,
)
from sqlalchemy.sql.schema import ForeignKey
from sqlalchemy.types import Integer, String
mapper_registry = registry()
AddressModel = TypeVar("AddressModel")
class UserMixinProtocol(Protocol):
address_model_name: ClassVar[str]
def load_addresses_relationship(cls: Type[UserMixinProtocol]):
return relationship(cls.address_model_name, back_populates="user")
@dataclass
class UserMixin(UserMixinProtocol, Generic[AddressModel]):
__tablename__ = "user"
__sa_dataclass_metadata_key__ = "sa"
address_model_name: ClassVar[str]
id: int = field(
init=False, metadata={"sa": mapped_column(Integer, primary_key=True)}
)
name: str = field(init=True, metadata={"sa": mapped_column(String, nullable=False)})
addresses: List[AddressModel] = field(
init=True,
default_factory=list,
metadata={"sa": load_addresses_relationship},
)
UserModel = TypeVar("UserModel")
class AddressMixinProtocol(Protocol):
user_model_name: ClassVar[str]
def load_user_relationship(cls: Type[AddressMixinProtocol]):
return relationship(cls.user_model_name, back_populates="addresses")
@dataclass
class AddressMixin(AddressMixinProtocol, Generic[UserModel]):
__tablename__ = "address"
__sa_dataclass_metadata_key__ = "sa"
user_model_name: ClassVar[str]
id: int = field(
init=False, metadata={"sa": mapped_column(Integer, primary_key=True)}
)
user_id: int = field(
init=False,
metadata={"sa": mapped_column(ForeignKey("user.id"))},
)
email_address: str = field(init=True, metadata={"sa": mapped_column(String)})
user: Optional[UserModel] = field(
init=True, default=None, metadata={"sa": load_user_relationship}
)
@mapper_registry.mapped
class User(UserMixin["Address"]):
address_model_name = "Address"
@mapper_registry.mapped
class Address(AddressMixin[User]):
user_model_name = "User"
user1 = User(name="Name 1")
address1 = Address(email_address="email1@gmail.com", user=user1)
address2 = Address(email_address="email2@gmail.com")
user2 = User(name="Name 2", addresses=[address2])
address3 = Address(email_address="email3@gmail.com")
user3 = User(name="Name 3")
user3.addresses.append(address3)
engine = create_engine(url="sqlite:///test.db")
mapper_registry.metadata.create_all(engine)
with Session(engine) as session:
session.add_all([user1, user2, user3, address1, address2, address3])
session.commit()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment