Last active
January 9, 2023 12:24
Postgres composites for SQLAlchemy 1.4
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import TypeAlias, Type | |
from frozendict import frozendict | |
import sqlalchemy as sa | |
import sqlalchemy.types | |
import sqlalchemy.event | |
import sqlalchemy.schema | |
import sqlalchemy.sql.sqltypes | |
__all__ = "define_composite", "CreateComposite", "DropComposite" | |
class _CompositeField(sa.sql.expression.FunctionElement): | |
cache_ok, inherit_cache = True, True | |
def __init__(self, base, name: str, type: sa.types.TypeEngine) -> None: | |
self.name = name | |
super().__init__(base) | |
self.type = type | |
@sa.ext.compiler.compiles(_CompositeField) | |
def _compile_composite_field(expr, compiler, **kw): | |
return f"({compiler.process(expr.clauses, **kw)}).{expr.name}" | |
_ColumnDef: TypeAlias = tuple[str, Type[sa.types.TypeEngine] | sa.types.TypeEngine] | |
class _CompositeType(sa.types.UserDefinedType, sa.sql.sqltypes.SchemaType): | |
cache_ok = True | |
python_type = tuple | |
def __init__(self, name: str, fields: frozendict[str, sa.types.TypeEngine]): | |
sa.sql.sqltypes.SchemaType.__init__(self, name=name) | |
sa.types.UserDefinedType.__init__(self) | |
self.fields = fields | |
class comparator_factory(sa.types.UserDefinedType.Comparator): | |
def __getattr__(self, key): | |
try: | |
return _CompositeField(self.expr, key, self.type.fields[key]) | |
except KeyError: | |
raise AttributeError( | |
f"{self.type.name} ({type(self).__name__}) doesn't have an attribute named '{key}'" | |
) | |
def get_col_spec(self): | |
return self.name | |
def bind_processor(self, dialect): | |
def process(value): | |
if value is None: | |
return None | |
processed_value = [] | |
for i, name, type_ in enumerate(self.fields.items()): | |
current_value = value.get(name) if isinstance(value, dict) else value[i] | |
if isinstance(type_, sa.sql.sqltypes.TypeDecorator): | |
processed_value.append( | |
type_.process_bind_param( | |
current_value, dialect | |
) | |
) | |
else: | |
processed_value.append(current_value) | |
return self.python_type(*processed_value) | |
return process | |
def create(self, bind=None, checkfirst=None): | |
if not checkfirst or not bind.dialect.has_type(bind, self.name, schema=self.schema): | |
bind.execute(CreateComposite(self)) | |
def drop(self, bind=None, checkfirst=True): | |
if checkfirst and bind.dialect.has_type(bind, self.name, schema=self.schema): | |
bind.execute(DropComposite(self)) | |
def define_composite(name: str, metadata: sa.MetaData, *fields: _ColumnDef) -> _CompositeType: | |
composite = _CompositeType(name, frozendict((n, sa.sql.sqltypes.to_instance(t)) for n, t in fields)) | |
@sa.event.listens_for(metadata, "after_create") | |
def after_create(_: sa.MetaData, connection: sa.engine.Connection, checkfirst: bool, **kwargs): | |
composite.create(connection, checkfirst) | |
@sa.event.listens_for(metadata, "after_drop") | |
def after_drop(_: sa.MetaData, connection: sa.engine.Connection, checkfirst: bool, **kwargs): | |
composite.drop(connection, checkfirst) | |
return composite | |
class CreateComposite(sa.schema._CreateDropBase): | |
pass | |
@sa.ext.compiler.compiles(CreateComposite) | |
def _compile_create_composite(create, compiler, **kwargs): | |
return "CREATE TYPE {name} AS ({fields})".format( | |
name=compiler.preparer.format_type(create.element), | |
fields=", ".join("{name} {type}".format( | |
name=n, | |
type=compiler.dialect.type_compiler.process(sa.sql.sqltypes.to_instance(t)), | |
) for n, t in create.element.fields.items()) | |
) | |
class DropComposite(sa.schema._CreateDropBase): | |
def __init__(self, type: _CompositeType, cascade: bool = False): | |
super().__init__(type) | |
self.cascade = cascade | |
@sa.ext.compiler.compiles(DropComposite) | |
def _compile_drop_composite(drop: DropComposite, compiler, **kwargs): | |
type_ = drop.element | |
result = f"DROP TYPE {compiler.preparer.format_type(type_)}" | |
if drop.cascade: | |
result += " CASCADE" | |
return result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Inspired by this question on StackOverflow. Usage example can be found in the answer.