Last active
July 14, 2016 14:57
-
-
Save andriykohut/0dafc22c9138b3cb95444b662228c0ea to your computer and use it in GitHub Desktop.
torm - something orm-ish for tornado_mysql
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tornado_mysql.cursors import DictCursor | |
from tornado_mysql.pools import Pool | |
pool = Pool(dict(host='127.0.0.1', | |
port=3306, | |
user='root', | |
passwd='password', | |
db='ia_manager_api', | |
cursorclass=DictCursor), | |
max_idle_connections=5, | |
max_open_connections=20) | |
class cursor_context: | |
def __init__(self, pool=pool): | |
self.pool = pool | |
self.__conn = None | |
self.__cursor = None | |
async def __aenter__(self): | |
self.__conn = await self.pool._get_conn() | |
try: | |
self.__cursor = self.__conn.cursor(DictCursor) | |
except: | |
self.pool._close_conn(self.__conn) | |
raise | |
return self.__cursor | |
async def __aexit__(self, *args): | |
await self.__cursor.close() | |
self.pool._put_conn(self.__conn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
from collections import defaultdict | |
from connection import pool | |
from connection import cursor_context | |
LOOKUP_TYPES = { | |
'eq': '=', | |
'gt': '>', | |
'gte': '>=', | |
'lt': '<', | |
'lte': '<=', | |
'in': 'IN', | |
} | |
JOIN_TYPES = { | |
'JOIN', | |
'LEFT JOIN', | |
'RIGTH JOIN', | |
'FULL OUTER JOIN', | |
} | |
GROUP_CONCAT_SEPARATOR = '|||' | |
class MetaModel(type): | |
"""Used as a metaclass for base Model class for some basic validation.""" | |
def __new__(cls, name, bases, attrs): | |
parents = [b for b in bases if isinstance(b, MetaModel)] | |
if not parents: | |
return super().__new__(cls, name, bases, attrs) | |
# Validate basic attributes | |
pk = attrs.get('__pk__') | |
table = attrs.get('__table__') | |
columns = attrs.get('__columns__') | |
assert pk, 'Primary key is required' | |
assert table, 'Table is required' | |
assert columns, 'Column dict missing' | |
assert pk in columns, 'Primary key not in columns dict' | |
# Validate relations | |
rel = attrs.get('__relations__') | |
if rel: | |
for col, (model, rel_col) in rel.items(): | |
assert col in columns, ('__relations__: no such column {}').format(col) | |
assert rel_col in model.__columns__, ('__relations__: no such column {}.{}').format(model.__table__, | |
rel_col) | |
return super().__new__(cls, name, bases, attrs) | |
class Model(metaclass=MetaModel): | |
__pk__ = 'id' | |
__table__ = None | |
__columns__ = None | |
__ordering_col__ = __pk__ | |
@classmethod | |
def select_expr(cls): | |
'''Generate SELECT expression from Model columns.''' | |
return ', '.join('{} AS {}'.format(k, v) for k, v in cls.__columns__.items()) | |
@classmethod | |
def select_expr_join(cls): | |
'''Generate select expression for joins. | |
Will build smth like: | |
"table.col1 AS table__col1, table.col2 AS table__col2" | |
''' | |
return ', '.join('{0}.{1} AS {0}__{2}'.format(cls.__table__, key, val) for key, val in cls.__columns__.items()) | |
@classmethod | |
def select_expr_group_concat(cls, sep=GROUP_CONCAT_SEPARATOR): | |
"""Generate group concat select expression. | |
:param sep: separator string. | |
""" | |
return ', '.join("GROUP_CONCAT({0}.{1} SEPARATOR '{3}') AS {0}__{2}".format(cls.__table__, k, v, sep) | |
for k, v in cls.__columns__.items()) | |
@classmethod | |
def group_by_expr(cls, grouping): | |
"""Generate GROUP BY expression. | |
:param grouping: list of tuples (model: Model, field: str) | |
""" | |
groups = ', '.join('{}.{}'.format(m.__table__, f) for m, f in grouping) | |
return 'GROUP BY {}'.format(groups) | |
@classmethod | |
def filter_expr(cls, key, val): | |
'''Generate filter expression for WHERE clauses, etc. | |
Different lookup types are supported, e.g.: | |
cls.filter_expr('count__lte', 2) -> count <= %(count__lte)s | |
See LOOKUP_TYPES for more details. | |
:param key: filter argument to parse. | |
:param val: value for filter expression. | |
:returns: string with filter expression. | |
''' | |
# TODO: Not sure we need `val` here | |
try: | |
field, lookup_type = key.split('__') | |
except ValueError: | |
field = key | |
lookup_type = 'eq' | |
assert field in cls.__columns__, 'Unknown field: {}'.format(field) | |
assert lookup_type in LOOKUP_TYPES, 'Unknown lookup type {}'.format(lookup_type) | |
column_expr = cls.__columns__[field] | |
lookup_expr = LOOKUP_TYPES[lookup_type] | |
return "{} {} %({})s".format(column_expr, lookup_expr, key) | |
@classmethod | |
def __validate_related_field(cls, table, field): | |
"""Find the related model for table/field combination. | |
:param table: string with related table name. | |
:param field: string with field name. | |
:returns: related model. | |
""" | |
model = None | |
if table == cls.__table__: | |
assert field in cls.__columns__, "Field {}.{} not found".format(table, field) | |
return cls | |
for m, _ in cls.__relations__.values(): | |
if m.__table__ == table: | |
model = m | |
break | |
assert model, "Model for {}.{} not found".format(table, field) | |
assert field in model.__columns__, "Field {}.{} not found".format(table, field) | |
return model | |
@classmethod | |
def filter_expr_join(cls, key, val, tables=None): | |
"""Generate filters for WHERE clause based on query key/value pair. | |
Example: | |
cls.filter_expr_join(article_status___id, 1) | |
# will produce: "article_status.id = %(article_status___id)s" | |
:param key: filter argument to parse. | |
:param val: value for filter expression. | |
:param tables: list of related tables. | |
:returns: string with where clause filter expression | |
""" | |
tables = tables if tables else [v[0].__table__ | |
for v in cls.__relations__.values()] | |
try: | |
table, rest = key.split('___') | |
except ValueError: | |
table = cls.__table__ | |
rest = key | |
try: | |
field, lookup_type = rest.split('__') | |
except ValueError: | |
field = rest | |
lookup_type = 'eq' | |
assert table in tables, 'Filtering by table that is not joined: {}'.format(table) | |
assert lookup_type in LOOKUP_TYPES, 'Unknown lookup type {}'.format(lookup_type) | |
model = cls.__validate_related_field(table, field) | |
column_expr = "{}.{}".format(model.__table__, model.__columns__[field]) | |
lookup_expr = LOOKUP_TYPES[lookup_type] | |
return "{} {} %({})s".format(column_expr, lookup_expr, key) | |
@classmethod | |
def __pagination_to_limit(cls, page, page_size): | |
"""Just turn the page and page size into proper LIMIT. | |
:param page: integer, page number. | |
:param page_size: integer, number of records per page. | |
:returns: string with `LIMIT ...` clause. | |
""" | |
offset = (page-1) * page_size | |
return 'LIMIT {offset}, {page_size}'.format(**locals()) | |
@classmethod | |
async def get_one(cls, pk): | |
"""Get one row by primary key. | |
:param pk: value of primary key. | |
:returns: dict with row data. | |
""" | |
select_expr = cls.select_expr() | |
sql = 'SELECT {} FROM {} WHERE {} = %s'.format(select_expr, cls.__table__, cls.__pk__) | |
cursor = await pool.execute(sql, pk) | |
result = cursor.fetchone() | |
return result | |
@classmethod | |
async def get_all(cls, where=None, page=1, page_size=50, **kwargs): | |
'''Get multiple rows with filtering, ordering and limit. | |
See Model.filter_expr for more lookup type details. | |
Usage: | |
await Model.get_all(dict(id=2, name__lte=12), page=2, page_size=10) | |
:param where: dict with where params. | |
:param page: page number to retrieve. | |
:param page_size: row count for page | |
''' | |
select_expr = cls.select_expr() | |
sql_args = {} | |
filter_exprs = [] | |
if where: | |
for key, value in where.items(): | |
filter_expr = cls.filter_expr(key, value) | |
filter_exprs.append(filter_expr) | |
sql_args[key] = value | |
limit_expr = cls.__pagination_to_limit(page, page_size) | |
where_expr = 'WHERE ' + ' AND '.join(filter_exprs) if filter_exprs else '' | |
sql = ('SELECT SQL_CALC_FOUND_ROWS {select_expr}\n' | |
'FROM {table} {where_expr}\n' | |
'{limit_expr}').format(select_expr=select_expr, | |
table=cls.__table__, | |
where_expr=where_expr, | |
limit_expr=limit_expr) | |
async with cursor_context(pool) as c: | |
await c.execute(sql, sql_args) | |
data = c.fetchall() | |
await c.execute('SELECT FOUND_ROWS()') | |
found_rows = c.fetchone()['FOUND_ROWS()'] | |
total_pages = math.ceil(found_rows/page_size) | |
result = {'data': data, 'page': page, 'page_size': page_size, 'pages': total_pages} | |
result['next_page'] = page + 1 if page <= total_pages else None | |
result['previuos_page'] = page - 1 if page > 1 else None | |
return result | |
@classmethod | |
def join_on(cls, field, join_type='JOIN'): | |
"""Build join expression based on model foregn keys. | |
:param field: field with foreign key. | |
:param join_type: JOIN_TYPES string. | |
:returns: string with join expression. | |
""" | |
join_type = join_type.upper() | |
assert join_type in JOIN_TYPES, 'Invalid join: {}'.format(join_type) | |
model, rel_field = cls.__relations__[field] | |
rel_field = model.__columns__[rel_field] | |
rel_table = model.__table__ | |
table = cls.__table__ | |
sql = ('{join_type} {rel_table} ' 'ON {rel_table}.{rel_field} = {table}.{field}') | |
return sql.format(**locals()) | |
@classmethod | |
async def get_related(cls, where, tables=None, page=1, page_size=50, **kwargs): | |
"""Same as cls.get_all(...), but also joins with related tables, allows to perform | |
GROUP_CONCAT with GROUP BY. | |
Usage: | |
result = await Article.get_related( | |
dict(id=1, article_status___id=1, article_source_type___description='buzz'), | |
joins=dict(status_id='LEFT JOIN', source_type_id='JOIN')) | |
This assumes that article has one-to-one relationship with two models: one for `article_status` table and | |
another one for `article_source_type`, and following columns exist: | |
* `article_status`.`id` | |
* `article_source_type`.`description` | |
So the resulting query will be something like: | |
SELECT article.updated_at AS article__updated_at, article.source_type_id AS article__source_type_id, | |
article.created_at AS article__created_at, article.id AS article__id, | |
...etc... | |
FROM article | |
LEFT JOIN article_status ON article_status.id = article.status_id | |
JOIN article_source_type ON article_source_type.id = article.source_type_id | |
WHERE article_source_type.description = %(article_source_type___description)s | |
AND article_status.id = %(article_status___id)s AND article.id = %(id)s | |
Example with group_concat: | |
result = await Article.get_related( | |
dict(article_status___description__in=('done', ), | |
article_source_type___description='buzz'), | |
group_concat=[Import], | |
group_by=[(Article, 'id')], | |
joins=dict(id='LEFT JOIN')) | |
Here we peform GROUP_CONCAT aggregation on all fields from Import model, this however require group_by | |
parameter. Just keep in mind this is not very robust, and the resulting `import` dict for each field | |
will have a list of strings as a values (or null). | |
:param where: positional, dict with qery params. | |
:param tables: positional, list of tables to include in select query, defaults to ones in __relations__ | |
:param page: positional, page number to retrieve. | |
:param page_size: positional, row count for page | |
:param joins: dict with mapping of field to join_type. The keys should be model's fields with foreign keys, | |
and values should be join expression from JOIN_TYPES. | |
:param group_concat: List of models, the fields of this models will be aggregated by GROUP_CONCAT. | |
:param group_by: List of tuples of (model, field), where `model` is either current Model or one of related, | |
and `field` is the string with the field name to join on. | |
:returns: List of dicts with resulting data. | |
""" | |
select_exprs = [cls.select_expr_join()] | |
rel_tables = [cls.__table__] | |
sql_args = {} | |
filter_exprs = [] | |
group_by = kwargs.get('group_by') | |
group_concat_models = kwargs.get('group_concat') | |
group_by_expr = '' | |
if group_concat_models: | |
assert group_by, 'group_by is required with group_concat' | |
group_by_expr = cls.group_by_expr(group_by) | |
relations = {c: (m, rc) for c, (m, rc) | |
in cls.__relations__.items() | |
if m.__table__ in tables} if tables else cls.__relations__ | |
for _, (model, _) in relations.items(): | |
if group_concat_models and model in group_concat_models: | |
select_exprs.append(model.select_expr_group_concat()) | |
else: | |
select_exprs.append(model.select_expr_join()) | |
rel_tables.append(model.__table__) | |
for key, value in where.items(): | |
filter_expr = cls.filter_expr_join(key, value, tables=tables) | |
filter_exprs.append(filter_expr) | |
sql_args[key] = value | |
where_expr = 'WHERE ' + ' AND '.join(filter_exprs) if filter_exprs else '' | |
select_expr = ', '.join(select_exprs) | |
rel_tables_expr = ', '.join('{}'.format(t) for t in rel_tables) | |
table = cls.__table__ | |
joins = kwargs.get('joins') | |
if not joins: | |
join_on = '\n'.join(cls.join_on(f) for f in relations) | |
else: | |
join_on = '\n'.join(cls.join_on(f, joins.get(f, 'JOIN')) for f in relations) | |
limit_expr = cls.__pagination_to_limit(page, page_size) | |
sql = ('SELECT SQL_CALC_FOUND_ROWS {select_expr}\n' | |
'FROM {table}\n' | |
'{join_on}\n' | |
'{where_expr}\n' | |
'{group_by_expr}\n' | |
'{limit_expr}').format(**locals()) | |
async with cursor_context(pool) as c: | |
await c.execute(sql, sql_args) | |
raw_result = c.fetchall() | |
await c.execute('SELECT FOUND_ROWS()') | |
found_rows = c.fetchone()['FOUND_ROWS()'] | |
total_pages = math.ceil(found_rows/page_size) | |
result = {'data': [], 'page': page, 'page_size': page_size, 'pages': total_pages} | |
result['next_page'] = page + 1 if page <= total_pages else None | |
result['previuos_page'] = page - 1 if page > 1 else None | |
group_concat_tables = [m.__table__ for m in group_concat_models]if group_concat_models else None | |
for item in raw_result: | |
new_item = defaultdict(dict) | |
for k, v in item.items(): | |
table, column = k.split('__') | |
# Split by GROUP_CONCAT_SEPARATOR, keep in mind that values for group_concat columns will | |
# be a list of strings, unfortunately group_concat just dumps everything into single string. | |
if group_concat_tables and table in group_concat_tables: | |
try: | |
new_item[table][column] = v.split(GROUP_CONCAT_SEPARATOR) | |
except AttributeError: | |
pass | |
else: | |
new_item[table][column] = v | |
result['data'].append(new_item) | |
return result | |
@classmethod | |
async def create(cls, **kwargs): | |
"""Insert new row.""" | |
insert_keys = [key for key in kwargs.keys() if key in cls.__columns__] | |
insert_expr = ', '.join('{}'.format(key) for key in insert_keys) | |
values_expr = ', '.join('%({})s'.format(key) for key in insert_keys) | |
sql = 'INSERT INTO {} ({}) VALUES ({})'.format(cls.__table__, insert_expr, values_expr) | |
cursor = await pool.execute(sql, kwargs) | |
# in case primary key is not AUTO INCREMENT | |
result = cursor.lastrowid or kwargs.get(cls.__pk__, 0) | |
return result | |
@classmethod | |
async def update_one(cls, pk, **kwargs): | |
'''Update by primary key. | |
Usage: | |
await Buzz.update_by_pk(42, title='this is neat') | |
''' | |
update_keys = [key for key in kwargs.keys() if key in cls.__columns__] | |
set_expr = ', '.join('{0}=%({0})s'.format(k) for k in update_keys) | |
sql = 'UPDATE {} SET {} WHERE {}=%(pk)s'.format(cls.__table__, set_expr, cls.__pk__) | |
cursor = await pool.execute(sql, {**kwargs, **{'pk': pk}}) | |
result = cursor.lastrowid | |
return result | |
@classmethod | |
async def update(cls, sargs, wargs): | |
'''Update multiple rows at onece. | |
wargs supports filtering with LOOKUP_TYPES. | |
Usage: | |
await BackfillStatus.update({'description': 'whooo'}, | |
{'description': 'deleted', | |
'id__lte': 2}) | |
:param sargs: SET key/value pairs. | |
:param wargs: WHERE filters. | |
''' | |
where_expr = [] | |
where_args = {} | |
for k, v in wargs.items(): | |
try: | |
where_expr.append(cls.filter_expr(k, v)) | |
where_args[k] = v | |
except AssertionError: | |
pass | |
# adding the __set postfix here, since we need to distinguish k/v | |
# pairs in SET expression from ones in WHERE clause | |
set_keys = [k for k in sargs.keys() if k in cls.__columns__] | |
set_expr = ', '.join('{0}=%({0}__set)s'.format(k) for k in set_keys) | |
if not all([set_expr, where_expr]): | |
raise ValueError('Valid SET and WHERE clauses required') | |
sql = 'UPDATE {} SET {} WHERE {}'.format(cls.__table__, set_expr, ' AND '.join(where_expr)) | |
sql_args = {**where_args, **{k + '__set': sargs[k] for k in set_keys}} | |
cursor = await pool.execute(sql, sql_args) | |
return cursor.rowcount | |
@classmethod | |
async def delete_one(cls, pk): | |
'''DELETE row by primary key.''' | |
sql = 'DELETE FROM {} WHERE {} = %s'.format(cls.__table__, cls.__pk__) | |
cursor = await pool.execute(sql, pk) | |
result = cursor.fetchone() | |
return result | |
@classmethod | |
async def delete(cls, **kwargs): | |
'''DELETE multiple rows. | |
Supports LOOKUP_TYPES in kwargs. | |
''' | |
sql_args = {} | |
filter_exprs = [] | |
for key, value in kwargs.items(): | |
filter_expr = cls.filter_expr(key, value) | |
filter_exprs.append(filter_expr) | |
sql_args[key] = value | |
sql = 'DELETE FROM {} WHERE {}'.format(cls.__table__, ' AND '.join(filter_exprs)) | |
cursor = await pool.execute(sql, sql_args) | |
result = cursor.rowcount | |
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tornado import ioloop | |
from connection import cursor_context | |
async def main(): | |
async with cursor_context() as c: | |
await c.execute('SHOW TABLES') | |
print(c.fetchall()) | |
if __name__ == '__main__': | |
ioloop.IOLoop.instance().run_sync(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ujson | |
import tornado.ioloop | |
from model import Model | |
class ArticleStatus(Model): | |
__pk__ = 'id' | |
__table__ = 'article_status' | |
__columns__ = { | |
'id': 'id', | |
'description': 'description', | |
'created_at': 'created_at', | |
'updated_at': 'updated_at' | |
} | |
class ArticleSourceType(Model): | |
__pk__ = 'id' | |
__table__ = 'article_source_type' | |
__columns__ = { | |
'id': 'id', | |
'description': 'description', | |
'created_at': 'created_at', | |
'updated_at': 'updated_at' | |
} | |
class Import(Model): | |
__pk__ = 'id' | |
__table__ = 'import' | |
__columns__ = { | |
'id': 'id', | |
'import_status_id': 'import_status_id', | |
'article_id': 'article_id', | |
'instant_article_id': 'instant_article_id', | |
'import_status': 'import_status', | |
'errors': 'errors', | |
'created_at': 'created_at', | |
'updated_at': 'updated_at' | |
} | |
class Article(Model): | |
__pk__ = 'id' | |
__table__ = 'article' | |
__columns__ = { | |
'id': 'id', | |
'source_id': 'source_id', | |
'source_type_id': 'source_type_id', | |
'status_id': 'status_id', | |
'created_at': 'created_at', | |
'updated_at': 'updated_at' | |
} | |
__relations__ = { | |
'source_type_id': (ArticleSourceType, 'id'), | |
'status_id': (ArticleStatus, 'id'), | |
'id': (Import, 'article_id'), | |
} | |
async def main(): | |
article_status_id = await ArticleStatus.create(id=1, description='done') | |
article_source_id = await ArticleSourceType.create(id=1, description='buzz') | |
article1_id = await Article.create(source_id=214, | |
source_type_id=1, | |
status_id=1) | |
article2_id = await Article.create(source_id=215, | |
source_type_id=1, | |
status_id=1) | |
# import1_id has two articles | |
import1_id = await Import.create(import_status_id=1, article_id=article1_id) | |
import2_id = await Import.create(import_status_id=1, article_id=article2_id) | |
import3_id = await Import.create(import_status_id=1, article_id=article2_id) | |
result = await Article.get_related( | |
dict(article_status___description__in=('done', ), | |
article_source_type___description='buzz'), | |
group_concat=[Import], | |
group_by=[(Article, 'id')], | |
joins=dict(id='LEFT JOIN'), | |
page_size=1, | |
page=2) | |
json = ujson.dumps(result, indent=2) | |
print("get_related example") | |
print('-'*120) | |
print(json) | |
result = await Import.get_all(page_size=1) | |
json = ujson.dumps(result, indent=2) | |
print("get_related example") | |
print('-'*120) | |
print(json) | |
await Import.delete_one(import1_id) | |
await Import.delete_one(import2_id) | |
await Import.delete_one(import3_id) | |
await Article.delete_one(article1_id) | |
await Article.delete_one(article2_id) | |
await ArticleSourceType.delete_one(article_source_id) | |
await ArticleStatus.delete_one(article_status_id) | |
if __name__ == '__main__': | |
tornado.ioloop.IOLoop.instance().run_sync(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment