Skip to content

Instantly share code, notes, and snippets.

@rsj217
Created May 6, 2014 15:36
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save rsj217/a05adb7b71638060d6f4 to your computer and use it in GitHub Desktop.
Save rsj217/a05adb7b71638060d6f4 to your computer and use it in GitHub Desktop.
a python class to use mysql. mysql 的连接关闭查询的python封装
# -*- coding: utf-8 -*-
__author__ = 'ghost'
import time, uuid, functools, threading, logging
def next_id(t=None):
"""
"""
if t is None:
t = time.time()
return '%015d%s000' % (int(t * 1000), uuid.uuid4().hex)
def _profiling(start, sql=''):
t = time.time() - start
if t > 0.1:
logging.warning('[PROFILING] [DB] %s: %s' % (t, sql))
else:
logging.info('[PROFILING] [DB] %s: %s' % (t, sql))
class DBError(Exception):
"""
数据异常类
"""
pass
class MultiColumnsError(DBError):
"""
"""
pass
class Dict(dict):
"""
增强型字典,继承原有的字典,可以将两个列表打包成字典,实现 dict(zip(list1, list2))
>>> d1 = Dict()
>>> type(d1)
<class '__main__.Dict'>
>>> d1['name'] = 'python'
>>> d1.name
'python'
>>> d1['age'] = 13
>>> d1['age']
13
>>> d1.get('name')
'python'
>>> d1.get('lll', 0)
0
>>> d2 = Dict(name='python', age=13)
>>> d2['name']
'python'
>>> d2['none']
Traceback (most recent call last):
...
KeyError: 'none'
>>> d2.name
'python'
>>> d2.none
Traceback (most recent call last):
...
AttributeError: 'Dict' object has no attribute 'none'
>>> d3 = Dict(('name', 'age'), ('python', 13), isgood=True)
>>> d3
{'isgood': True, 'age': 13, 'name': 'python'}
"""
def __init__(self, names=(), values=(), **kwargs):
super(Dict, self).__init__(**kwargs)
self.update(dict(zip(names, values)))
def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError(r"'Dict' object has no attribute '%s'" % key)
def __setattr__(self, key, value):
self[key] = value
class _LasyConnection(object):
"""
获取数据库引擎`连接资源句柄connection`
通过connection获取cursor
操作 commit, rollback
关闭连接 cleanup
"""
def __init__(self):
self.connection = None
def cursor(self):
if self.connection is None:
connection = engine.connect()
logging.info('open connection <%s>...' % hex(id(connection)))
self.connection = connection
return self.connection.cursor()
def commit(self):
self.connection.commit()
def rollback(self):
self.connection.rollback()
def cleanup(self):
if self.connection:
connection = self.connection
self.connection = None
logging.info('close connection <%s>...' % hex(id(connection)))
connection.close()
class _DbCtx(threading.local):
"""
数据库上下文操作类,实例全局数据库上下文实例 `_db_Ctx`
主要提供给 `_ConnectionCtx` 进行判断 connection 是否初始化`is_init`,进行初始化`init`和关闭`cleanup`
"""
def __init__(self):
self.connection = None
self.transactions = 0
def is_init(self):
return not self.connection is None
def init(self):
logging.info('open lazy connections...')
self.connection = _LasyConnection()
self.transactions = 0
def cleanup(self):
self.connection.cleanup()
self.connection = None
def cursor(self):
return self.connection.cursor()
class _Engine(object):
"""
数据库引擎类,用于连接数据库
"""
def __init__(self, connect):
self._connect = connect
def connect(self):
return self._connect()
# 数据库上下文操作连接
_db_ctx = _DbCtx()
# 全局数据库引擎对象
engine = None
def create_engine(user, passwd, db, host='127.0.0.1', port=3306, **kwargs):
"""
创建数据库引擎,实现全局对象 `engine`
"""
import MySQLdb
global engine
if engine is not None:
raise DBError('Engine is already initialized.')
# 连接参数
params = dict(user=user, passwd=passwd, db=db, host=host, port=port)
# 默认的连接参数
defaults = dict(use_unicode=True, charset='utf8')
for k, v in defaults.iteritems():
params[k] = kwargs.pop(k, v)
# 通过函数参数更新连接参数
params.update(kwargs)
# 创建engine全局对象
engine = _Engine(lambda: MySQLdb.connect(**params))
logging.info('Init mysql engine <%s> ok.' % hex(id(engine)))
class _ConnectionCtx(object):
"""
打开关闭数据库上下文类,用于进行数据库操作时候,获取数据库引擎连接`connection`,操作结束后关闭连接
with _ConnectionCtx():
pass
"""
def __enter__(self):
global _db_ctx
self.should_cleanup = False
if not _db_ctx.is_init():
_db_ctx.init()
self.should_cleanup = True
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global _db_ctx
if self.should_cleanup:
_db_ctx.cleanup()
def connection():
"""
对`_ConnectionCtx`的封装, 提供对外接口
with connection():
do_some_db_operation()
"""
return _ConnectionCtx()
def with_connection(func):
"""
获取数据库连接和关闭装饰器
@with_connection
def foo(*args, **kwargs):
do_some_db_operation()
do_some_db_operation()
"""
@functools.wraps(func)
def _wrapper(*args, **kwargs):
with _ConnectionCtx():
return func(*args, **kwargs)
return _wrapper
class _TransactionCtx(object):
def __enter__(self):
global _db_ctx
self.should_close_conn = False
if not _db_ctx.is_init():
_db_ctx.init()
self.should_close_conn = True
_db_ctx.transactions += 1
logging.info('begin transaction...' if _db_ctx.transactions==1 else 'join current transaction...')
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global _db_ctx
_db_ctx.transactions -= 1
try:
if _db_ctx.transactions == 0:
if exc_type is None:
self.commit()
else:
self.rollback()
finally:
if self.should_close_conn:
_db_ctx.cleanup()
def commit(self):
global _db_ctx
logging.info('commit transaction...')
try:
_db_ctx.connection.commit()
logging.info('commit ok.')
except:
logging.warning('commit failed. try rollback...')
_db_ctx.connection.rollback()
logging.warning('rollback ok.')
raise
def rollback(self):
global _db_ctx
logging.warning('rollback transaction...')
_db_ctx.connection.rollback()
logging.info('rollback ok.')
def transaction():
"""
"""
return _TransactionCtx()
def with_transaction(func):
@functools.wraps(func)
def _wrapper(*args, **kwargs):
_start = time.time()
with _TransactionCtx():
return func(*args, **kwargs)
_profiling(_start)
return _wrapper
@with_connection
def _select(sql, first, *args):
"""
查询函数
"""
global _db_ctx
cursor = None
sql = sql.replace('?', '%s')
logging.info('SQL: %s, ARFS: %s' % (sql, args))
try:
# 通过数据库上下文获取查询游标`cursor`
cursor = _db_ctx.connection.cursor()
# 执行sql查询
cursor.execute(sql, args)
# 处理查询结果,返回 对象列表
if cursor.description:
names = [x[0] for x in cursor.description]
if first:
values = cursor.fetchone()
if not values:
return None
return Dict(names, values)
return [Dict(names, x) for x in cursor.fetchall()]
finally:
# 关闭游标
if cursor:
cursor.close()
@with_connection
def _update(sql, *args):
global _db_ctx
cursor = None
sql = sql.replace('?', '%s')
logging.info('SQL: %s, ARGS: %s' % (sql, args))
try:
cursor = _db_ctx.connection.cursor()
cursor.execute(sql, args)
r = cursor.rowcount
if _db_ctx.transactions == 0:
logging.info('auto commit')
_db_ctx.connection.commit()
return r
finally:
if cursor:
cursor.close()
def update(sql, *args):
return _update(sql, *args)
def insert(table, **kwargs):
"""
"""
cols, args = zip(*kwargs.iteritems())
sql = 'insert into `%s` (%s) values (%s)' % (table, ','.join(['`%s`' % col for col in cols]), ','.join(['?' for i in range(len(cols))]))
return _update(sql, *args)
def delete(sql, *args):
return _update(sql, *args)
def select_int(sql, *args):
d = _select(sql, True, *args)
if len(d) != 1:
raise MultiColumnsError('Expect only one column.')
return d.values()[0]
def select(sql, *args):
"""
"""
return _select(sql, False, *args)
def select_one(sql, *args):
"""
"""
return _select(sql, True, *args)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
create_engine(user='root', passwd='', db='webapp')
# update("DROP TABLE IF EXISTS user")
# update("CREATE TABLE user (id INT UNSIGNED NOT NULL PRIMARY KEY AUTO_INCREMENT, nickname VARCHAR(40), email VARCHAR(40), passwd VARCHAR(40), last_modified REAL)")
def test(rollback):
with transaction():
u = dict(nickname='test', email='test@gmail.com', passwd='test')
insert('user', **u)
r = update("UPDATE user SET nickname='chage' WHERE passwd='test'")
if rollback:
raise StandardError('will cause rollback...')
import doctest
doctest.testmod()
# -*- coding: utf-8 -*-
__author__ = 'ghost'
import logging, threading, functools, time
def _profiling(start, sql=''):
"""
"""
t = time.time() - start
if t > 0.1:
logging.warning('[PROFILING] [DB] %s: %s' % (t, sql))
else:
logging.info('[PROFILING] [DB] %s: %s' % (t, sql))
class Dict(dict):
""" 增强型字典,能够将序列一次打包成字典
>>> d = Dict(name='python', age=12)
>>> d
{'age': 12, 'name': 'python'}
>>> d['name']
'python'
>>> d.get('passwd', '1111')
'1111'
>>> d['passwd'] = 123
>>> d
{'passwd': 123, 'age': 12, 'name': 'python'}
>>>
"""
def __init__(self, names=(), values=(), **kwargs):
""" 初始化
@params names 元组或者列表序列
values 元组或者列表序列
kwargs 字典
"""
super(Dict, self).__init__(**kwargs)
self.update(dict(zip(names, values)))
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(r"'Dict' object has no attribute '%s'" % item)
def __setattr__(self, key, value):
self[key] = value
class DBError(Exception):
pass
class MultiColumnsError(DBError):
pass
class _LazyConnection(object):
"""
获取数据库引擎`连接资源句柄connection`
通过connection获取cursor
操作 commit, rollback
关闭连接 cleanup
"""
def __init__(self):
self.connection = None
def cursor(self):
""" 获取游标
"""
if self.connection is None:
connection = engine.connect()
logging.info('open connection <%s>...' % hex(id(connection)))
self.connection = connection
return self.connection.cursor()
def commit(self):
""" 提交session
"""
self.connection.commit()
def rollback(self):
""" 回滚
"""
self.connection.rollback()
def cleanup(self):
""" 释放连接
"""
if self.connection:
connection = self.connection
self.connection = None
logging.info('close connection <%s>...' % hex(id(connection)))
connection.close()
class _DbCtx(threading.local):
"""
数据库上下文操作类,用来生成全局数据库上下文实例 `_db_Ctx`
主要提供给 `_ConnectionCtx` 进行判断 connection 是否初始化`is_init`,
进行初始化`init`和关闭`cleanup`
"""
def __init__(self):
self.connection = None
self.transactions = 0
def is_init(self):
""" 判断数据库上下文连接是否存在"""
return not self.connection is None
def init(self):
""" 初始化数据库上下文连接"""
logging.info('open lazy connections...')
self.connection = _LazyConnection()
self.transactions = 0
def cleanup(self):
""" 清除数据库上下文连接"""
self.connection.cleanup()
self.connection = None
def cursor(self):
""" 获取上下文连接游标"""
return self.connection.cursor()
# 全局数据库引擎,用于获取连接
engine = None
# 全局数据库连接池对象
dbpool = None
# 全局数据库上下文对象
_db_ctx = _DbCtx()
class _Engine(object):
""" 数据库引擎对象,用于动态生成连接
"""
def __init__(self, connect):
self._connect = connect
def connect(self):
return self._connect()
class _ConnectCtx(object):
""" 数据库上下文打开关闭操作类,用于自动打开连接,清理连接
with _ConnectCtx():
pass
with _ConnectCtx():
pass
"""
def __enter__(self):
global _db_ctx
self.should_cleanup = False
# 初始化连接
if not _db_ctx.is_init():
_db_ctx.init()
self.should_cleanup = True
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global _db_ctx
if self.should_cleanup:
_db_ctx.cleanup()
def connection():
""" 对外封装的数据库上下文自动打开和关闭接口方法
with connection():
do_some_db_operation
"""
return _ConnectCtx()
def with_connection(func):
""" 获取数据库连接和关闭装饰器
@with_connection
def foo(*args, **kwargs):
do_some_db_operation()
do_some_db_operation()
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
with _ConnectCtx():
return func(*args, **kwargs)
return wrapper
class _TransactionCtx(object):
""" 事务上下文自动管理类,用于事务处理时候获取和关闭上下文
"""
def __enter__(self):
global _db_ctx
self.should_close_conn = False
if not _db_ctx.is_init():
_db_ctx.init()
self.should_close_conn = True
_db_ctx.transactions += 1
logging.info('begin transaction...' if _db_ctx.transactions==1 else 'join current transaction...')
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global _db_ctx
_db_ctx.transactions -= 1
try:
if _db_ctx.transactions == 0:
if exc_type is None:
self.commit()
else:
self.rollback()
finally:
if self.should_close_conn:
_db_ctx.cleanup()
def commit(self):
global _db_ctx
logging.info('commit transaction...')
try:
_db_ctx.connection.commit()
logging.info('commit ok.')
except:
logging.warning('commit failed. try rollback...')
_db_ctx.connection.rollback()
logging.warning('rollback ok.')
raise
def rollback(self):
global _db_ctx
logging.warning('rollback transaction...')
_db_ctx.connection.rollback()
logging.info('rollback ok.')
def transaction():
""" 事务上下文对外接口
with transaction():
pass
"""
return _TransactionCtx()
def with_transaction(func):
""" 事务上下文装饰器
@with_transaction
def foo():
pass
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
_start = time.time()
with _TransactionCtx():
return func(*args, **kwargs)
_profiling(_start)
return wrapper
@with_connection
def _update(sql, *args):
""" 执行 cud 操作的 sql 方法
@params sql string SQL查询语句
args tuple SQL 查询参数
"""
global _db_ctx
cursor = None
# 格式化 sql 语句
sql = sql.replace('?', '%s')
logging.info('SQL: %s, ARGS: %s' % (sql, args))
try:
# 获取游标
cursor = _db_ctx.connection.cursor()
# 执行 SQL
try:
cursor.execute(sql, args)
except Exception, e:
raise DBError('sql execute error')
# 影响行数
r = cursor.rowcount
# 提交更改
if _db_ctx.transactions == 0:
logging.info('auto commit')
_db_ctx.connection.commit()
return r
finally:
# 关闭游标
if cursor:
cursor.close()
def update(sql, *args):
""" 更新方法
@params: sql string SQL 语句
args tuple 查询参数
@return 返回影响行数
>>> update("UPDATE user SET nickname=? WHERE nickname=? AND passwd=?", 'ruby', 'python', '111111')
"""
return _update(sql, *args)
def insert(table, **kwargs):
""" 插入方法
@params tabel string 需要插入的数据表名
kwargs dict 插入数据字典
@return 返回影响的数据库行数
>>> insert('user', nickname='python', email='python@gmail.com', passwd='111111')
"""
cols, args = zip(*kwargs.iteritems())
sql = 'insert into `%s` (%s) values (%s)' % (table, ','.join(['`%s`' % col for col in cols]), ','.join(['?' for i in range(len(cols))]))
return _update(sql, *args)
def delete(sql, *args):
""" 删除方法 与 update 类似,用于执行 sql 删除
"""
return _update(sql, *args)
@with_connection
def _select(sql, first, *args):
""" 查询数据库方法,执行查询sql语句,返回结果集
@params: sql string SQL查询语句
first bool 是否为查询一条,True为查询一条记录
args tuple sql查询参数
@return: 返回结果集列表
"""
global _db_ctx
cursor = None
# 格式化 sql语句
sql = sql.replace('?', '%s')
logging.info('SQL: %s, ARFS: %s' % (sql, args))
try:
# 获取游标
cursor = _db_ctx.connection.cursor()
# 执行 SQL 语句
cursor.execute(sql, args)
if cursor.description:
names = [x[0] for x in cursor.description]
if first:
# 获得结果集
values = cursor.fetchone()
if not values:
return None
# 格式化结果
return Dict(names, values)
return [Dict(names, x) for x in cursor.fetchall()]
finally:
# 关闭游标
if cursor:
cursor.close()
def select(sql, *args):
""" 查询sql方法 返回结果集
@params: sql string sql查询语句
args tuple 查询参数
@return: 返回结果集
"""
return _select(sql, False, *args)
def select_one(sql, *args):
return _select(sql, True, *args)
def create_pool(user, passwd, db, host='127.0.0.1', port=3306, **kwargs):
""" 创建连接池
@params: user string 数据库用户名
passwd string 数据库用户名密码
db string 数据库名
host string 数据库主机地址,默认为 127.0.0.1
port number 数据库端口 , 默认为3306
kwargs dict 其他设置参数
@return None
"""
import MySQLdb
from DBUtils.PooledDB import PooledDB
global dbpool
# 判断连接池是否存在
if dbpool is not None:
logging.info(DBError('pool is already initialized.'))
return
# 连接参数
params = dict(user=user, passwd=passwd, db=db, host=host, port=port)
# 默认的连接参数
# use_unicode 是否使用 unicode, 默认 True
# charset 数据库编码 默认使用utf8
# mincached : 启动时开启的闲置连接数量(缺省值 0 以为着开始时不创建连接)
# maxcached : 连接池中允许的闲置的最多连接数量(缺省值 0 代表不闲置连接池大小)
# maxshared : 共享连接数允许的最大数量(缺省值 0 代表所有连接都是专用的)如果达到了最大数量,被请求为共享的连接将会被共享使用
# maxconnections : 创建连接池的最大数量(缺省值 0 代表不限制)
# blocking : 设置在连接池达到最大数量时的行为(缺省值 0 或 False 代表返回一个错误<toMany......>; 其他代表阻塞直到连接数减少,连接被分配)
# maxusage : 单个连接的最大允许复用次数(缺省值 0 或 False 代表不限制的复用).当达到最大数时,连接会自动重新连接(关闭和重新打开)
# setsession : 一个可选的SQL命令列表用于准备每个会话,
defaults = dict(use_unicode=True, charset='utf8', mincached=10, maxcached=10, maxshared=30, maxconnections=100, blocking=True, maxusage=0, setsession=None)
# 处理自定义参数和默认参数
for k, v in defaults.iteritems():
params[k] = kwargs.pop(k, v)
# 更新连接参数
params.update(kwargs)
# 创建连接池
dbpool = PooledDB(MySQLdb, **params)
logging.info('Init mysql pool <%s>ok' % hex(id(dbpool)))
def create_engine():
""" 创建数据库连接引擎,用于获取数据库连接池连接
"""
global engine
if engine is not None:
logging.info('Engine is already initialized.')
return
engine = _Engine(lambda : dbpool.connection())
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
create_pool(user='root', passwd='', db='pytest', host='127.0.0.1')
create_engine()
update("DROP TABLE IF EXISTS user")
update("CREATE TABLE user (id INT UNSIGNED NOT NULL PRIMARY KEY AUTO_INCREMENT, nickname VARCHAR(40), email VARCHAR(40), passwd VARCHAR(40))")
insert('user', nickname='python', email='python@gmail.com', passwd='123456')
insert('user', nickname='python', email='python@gmail.com', passwd='111111')
update("UPDATE user SET nickname=? WHERE nickname=? AND passwd=?", 'ruby', 'python', '111111')
@with_transaction
def update_profile(name, rollback):
u = dict(nickname=name, email='{0}@test.com'.format(name))
insert('user', **u)
r = update("UPDATE user SET nickname=? where passwd=?", name.upper(), '111111')
if rollback:
raise StandardError('will cause rollback...')
update_profile('test', True)
#: -*- coding: utf-8 -*-
"""
database
~~~~~~~~
python对mysql操作的封装类
~~~~~~~~~~~~~~~~~~~~~~~~~
:author: rsj217
:license: BSD.
:contact: rsj217@gmail.com
:version: 0.0.1
"""
try: # 连接 MySQLdb, 或者 pymysql
import MySQLdb as mysql
except ImportError:
import pymysql as mysql
class DataBase(object):
"""
数据库操作类
"""
def __init__(self, host, user, passwd, database, port=3306, charset='utf8'):
"""初始化数据库配置信息,端口默认 3306 编码 utf-8
"""
#: * 数据库主机地址
self.host=host
#: * 数据库用户名
self.user=user
#: * 数据库用户密码
self.passwd=passwd
#: * 数据库名
self.database=database
#: * 数据库端口
self.port=port
#: * 数据库字符编码
self.charset=charset
def __get_db(self):
"""链接数据库,获取数据库句柄"""
#: 数据库连接, 返回资源句柄
db = mysql.connect(
host=self.host,
user=self.user,
passwd=self.passwd,
db=self.database,
port=self.port,
charset=self.charset)
return db
def execrone(self, func):
'''
读取数据库,返回单条记录
parameters
func
函数类型,被装饰器包装的函数,不用显示传递
return
查询数据库单条记录结果和影响行数
sample::
@self.execrone
def getone():
pass
getone 将会被本方法包装
'''
def wrap(*args):
try:
#: 连接数据库
db = self.__get_db()
#: 获取数据查询游标
cursor = db.cursor()
#: 得到 sql 语句
sql = func(*args)
#: 执行单条 sql 语句, 返回受影响的行数
rownum = cursor.execute(sql)
#: 执行查询,返回单条数据
result = cursor.fetchone()
#: 返回查询结果和影响行数
return (rownum, result)
except mysql.Error, e:
print "Mysql Error %d: %s" % (e.args[0], e.args[1])
finally:
#: 关闭游标
cursor.close()
#: 关闭数据库
db.close()
return wrap
def execrall(self, func):
'''
读取数据库,返回多条记录
parameters:
func
函数类型,被装饰器包装的函数,不用显示传递
return:
查询数据库多条记录结果和影响行数
sample::
@self.execrall
def getall():
pass
getall 将会被本方法包装
'''
def wrap(*args):
try:
#: 连接数据库
db = self.__get_db()
#: 获取数据查询游标
cursor = db.cursor()
#: 得到 sql 语句
sql = func(*args)
#: 执行单条 sql 语句, 返回受影响的行数
rownum = cursor.execute(sql)
#: 执行查询,返回多条数据
result = cursor.fetchall()
#: 返回查询结果和影响行数
return (rownum, result)
except mysql.Error, e:
print "Mysql Error %d: %s" % (e.args[0], e.args[1])
finally:
#: 关闭游标
cursor.close()
#: 关闭数据库
db.close()
return wrap
def execcud(self, func):
'''
添加数据库记录,用于 create update delete 操作,
如果写入数据库失败,则执行回滚操作。
parameters:
func
函数类型,被装饰器包装的函数,不用显示传递
return:
增加更新删除数据库影响的行数
sample::
@self.execcud
def insert():
pass
insert 将会被本方法包装
'''
def wrap(*args):
try:
#: 连接数据库
db = self.__get_db()
#: 获取数据查询游标
cursor = db.cursor()
#: 得到 sql 语句
sql = func(*args)
#: 执行单条 sql 语句, 返回受影响的行数
rownum = cursor.execute(sql)
#: 提交查询
db.commit()
#: 返回影响行数
return rownum
except mysql.Error, e:
#: 发生错误时回滚
db.rollback()
print "Mysql Error %d: %s" % (e.args[0], e.args[1])
finally:
#: 关闭游标
cursor.close()
#: 关闭数据库
db.close()
return wrap
# -*- coding: utf-8 -*-
__author__ = 'ghost'
import unittest, logging
import pool
class TestPool(unittest.TestCase):
def setUp(self):
pool.create_pool(user='root', passwd='', db='pytest', host='127.0.0.1')
pool.create_engine()
pool.update("DROP TABLE IF EXISTS user")
pool.update("CREATE TABLE user (id INT UNSIGNED NOT NULL PRIMARY KEY AUTO_INCREMENT, nickname VARCHAR(40), email VARCHAR(40), passwd VARCHAR(40))")
def tearDown(self):
print 'test end.'
def test_insert(self):
r = pool.insert('user', nickname='insert', email='insert@gmail.com', passwd='insert')
self.assertEquals(1L, r)
with self.assertRaises(pool.DBError):
r = pool.insert('users', nickname='python', email='python@gmail.com', passwd='123456')
print 'test insert end'
def test_delete(self):
r = pool.insert('user', nickname='delete', email='delete@gmail.com', passwd='delete')
self.assertEquals(1, r)
dr = pool.delete("DELETE FROM user WHERE nickname=?", 'delete')
self.assertEquals(dr, 1)
print 'test delete end'
def test_update(self):
r = pool.insert('user', nickname='update', email='delete@gmail.com', passwd='update')
self.assertEquals(1, r)
r = pool.update("UPDATE user SET email=?, passwd=? WHERE nickname=?", 'update@gmail.com', '111111', 'update')
self.assertEquals(1, r)
print 'test update end'
def test_select(self):
r = pool.insert('user', nickname='python', email='python@gmail.com', passwd='python')
r = pool.insert('user', nickname='ruby', email='ruby@gmail.com', passwd='update')
r = pool.insert('user', nickname='python', email='ruby@gmail.com', passwd='update')
users = pool.select("SELECT * FROM user WHERE nickname=?", "python")
self.assertEquals(2, len(users))
print users
print 'test select end'
def test_transaction(self):
def update_profile(name, rollback):
u = dict(nickname=name, email='{0}@test.com'.format(name))
pool.insert('user', **u)
r = pool.update("UPDATE user SET nickname=? where passwd=?", name.upper(), '111111')
if rollback:
raise StandardError('will cause rollback...')
with self.assertRaises(StandardError):
with pool.transaction():
update_profile('test', True)
with pool.transaction():
update_profile('test', False)
u = pool.select("SELECT * FROM user WHERE nickname=?", 'test')
self.assertEquals(1, len(u))
print u
print 'test transaction end'
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
unittest.main()
# id classid title
# 1 -1 python
# 2 -1 ruby
# 3 -1 php
# 4 -1 lisp
# 5 1 flask
# 6 1 django
# 7 1 webpy
# 8 2 rails
# 9 3 zend
# 10 6 dblog
t = (
(1, -1, 'python'),
(2, -1, 'ruby'),
(3, -1, 'php'),
(4, -1, 'lisp'),
(5, 1, 'flask'),
(6, 1, 'django'),
(7, 1, 'webpy'),
(8, 2, 'rails'),
(9, 3, 'zend'),
(10, 6, 'dblog')
)
# l = [
# {
# 'id': 1,
# 'classid': -1,
# 'title': 'python',
# 'son': [
# {
# 'id': 5,
# 'classid': 1,
# 'title': 'flask',
# 'son': None
# },
# {
# 'id': 6,
# 'classid': 1,
# 'title': 'django',
# 'son': [
# {
# 'id': 10,
# 'classid': 6,
# 'title': 'dblog',
# 'son': None
# },
# ]
# },
# {
# 'id': 7,
# 'classid': 1,
# 'title': 'webpy',
# 'son': None
# },
# ]
# },
# {
# 'id': 2,
# 'classid': -1,
# 'title': 'ruby',
# 'son': [
# {
# 'id': 8,
# 'classid': 2,
# 'title': 'rails',
# 'son': None
# },
# ]
# },
# {
# 'id': 3,
# 'classid': -1,
# 'title': 'php',
# 'son': [
# {
# 'id': 9,
# 'classid': 3,
# 'title': 'zend',
# 'son': None
# },
# ]
# },
# {
# 'id': 4,
# 'classid': -1,
# 'title': 'lisp',
# 'son': None
# }
# ]
# from pprint import pprint
# l = []
# entries = {}
# for id, fid, title in t:
# entries[id] = entry = {'id': id, 'fid': fid, 'title': title}
# if fid == -1:
# l.append(entry)
# else:
# parent = entries[fid]
# parent.setdefault('son', []).append(entry)
# pprint(l)
children = {}
objs = []
l = []
for id, parent, title in t:
obj = {
"id": id,
"fid": parent,
"title": title
}
objs.append(obj)
if parent == -1: # keep only roots
l.append(obj)
if parent not in children: # append to children
children[parent] = []
children[parent].append(obj)
for obj in objs:
if obj["id"] in children:
obj["son"] = children[obj["id"]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment