Skip to content

Instantly share code, notes, and snippets.

@allieus
Created April 9, 2014 05:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save allieus/10229030 to your computer and use it in GitHub Desktop.
Save allieus/10229030 to your computer and use it in GitHub Desktop.
http://techspot.zzzeek.org/2011/01/14/the-enum-recipe/ 에서 custom enum, enum array 를 등록하는 부분 보충
# coding: utf-8
import re
from sqlalchemy import cast, literal
from sqlalchemy.dialects.postgresql import array
from sqlalchemy.types import SchemaType, TypeDecorator, Enum
from sqlalchemy.util import set_creation_order, OrderedDict
from psycopg2._psycopg import new_type, new_array_type, register_type
class DeclarativeEnum(object):
"Declarative enumeration."
__metaclass__ = DeclEnumMeta
_reg = OrderedDict()
@classmethod
def db_type(cls):
return DeclarativeEnumType(cls)
@classmethod
def from_string(cls, value):
try:
return cls._reg[value]
except KeyError:
raise ValueError("Invalid value for {!r}: {!r}".format(cls.__name__, value))
@classmethod
def names(cls):
return cls._reg.keys()
@classmethod
def choices(cls):
return cls._reg.items()
# REF: psycopg2/_json.py
@classmethod
def register(cls, db):
typname = str(cls.db_type().impl.name)
try:
result = db.session.execute("SELECT t.oid, typarray FROM pg_type t WHERE t.typname = '{}';".format(typname))
oid, array_oid = result.cursor.fetchone()
_type = new_type((oid, ), typname, lambda s, cur: s)
register_type(_type)
if array_oid:
result = db.session.execute("SELECT typname FROM pg_type t WHERE t.oid = '{}';".format(array_oid))
array_typname = str(result.cursor.fetchone()[0])
_type_array = new_array_type((array_oid,), array_typname, _type)
register_type(_type_array)
db.session.rollback()
except TypeError:
pass
class DeclarativeEnumArrayType(TypeDecorator):
enumTypeCls = None
impl = None
def bind_expression(self, bind_value):
val = bind_value.effective_value
if val is None:
val = []
elif not hasattr(val, '__iter__'):
return cast(bind_value, self.__class__)
return array(cast(literal(str(ele)), self.__class__.enumTypeCls.db_type()) for ele in val)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment