Skip to content

Instantly share code, notes, and snippets.

@bigfang
Created November 13, 2012 05:01
Show Gist options
  • Save bigfang/4063999 to your computer and use it in GitHub Desktop.
Save bigfang/4063999 to your computer and use it in GitHub Desktop.
数据库insert和update操作封装,可以改进的地方有:减少连接次数,DB和Store合体,简化update...
from functools import partial
import MySQLdb
ip = '192.168.1.42'
class DB(object):
def __init__(self, host=ip, user='xxoo', pwd=''):
self.host = host
self.user = user
self.pwd = pwd
def get_conn(self, db):
try:
cxn = MySQLdb.connect(host = self.host,
port = 3306,
user = self.user,
passwd = self.pwd,
charset = 'utf8', use_unicode = 'true',
db = db)
return cxn
except:
raise
def query(self, db, q, *args):
try:
cxn = MySQLdb.connect(host = self.host,
port = 3306,
user = self.user,
passwd = self.pwd,
charset = 'utf8', use_unicode = 'true',
db = db)
cur = cxn.cursor()
cur.execute(q, args)
res = cur.fetchall()
cxn.commit()
except:
raise
finally:
cur.close()
cxn.close()
return res
def order(self, db, o, arg):
try:
cxn = MySQLdb.connect(host = self.host,
port = 3306,
user = self.user,
passwd = self.pwd,
charset = 'utf8', use_unicode = 'true',
db = db)
cur = cxn.cursor()
ret = cur.executemany(o, arg)
cxn.commit()
except:
raise
finally:
cur.close()
cxn.close()
return ret
def log_with_helper(logger):
def store(func):
def inner(data, *args):
store_data = func(data, *args)
if store_data:
if func.__name__.split('_')[0] == 'ins':
logger.debug('%s store length: %s' % (func.__name__, len(store_data)))
elif func.__name__.split('_')[0] == 'up':
logger.debug('%s update length: %s' % (func.__name__, len(store_data[1])))
else:
logger.warning('%s store length: 0' % func.__name__)
return store_data
return inner
return store
class StoreHelper(object):
def __getattr__(self, name):
def fn(data, *args):
return data
fn.func_name = name
return fn
class Store(object):
"""database schema format:
{table_1: (col_1, col_2, col_3),
table_2: (col_1, col_2),
...}
"""
def __init__(self, db, helper=StoreHelper(), host=ip):
self.__host = host
self.__db = db
self.__helper = helper
self.__schema = self.__detect_schema()
def __detect_schema(self):
schema = {}
try:
r = DB(self.__host).query(self.__db, 'show tables')
except:
raise
tables = [i[0] for i in r]
for table in tables:
try:
r = DB(self.__host).query(self.__db, 'desc `%s`' % table)
except:
raise
schema[table] = tuple([col[0] for col in r])
return schema
def __getattr__(self, k):
op = k.split('_')[0]
table = k[k.index('_')+1:]
if table in self.__schema.keys():
pass
elif table.upper() in self.__schema.keys():
table = table.upper()
else:
raise Exception('Table Name Error!')
if op == 'ins':
return partial(self.__insert, table)
elif op == 'up':
return partial(self.__update, table)
else:
raise Exception('Operation Error!')
def __insert(self, table, data, *args):
fields = ['`%s`' % i for i in self.__schema.get(table)] # wrap keyword fields use backquote
sql = 'INSERT IGNORE INTO `%s` (%s) VALUES (%s)' % (table, ','.join(fields), ('%s,'*len(fields))[:-1])
store_data = getattr(self.__helper, 'ins_%s' % table.lower())(data, *args)
try:
DB(self.__host).order(self.__db, sql, store_data)
except:
raise
def __update(self, table, data, *args):
cols, store_data = getattr(self.__helper, 'up_%s' % table.lower())(data, *args)
presql = 'UPDATE `%s` SET %s WHERE %s' % (table, ('`%s`=%%s,'*len(cols[0]))[:-1], ('`%s`=%%s and '*len(cols[-1]))[:-5])
sql = presql % tuple(cols[0] + cols[1])
try:
DB(self.__host).order(self.__db, sql, store_data)
except:
raise
if __name__ == '__main__':
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment