Skip to content

Instantly share code, notes, and snippets.

@andriykohut
Last active July 14, 2016 14:57
Show Gist options
  • Save andriykohut/0dafc22c9138b3cb95444b662228c0ea to your computer and use it in GitHub Desktop.
Save andriykohut/0dafc22c9138b3cb95444b662228c0ea to your computer and use it in GitHub Desktop.
torm - something orm-ish for tornado_mysql
from tornado_mysql.cursors import DictCursor
from tornado_mysql.pools import Pool
pool = Pool(dict(host='127.0.0.1',
port=3306,
user='root',
passwd='password',
db='ia_manager_api',
cursorclass=DictCursor),
max_idle_connections=5,
max_open_connections=20)
class cursor_context:
def __init__(self, pool=pool):
self.pool = pool
self.__conn = None
self.__cursor = None
async def __aenter__(self):
self.__conn = await self.pool._get_conn()
try:
self.__cursor = self.__conn.cursor(DictCursor)
except:
self.pool._close_conn(self.__conn)
raise
return self.__cursor
async def __aexit__(self, *args):
await self.__cursor.close()
self.pool._put_conn(self.__conn)
import math
from collections import defaultdict
from connection import pool
from connection import cursor_context
LOOKUP_TYPES = {
'eq': '=',
'gt': '>',
'gte': '>=',
'lt': '<',
'lte': '<=',
'in': 'IN',
}
JOIN_TYPES = {
'JOIN',
'LEFT JOIN',
'RIGTH JOIN',
'FULL OUTER JOIN',
}
GROUP_CONCAT_SEPARATOR = '|||'
class MetaModel(type):
"""Used as a metaclass for base Model class for some basic validation."""
def __new__(cls, name, bases, attrs):
parents = [b for b in bases if isinstance(b, MetaModel)]
if not parents:
return super().__new__(cls, name, bases, attrs)
# Validate basic attributes
pk = attrs.get('__pk__')
table = attrs.get('__table__')
columns = attrs.get('__columns__')
assert pk, 'Primary key is required'
assert table, 'Table is required'
assert columns, 'Column dict missing'
assert pk in columns, 'Primary key not in columns dict'
# Validate relations
rel = attrs.get('__relations__')
if rel:
for col, (model, rel_col) in rel.items():
assert col in columns, ('__relations__: no such column {}').format(col)
assert rel_col in model.__columns__, ('__relations__: no such column {}.{}').format(model.__table__,
rel_col)
return super().__new__(cls, name, bases, attrs)
class Model(metaclass=MetaModel):
__pk__ = 'id'
__table__ = None
__columns__ = None
__ordering_col__ = __pk__
@classmethod
def select_expr(cls):
'''Generate SELECT expression from Model columns.'''
return ', '.join('{} AS {}'.format(k, v) for k, v in cls.__columns__.items())
@classmethod
def select_expr_join(cls):
'''Generate select expression for joins.
Will build smth like:
"table.col1 AS table__col1, table.col2 AS table__col2"
'''
return ', '.join('{0}.{1} AS {0}__{2}'.format(cls.__table__, key, val) for key, val in cls.__columns__.items())
@classmethod
def select_expr_group_concat(cls, sep=GROUP_CONCAT_SEPARATOR):
"""Generate group concat select expression.
:param sep: separator string.
"""
return ', '.join("GROUP_CONCAT({0}.{1} SEPARATOR '{3}') AS {0}__{2}".format(cls.__table__, k, v, sep)
for k, v in cls.__columns__.items())
@classmethod
def group_by_expr(cls, grouping):
"""Generate GROUP BY expression.
:param grouping: list of tuples (model: Model, field: str)
"""
groups = ', '.join('{}.{}'.format(m.__table__, f) for m, f in grouping)
return 'GROUP BY {}'.format(groups)
@classmethod
def filter_expr(cls, key, val):
'''Generate filter expression for WHERE clauses, etc.
Different lookup types are supported, e.g.:
cls.filter_expr('count__lte', 2) -> count <= %(count__lte)s
See LOOKUP_TYPES for more details.
:param key: filter argument to parse.
:param val: value for filter expression.
:returns: string with filter expression.
'''
# TODO: Not sure we need `val` here
try:
field, lookup_type = key.split('__')
except ValueError:
field = key
lookup_type = 'eq'
assert field in cls.__columns__, 'Unknown field: {}'.format(field)
assert lookup_type in LOOKUP_TYPES, 'Unknown lookup type {}'.format(lookup_type)
column_expr = cls.__columns__[field]
lookup_expr = LOOKUP_TYPES[lookup_type]
return "{} {} %({})s".format(column_expr, lookup_expr, key)
@classmethod
def __validate_related_field(cls, table, field):
"""Find the related model for table/field combination.
:param table: string with related table name.
:param field: string with field name.
:returns: related model.
"""
model = None
if table == cls.__table__:
assert field in cls.__columns__, "Field {}.{} not found".format(table, field)
return cls
for m, _ in cls.__relations__.values():
if m.__table__ == table:
model = m
break
assert model, "Model for {}.{} not found".format(table, field)
assert field in model.__columns__, "Field {}.{} not found".format(table, field)
return model
@classmethod
def filter_expr_join(cls, key, val, tables=None):
"""Generate filters for WHERE clause based on query key/value pair.
Example:
cls.filter_expr_join(article_status___id, 1)
# will produce: "article_status.id = %(article_status___id)s"
:param key: filter argument to parse.
:param val: value for filter expression.
:param tables: list of related tables.
:returns: string with where clause filter expression
"""
tables = tables if tables else [v[0].__table__
for v in cls.__relations__.values()]
try:
table, rest = key.split('___')
except ValueError:
table = cls.__table__
rest = key
try:
field, lookup_type = rest.split('__')
except ValueError:
field = rest
lookup_type = 'eq'
assert table in tables, 'Filtering by table that is not joined: {}'.format(table)
assert lookup_type in LOOKUP_TYPES, 'Unknown lookup type {}'.format(lookup_type)
model = cls.__validate_related_field(table, field)
column_expr = "{}.{}".format(model.__table__, model.__columns__[field])
lookup_expr = LOOKUP_TYPES[lookup_type]
return "{} {} %({})s".format(column_expr, lookup_expr, key)
@classmethod
def __pagination_to_limit(cls, page, page_size):
"""Just turn the page and page size into proper LIMIT.
:param page: integer, page number.
:param page_size: integer, number of records per page.
:returns: string with `LIMIT ...` clause.
"""
offset = (page-1) * page_size
return 'LIMIT {offset}, {page_size}'.format(**locals())
@classmethod
async def get_one(cls, pk):
"""Get one row by primary key.
:param pk: value of primary key.
:returns: dict with row data.
"""
select_expr = cls.select_expr()
sql = 'SELECT {} FROM {} WHERE {} = %s'.format(select_expr, cls.__table__, cls.__pk__)
cursor = await pool.execute(sql, pk)
result = cursor.fetchone()
return result
@classmethod
async def get_all(cls, where=None, page=1, page_size=50, **kwargs):
'''Get multiple rows with filtering, ordering and limit.
See Model.filter_expr for more lookup type details.
Usage:
await Model.get_all(dict(id=2, name__lte=12), page=2, page_size=10)
:param where: dict with where params.
:param page: page number to retrieve.
:param page_size: row count for page
'''
select_expr = cls.select_expr()
sql_args = {}
filter_exprs = []
if where:
for key, value in where.items():
filter_expr = cls.filter_expr(key, value)
filter_exprs.append(filter_expr)
sql_args[key] = value
limit_expr = cls.__pagination_to_limit(page, page_size)
where_expr = 'WHERE ' + ' AND '.join(filter_exprs) if filter_exprs else ''
sql = ('SELECT SQL_CALC_FOUND_ROWS {select_expr}\n'
'FROM {table} {where_expr}\n'
'{limit_expr}').format(select_expr=select_expr,
table=cls.__table__,
where_expr=where_expr,
limit_expr=limit_expr)
async with cursor_context(pool) as c:
await c.execute(sql, sql_args)
data = c.fetchall()
await c.execute('SELECT FOUND_ROWS()')
found_rows = c.fetchone()['FOUND_ROWS()']
total_pages = math.ceil(found_rows/page_size)
result = {'data': data, 'page': page, 'page_size': page_size, 'pages': total_pages}
result['next_page'] = page + 1 if page <= total_pages else None
result['previuos_page'] = page - 1 if page > 1 else None
return result
@classmethod
def join_on(cls, field, join_type='JOIN'):
"""Build join expression based on model foregn keys.
:param field: field with foreign key.
:param join_type: JOIN_TYPES string.
:returns: string with join expression.
"""
join_type = join_type.upper()
assert join_type in JOIN_TYPES, 'Invalid join: {}'.format(join_type)
model, rel_field = cls.__relations__[field]
rel_field = model.__columns__[rel_field]
rel_table = model.__table__
table = cls.__table__
sql = ('{join_type} {rel_table} ' 'ON {rel_table}.{rel_field} = {table}.{field}')
return sql.format(**locals())
@classmethod
async def get_related(cls, where, tables=None, page=1, page_size=50, **kwargs):
"""Same as cls.get_all(...), but also joins with related tables, allows to perform
GROUP_CONCAT with GROUP BY.
Usage:
result = await Article.get_related(
dict(id=1, article_status___id=1, article_source_type___description='buzz'),
joins=dict(status_id='LEFT JOIN', source_type_id='JOIN'))
This assumes that article has one-to-one relationship with two models: one for `article_status` table and
another one for `article_source_type`, and following columns exist:
* `article_status`.`id`
* `article_source_type`.`description`
So the resulting query will be something like:
SELECT article.updated_at AS article__updated_at, article.source_type_id AS article__source_type_id,
article.created_at AS article__created_at, article.id AS article__id,
...etc...
FROM article
LEFT JOIN article_status ON article_status.id = article.status_id
JOIN article_source_type ON article_source_type.id = article.source_type_id
WHERE article_source_type.description = %(article_source_type___description)s
AND article_status.id = %(article_status___id)s AND article.id = %(id)s
Example with group_concat:
result = await Article.get_related(
dict(article_status___description__in=('done', ),
article_source_type___description='buzz'),
group_concat=[Import],
group_by=[(Article, 'id')],
joins=dict(id='LEFT JOIN'))
Here we peform GROUP_CONCAT aggregation on all fields from Import model, this however require group_by
parameter. Just keep in mind this is not very robust, and the resulting `import` dict for each field
will have a list of strings as a values (or null).
:param where: positional, dict with qery params.
:param tables: positional, list of tables to include in select query, defaults to ones in __relations__
:param page: positional, page number to retrieve.
:param page_size: positional, row count for page
:param joins: dict with mapping of field to join_type. The keys should be model's fields with foreign keys,
and values should be join expression from JOIN_TYPES.
:param group_concat: List of models, the fields of this models will be aggregated by GROUP_CONCAT.
:param group_by: List of tuples of (model, field), where `model` is either current Model or one of related,
and `field` is the string with the field name to join on.
:returns: List of dicts with resulting data.
"""
select_exprs = [cls.select_expr_join()]
rel_tables = [cls.__table__]
sql_args = {}
filter_exprs = []
group_by = kwargs.get('group_by')
group_concat_models = kwargs.get('group_concat')
group_by_expr = ''
if group_concat_models:
assert group_by, 'group_by is required with group_concat'
group_by_expr = cls.group_by_expr(group_by)
relations = {c: (m, rc) for c, (m, rc)
in cls.__relations__.items()
if m.__table__ in tables} if tables else cls.__relations__
for _, (model, _) in relations.items():
if group_concat_models and model in group_concat_models:
select_exprs.append(model.select_expr_group_concat())
else:
select_exprs.append(model.select_expr_join())
rel_tables.append(model.__table__)
for key, value in where.items():
filter_expr = cls.filter_expr_join(key, value, tables=tables)
filter_exprs.append(filter_expr)
sql_args[key] = value
where_expr = 'WHERE ' + ' AND '.join(filter_exprs) if filter_exprs else ''
select_expr = ', '.join(select_exprs)
rel_tables_expr = ', '.join('{}'.format(t) for t in rel_tables)
table = cls.__table__
joins = kwargs.get('joins')
if not joins:
join_on = '\n'.join(cls.join_on(f) for f in relations)
else:
join_on = '\n'.join(cls.join_on(f, joins.get(f, 'JOIN')) for f in relations)
limit_expr = cls.__pagination_to_limit(page, page_size)
sql = ('SELECT SQL_CALC_FOUND_ROWS {select_expr}\n'
'FROM {table}\n'
'{join_on}\n'
'{where_expr}\n'
'{group_by_expr}\n'
'{limit_expr}').format(**locals())
async with cursor_context(pool) as c:
await c.execute(sql, sql_args)
raw_result = c.fetchall()
await c.execute('SELECT FOUND_ROWS()')
found_rows = c.fetchone()['FOUND_ROWS()']
total_pages = math.ceil(found_rows/page_size)
result = {'data': [], 'page': page, 'page_size': page_size, 'pages': total_pages}
result['next_page'] = page + 1 if page <= total_pages else None
result['previuos_page'] = page - 1 if page > 1 else None
group_concat_tables = [m.__table__ for m in group_concat_models]if group_concat_models else None
for item in raw_result:
new_item = defaultdict(dict)
for k, v in item.items():
table, column = k.split('__')
# Split by GROUP_CONCAT_SEPARATOR, keep in mind that values for group_concat columns will
# be a list of strings, unfortunately group_concat just dumps everything into single string.
if group_concat_tables and table in group_concat_tables:
try:
new_item[table][column] = v.split(GROUP_CONCAT_SEPARATOR)
except AttributeError:
pass
else:
new_item[table][column] = v
result['data'].append(new_item)
return result
@classmethod
async def create(cls, **kwargs):
"""Insert new row."""
insert_keys = [key for key in kwargs.keys() if key in cls.__columns__]
insert_expr = ', '.join('{}'.format(key) for key in insert_keys)
values_expr = ', '.join('%({})s'.format(key) for key in insert_keys)
sql = 'INSERT INTO {} ({}) VALUES ({})'.format(cls.__table__, insert_expr, values_expr)
cursor = await pool.execute(sql, kwargs)
# in case primary key is not AUTO INCREMENT
result = cursor.lastrowid or kwargs.get(cls.__pk__, 0)
return result
@classmethod
async def update_one(cls, pk, **kwargs):
'''Update by primary key.
Usage:
await Buzz.update_by_pk(42, title='this is neat')
'''
update_keys = [key for key in kwargs.keys() if key in cls.__columns__]
set_expr = ', '.join('{0}=%({0})s'.format(k) for k in update_keys)
sql = 'UPDATE {} SET {} WHERE {}=%(pk)s'.format(cls.__table__, set_expr, cls.__pk__)
cursor = await pool.execute(sql, {**kwargs, **{'pk': pk}})
result = cursor.lastrowid
return result
@classmethod
async def update(cls, sargs, wargs):
'''Update multiple rows at onece.
wargs supports filtering with LOOKUP_TYPES.
Usage:
await BackfillStatus.update({'description': 'whooo'},
{'description': 'deleted',
'id__lte': 2})
:param sargs: SET key/value pairs.
:param wargs: WHERE filters.
'''
where_expr = []
where_args = {}
for k, v in wargs.items():
try:
where_expr.append(cls.filter_expr(k, v))
where_args[k] = v
except AssertionError:
pass
# adding the __set postfix here, since we need to distinguish k/v
# pairs in SET expression from ones in WHERE clause
set_keys = [k for k in sargs.keys() if k in cls.__columns__]
set_expr = ', '.join('{0}=%({0}__set)s'.format(k) for k in set_keys)
if not all([set_expr, where_expr]):
raise ValueError('Valid SET and WHERE clauses required')
sql = 'UPDATE {} SET {} WHERE {}'.format(cls.__table__, set_expr, ' AND '.join(where_expr))
sql_args = {**where_args, **{k + '__set': sargs[k] for k in set_keys}}
cursor = await pool.execute(sql, sql_args)
return cursor.rowcount
@classmethod
async def delete_one(cls, pk):
'''DELETE row by primary key.'''
sql = 'DELETE FROM {} WHERE {} = %s'.format(cls.__table__, cls.__pk__)
cursor = await pool.execute(sql, pk)
result = cursor.fetchone()
return result
@classmethod
async def delete(cls, **kwargs):
'''DELETE multiple rows.
Supports LOOKUP_TYPES in kwargs.
'''
sql_args = {}
filter_exprs = []
for key, value in kwargs.items():
filter_expr = cls.filter_expr(key, value)
filter_exprs.append(filter_expr)
sql_args[key] = value
sql = 'DELETE FROM {} WHERE {}'.format(cls.__table__, ' AND '.join(filter_exprs))
cursor = await pool.execute(sql, sql_args)
result = cursor.rowcount
return result
from tornado import ioloop
from connection import cursor_context
async def main():
async with cursor_context() as c:
await c.execute('SHOW TABLES')
print(c.fetchall())
if __name__ == '__main__':
ioloop.IOLoop.instance().run_sync(main)
import ujson
import tornado.ioloop
from model import Model
class ArticleStatus(Model):
__pk__ = 'id'
__table__ = 'article_status'
__columns__ = {
'id': 'id',
'description': 'description',
'created_at': 'created_at',
'updated_at': 'updated_at'
}
class ArticleSourceType(Model):
__pk__ = 'id'
__table__ = 'article_source_type'
__columns__ = {
'id': 'id',
'description': 'description',
'created_at': 'created_at',
'updated_at': 'updated_at'
}
class Import(Model):
__pk__ = 'id'
__table__ = 'import'
__columns__ = {
'id': 'id',
'import_status_id': 'import_status_id',
'article_id': 'article_id',
'instant_article_id': 'instant_article_id',
'import_status': 'import_status',
'errors': 'errors',
'created_at': 'created_at',
'updated_at': 'updated_at'
}
class Article(Model):
__pk__ = 'id'
__table__ = 'article'
__columns__ = {
'id': 'id',
'source_id': 'source_id',
'source_type_id': 'source_type_id',
'status_id': 'status_id',
'created_at': 'created_at',
'updated_at': 'updated_at'
}
__relations__ = {
'source_type_id': (ArticleSourceType, 'id'),
'status_id': (ArticleStatus, 'id'),
'id': (Import, 'article_id'),
}
async def main():
article_status_id = await ArticleStatus.create(id=1, description='done')
article_source_id = await ArticleSourceType.create(id=1, description='buzz')
article1_id = await Article.create(source_id=214,
source_type_id=1,
status_id=1)
article2_id = await Article.create(source_id=215,
source_type_id=1,
status_id=1)
# import1_id has two articles
import1_id = await Import.create(import_status_id=1, article_id=article1_id)
import2_id = await Import.create(import_status_id=1, article_id=article2_id)
import3_id = await Import.create(import_status_id=1, article_id=article2_id)
result = await Article.get_related(
dict(article_status___description__in=('done', ),
article_source_type___description='buzz'),
group_concat=[Import],
group_by=[(Article, 'id')],
joins=dict(id='LEFT JOIN'),
page_size=1,
page=2)
json = ujson.dumps(result, indent=2)
print("get_related example")
print('-'*120)
print(json)
result = await Import.get_all(page_size=1)
json = ujson.dumps(result, indent=2)
print("get_related example")
print('-'*120)
print(json)
await Import.delete_one(import1_id)
await Import.delete_one(import2_id)
await Import.delete_one(import3_id)
await Article.delete_one(article1_id)
await Article.delete_one(article2_id)
await ArticleSourceType.delete_one(article_source_id)
await ArticleStatus.delete_one(article_status_id)
if __name__ == '__main__':
tornado.ioloop.IOLoop.instance().run_sync(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment