Last active
August 29, 2015 14:07
-
-
Save javasboy/06eccfbb2c6069f442e6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/ali/bin/python | |
# coding=utf-8 | |
'''Implements a simple database interface | |
Example 0: Create connection: | |
# Set auto commit to false | |
db = DB(False, host = 'x', user = 'x', passwd = 'x', db = 'x') | |
Example 1: Select SQL | |
a. Select the first two rows from ip table: | |
# normal select | |
db.select('select * from ip limit 2') | |
# add a where condition: | |
db.select('select * from ip where name != %s limit 2', ('0')) | |
b. Select all results but get only the first two: | |
db.execute('select * from ip') | |
# get dict rows | |
db.get_rows(2, is_dict = True) | |
Example 2: Insert/Replace SQL | |
a. Insert a new record into ip table: | |
db.insert('ip', address='192.168.0.1', name='xxx') | |
db.commit() | |
b. Insert multi-records into ip table: | |
db.multi_insert('ip', ('address','name'), [('192.168.0.1', 'xxx'), | |
('192.168.0.2', 'yyy'), ('192.168.0.3', 'zzz')]) | |
db.commit() | |
Example 3: Update SQL | |
a. Update the address of row whose name is xxx: | |
db.update('ip', where = 'name = "xxx", address = '192.168.0.1') | |
db.commit() | |
Example 4: Delete SQL | |
a. Delete the row whose name is 'xxx': | |
db.delete('ip', where = 'name = "xxx"') | |
db.commit() | |
Example 5: Debug SQL | |
db = DB(False, debug = True, host = 'x', user = 'x', passwd = 'x', db = 'x') | |
print h.execute('select name from ip where address = %s and id = %s', ('10.10.10.10', 1)) | |
"select name from ip where address='10.10.10.10' and id = 1" | |
''' | |
# Can be 'Prototype', 'Development', 'Product' | |
__status__ = 'Development' | |
__author__ = 'tuantuan.lv <dangoakachan@foxmail.com>' | |
import sys | |
import MySQLdb | |
# see https://gist.github.com/3883162 | |
from storage import Strorage | |
def not_supported(func_name): | |
'''Generate a not supported function warning.''' | |
def new_func(*args, **kwargs): | |
raise DeprecationWarning, '%s is not supported in debug mode' % func_name | |
return new_func | |
def _format(sql): | |
'''Format the sql.''' | |
return ' '.join(sql.split()) | |
class DB(): | |
'''A simple database query interface.''' | |
def __init__(self, auto_commit = False, debug = False, **kwargs): | |
if 'charset' not in kwargs: | |
kwargs['charset'] = 'utf8' | |
self.conn = MySQLdb.connect(**kwargs) | |
self.cursor = self.conn.cursor() | |
self.autocommit(auto_commit) | |
self._debug = debug | |
if debug: | |
self.execute = self.execute_dbg | |
self.select = not_supported('select') | |
self.executemany = not_supported('executemany') | |
self.multi_insert = not_supported('multi_insert') | |
def execute(self, sql, args = None): | |
'''Execute q sql.''' | |
return self.cursor.execute(_format(sql), args) | |
def execute_dbg(self, sql, args = None): | |
'''Print executed sql, for debug only.''' | |
del self.cursor.messages[:] | |
db = self.cursor._get_db() | |
charset = db.character_set_name() | |
if isinstance(sql, unicode): | |
sql = sql.encode(charset) | |
if args is not None: | |
sql = sql % db.literal(args) | |
return _format(sql) | |
def executemany(self, sql, args): | |
'''Execute a multi-row query.''' | |
return self.cursor.executemany(_format(sql), args) | |
# Execute a multi-row insert, the same as executemany | |
multi_insert = executemany | |
def select(self, sql, args = None, size = None, is_dict = False): | |
'''Execute a select sql.''' | |
self.execute(sql, args) | |
return self.get_rows(size, is_dict) | |
def insert(self, table, **column_dict): | |
'''Execute a insert sql.''' | |
keys = '`,`'.join(column_dict.keys()) | |
values = column_dict.values() | |
placeholder = ','.join([ '%s' for v in column_dict.values() ]) | |
ins_sql = 'INSERT INTO %(table)s (`%(keys)s`) VALUES (%(placeholder)s)' | |
return self.execute(ins_sql % locals(), values) | |
def update(self, table, where, args = [], **column_dict): | |
'''Execute a update sql.''' | |
set_stmt = ','.join([ '`%s`=%%s' % k for k in column_dict.keys() ]) | |
args = column_dict.values() + args | |
upd_sql = 'UPDATE %(table)s SET %(set_stmt)s WHERE %(where)s' | |
return self.execute(upd_sql % locals(), args) | |
def delete(self, table, where, args = None): | |
'''Execute a delete sql.''' | |
del_sql = 'DELETE FROM %(table)s WHERE %(where)s' % locals() | |
return self.execute(del_sql % locals(), args) | |
def get_rows(self, size = None, is_dict = False): | |
'''Get the result rows after executing.''' | |
if size is None: | |
rows = self.cursor.fetchall() | |
else: | |
rows = self.cursor.fetchmany(size) | |
if rows is None: | |
rows = [] | |
if is_dict: | |
dict_rows = [] | |
dict_keys = [ r[0] for r in self.cursor.description ] | |
for row in rows: | |
dict_rows.append(Storage(zip(dict_keys, row))) | |
rows = dict_rows | |
return rows | |
def get_rows_num(self): | |
'''Get the count of result rows.''' | |
return self.cursor.rowcount | |
def get_mysql_version(self): | |
'''Get the mysql version.''' | |
MySQLdb.get_client_info() | |
def autocommit(self, flag): | |
'''Set auto commit mode.''' | |
self.conn.autocommit(flag) | |
def commit(self): | |
'''Commits the current transaction.''' | |
self.conn.commit() | |
def __del__(self): | |
#self.commit() | |
self.close() | |
def close(self): | |
self.cursor.close() | |
self.conn.close() | |
# vim: set expandtab smarttab shiftwidth=4 tabstop=4: |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment