Created
August 8, 2019 23:00
-
-
Save mjbryant/51d69ca04faa9fd7cd1a6b00f1e22a45 to your computer and use it in GitHub Desktop.
SQLAlchemy deals with Enums oddly
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import subprocess | |
from sqlalchemy import create_engine | |
from sqlalchemy import Column | |
from sqlalchemy import Enum | |
from sqlalchemy import types | |
from sqlalchemy.exc import DataError | |
from sqlalchemy.ext.declarative import declarative_base | |
from sqlalchemy.orm import sessionmaker | |
DATABASE_NAME = 'sqlalchemy_enum_test' | |
ENUM_VALUES = ('a', 'b', 'c') | |
ENUM_NAME = 'me_thing_enum' | |
engine = create_engine('postgresql://localhost:5432/{}'.format(DATABASE_NAME)) | |
Session = sessionmaker(engine) | |
Base = declarative_base(engine) | |
class Me(Base): | |
__tablename__ = 'me' | |
id = Column(types.Integer, primary_key=True) | |
thing = Column(Enum(*ENUM_VALUES, name=ENUM_NAME)) | |
def create_tables(): | |
Base.metadata.create_all(engine) | |
def recreate_database(): | |
subprocess.call(['psql', '-c', 'DROP DATABASE {}'.format(DATABASE_NAME)]) | |
subprocess.call(['psql', '-c', 'CREATE DATABASE {}'.format(DATABASE_NAME)]) | |
create_tables() | |
def add_enum_value(value='d'): | |
subprocess.call([ | |
'psql', | |
'-d', | |
DATABASE_NAME, | |
'-c', | |
"ALTER TYPE {} ADD VALUE IF NOT EXISTS '{}';".format(ENUM_NAME, value) | |
]) | |
def insert(*things): | |
session = Session() | |
for thing in things: | |
session.add(Me(thing=thing)) | |
session.commit() | |
def get_all(): | |
session = Session() | |
return session.query(Me).all() | |
def test(): | |
recreate_database() | |
assert len(get_all()) == 0 | |
insert('a', 'b') | |
assert len(get_all()) == 2 | |
try: | |
insert('d') | |
except DataError: | |
pass | |
add_enum_value('d') | |
insert('d') | |
try: | |
get_all() | |
except: | |
print('WTF SQLAlchemy?') | |
if __name__ == '__main__': | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment