Created
September 13, 2018 08:55
-
-
Save fcharmy/7e63c345a898861035c2aef631752e10 to your computer and use it in GitHub Desktop.
sqlalchemy models
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from sqlalchemy import create_engine, MetaData, Table | |
from sqlalchemy.engine import reflection | |
from sqlalchemy.orm import mapper, sessionmaker | |
logger = logging.getLogger(__name__) | |
def get_session(host=None, port=3306, db=None, | |
user=None, password=None, charset='utf8mb4'): | |
""" | |
initial mapping with class and db tables | |
:param host: mysql db host | |
:param port: mysql db port | |
:param db: database name | |
:param user: user name | |
:param password: user password | |
:param charset: charset | |
:return: DBase class with session attribute and table attributes | |
""" | |
engine = create_engine( | |
'mysql+pymysql://%s:%s@%s:%d/%s?charset=%s' % | |
(user, password, host, port, db, charset)) | |
metadata = MetaData(engine) | |
session = sessionmaker(bind=engine)() | |
insp = reflection.Inspector.from_engine(engine) | |
logger.info("Connect to DB: %s@%s" % (db, host)) | |
# create table class for mapper, then map table to class | |
classes = dict() | |
for t in insp.get_table_names(): | |
table = Table(t, metadata, autoload=True) | |
new_class = type( | |
''.join([str(t), '_', db]), | |
(Model,), { | |
't': table, | |
's': session, | |
'insp': insp | |
}) | |
mapper(new_class, table) | |
instance = new_class() | |
classes[t] = instance | |
logger.debug("Map table %s to %s" % (t, new_class)) | |
db_session = DBase(engine, session, **classes) | |
return db_session | |
class DBase(object): | |
def __init__(self, engine, session, **kwargs): | |
self.engine = engine | |
self.session = session | |
for key, val in kwargs.items(): | |
setattr(self, key, val) | |
def rollback(self): | |
self.session.rollback() | |
def close(self): | |
self.session.close() | |
def execute(self, sql): | |
try: | |
with self.engine.connect() as con: | |
return con.execute(sql) | |
except: | |
self.rollback() | |
return None | |
class Model(object): | |
def __init__(self, **kwargs): | |
for key, val in kwargs.items(): | |
setattr(self, key, val) | |
if self.t.name: | |
self.objects = QuerySet( | |
self.__class__, self.t.name, self.s, self.insp) | |
class QuerySet(object): | |
def __init__(self, model, name, session, inspector): | |
self.s = session | |
self.insp = inspector | |
self.table_name = name | |
self.model = model | |
self.logger = logging.getLogger(__name__) | |
def get(self, **kwargs): | |
return self.s.query(self.model).filter_by(**kwargs).first() | |
def filter(self, **kwargs): | |
return self.s.query(self.model).filter_by(**kwargs) | |
def get_or_create(self, default={}, **kwargs): | |
""" | |
:param default: additional attributes if create | |
:param kwargs: attributes for query existence | |
:return: instance | |
""" | |
instance = self.get(**kwargs) | |
if instance is None: | |
return self.create(**kwargs, **default), True | |
return instance, False | |
def create(self, **kwargs): | |
instance = self.model(**kwargs) | |
self.s.add(instance) | |
self.save() | |
return instance | |
def all(self): | |
return self.s.query(self.model).all() | |
def save(self): | |
try: | |
self.s.commit() | |
except: | |
self.s.rollback() | |
raise ValueError('can not save instance') | |
def delete(self, **kwargs): | |
instance = self.model(**kwargs) | |
self.s.delete(instance) | |
self.save() | |
if __name__ == '__main__': | |
db = get_session(host='localhost', db='db', user='root', password="xxx") | |
# create a user | |
db.auth_user.objects.create(username='admin') | |
# get one instance | |
print(db.auth_user.objects.get(username='admin').id) | |
# filter query | |
for u in db.auth_user.objects.filter(is_active=1): | |
print(u.username) | |
# get or create a user | |
user, _ = db.auth_user.objects.get_or_create(username='admin') | |
# update fields | |
user.first_name = 'superman' | |
db.auth_user.objects.save() | |
print(db.auth_user.objects.get(username='admin').first_name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment