Created
April 9, 2014 05:36
-
-
Save allieus/10229030 to your computer and use it in GitHub Desktop.
http://techspot.zzzeek.org/2011/01/14/the-enum-recipe/ 에서 custom enum, enum array 를 등록하는 부분 보충
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
import re | |
from sqlalchemy import cast, literal | |
from sqlalchemy.dialects.postgresql import array | |
from sqlalchemy.types import SchemaType, TypeDecorator, Enum | |
from sqlalchemy.util import set_creation_order, OrderedDict | |
from psycopg2._psycopg import new_type, new_array_type, register_type | |
class DeclarativeEnum(object): | |
"Declarative enumeration." | |
__metaclass__ = DeclEnumMeta | |
_reg = OrderedDict() | |
@classmethod | |
def db_type(cls): | |
return DeclarativeEnumType(cls) | |
@classmethod | |
def from_string(cls, value): | |
try: | |
return cls._reg[value] | |
except KeyError: | |
raise ValueError("Invalid value for {!r}: {!r}".format(cls.__name__, value)) | |
@classmethod | |
def names(cls): | |
return cls._reg.keys() | |
@classmethod | |
def choices(cls): | |
return cls._reg.items() | |
# REF: psycopg2/_json.py | |
@classmethod | |
def register(cls, db): | |
typname = str(cls.db_type().impl.name) | |
try: | |
result = db.session.execute("SELECT t.oid, typarray FROM pg_type t WHERE t.typname = '{}';".format(typname)) | |
oid, array_oid = result.cursor.fetchone() | |
_type = new_type((oid, ), typname, lambda s, cur: s) | |
register_type(_type) | |
if array_oid: | |
result = db.session.execute("SELECT typname FROM pg_type t WHERE t.oid = '{}';".format(array_oid)) | |
array_typname = str(result.cursor.fetchone()[0]) | |
_type_array = new_array_type((array_oid,), array_typname, _type) | |
register_type(_type_array) | |
db.session.rollback() | |
except TypeError: | |
pass | |
class DeclarativeEnumArrayType(TypeDecorator): | |
enumTypeCls = None | |
impl = None | |
def bind_expression(self, bind_value): | |
val = bind_value.effective_value | |
if val is None: | |
val = [] | |
elif not hasattr(val, '__iter__'): | |
return cast(bind_value, self.__class__) | |
return array(cast(literal(str(ele)), self.__class__.enumTypeCls.db_type()) for ele in val) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment