Skip to content

Instantly share code, notes, and snippets.

@SF-300
Last active January 9, 2023 12:24
Postgres composites for SQLAlchemy 1.4
from typing import TypeAlias, Type
from frozendict import frozendict
import sqlalchemy as sa
import sqlalchemy.types
import sqlalchemy.event
import sqlalchemy.schema
import sqlalchemy.sql.sqltypes
__all__ = "define_composite", "CreateComposite", "DropComposite"
class _CompositeField(sa.sql.expression.FunctionElement):
cache_ok, inherit_cache = True, True
def __init__(self, base, name: str, type: sa.types.TypeEngine) -> None:
self.name = name
super().__init__(base)
self.type = type
@sa.ext.compiler.compiles(_CompositeField)
def _compile_composite_field(expr, compiler, **kw):
return f"({compiler.process(expr.clauses, **kw)}).{expr.name}"
_ColumnDef: TypeAlias = tuple[str, Type[sa.types.TypeEngine] | sa.types.TypeEngine]
class _CompositeType(sa.types.UserDefinedType, sa.sql.sqltypes.SchemaType):
cache_ok = True
python_type = tuple
def __init__(self, name: str, fields: frozendict[str, sa.types.TypeEngine]):
sa.sql.sqltypes.SchemaType.__init__(self, name=name)
sa.types.UserDefinedType.__init__(self)
self.fields = fields
class comparator_factory(sa.types.UserDefinedType.Comparator):
def __getattr__(self, key):
try:
return _CompositeField(self.expr, key, self.type.fields[key])
except KeyError:
raise AttributeError(
f"{self.type.name} ({type(self).__name__}) doesn't have an attribute named '{key}'"
)
def get_col_spec(self):
return self.name
def bind_processor(self, dialect):
def process(value):
if value is None:
return None
processed_value = []
for i, name, type_ in enumerate(self.fields.items()):
current_value = value.get(name) if isinstance(value, dict) else value[i]
if isinstance(type_, sa.sql.sqltypes.TypeDecorator):
processed_value.append(
type_.process_bind_param(
current_value, dialect
)
)
else:
processed_value.append(current_value)
return self.python_type(*processed_value)
return process
def create(self, bind=None, checkfirst=None):
if not checkfirst or not bind.dialect.has_type(bind, self.name, schema=self.schema):
bind.execute(CreateComposite(self))
def drop(self, bind=None, checkfirst=True):
if checkfirst and bind.dialect.has_type(bind, self.name, schema=self.schema):
bind.execute(DropComposite(self))
def define_composite(name: str, metadata: sa.MetaData, *fields: _ColumnDef) -> _CompositeType:
composite = _CompositeType(name, frozendict((n, sa.sql.sqltypes.to_instance(t)) for n, t in fields))
@sa.event.listens_for(metadata, "after_create")
def after_create(_: sa.MetaData, connection: sa.engine.Connection, checkfirst: bool, **kwargs):
composite.create(connection, checkfirst)
@sa.event.listens_for(metadata, "after_drop")
def after_drop(_: sa.MetaData, connection: sa.engine.Connection, checkfirst: bool, **kwargs):
composite.drop(connection, checkfirst)
return composite
class CreateComposite(sa.schema._CreateDropBase):
pass
@sa.ext.compiler.compiles(CreateComposite)
def _compile_create_composite(create, compiler, **kwargs):
return "CREATE TYPE {name} AS ({fields})".format(
name=compiler.preparer.format_type(create.element),
fields=", ".join("{name} {type}".format(
name=n,
type=compiler.dialect.type_compiler.process(sa.sql.sqltypes.to_instance(t)),
) for n, t in create.element.fields.items())
)
class DropComposite(sa.schema._CreateDropBase):
def __init__(self, type: _CompositeType, cascade: bool = False):
super().__init__(type)
self.cascade = cascade
@sa.ext.compiler.compiles(DropComposite)
def _compile_drop_composite(drop: DropComposite, compiler, **kwargs):
type_ = drop.element
result = f"DROP TYPE {compiler.preparer.format_type(type_)}"
if drop.cascade:
result += " CASCADE"
return result
@SF-300
Copy link
Author

SF-300 commented Jan 9, 2023

Inspired by this question on StackOverflow. Usage example can be found in the answer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment