Created
January 12, 2013 21:10
-
-
Save jdp/4520486 to your computer and use it in GitHub Desktop.
Quick DSL for building SimpleDB queries
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
SimpleDB Query DSL | |
>>> stmt = select(Star(), 'mydomain') | |
>>> stmt.where(Field('city') == 'Seattle').to_sql() | |
'select * from mydomain where city = "Seattle"' | |
>>> city = Field('city') | |
>>> stmt.where((city == 'Seattle') | (city == 'Portland')).to_sql() | |
'select * from mydomain where (city = "Seattle") or (city = "Portland")' | |
>>> name = Field('name') | |
>>> stmt.where(name != 'John').to_sql() | |
'select * from mydomain where name != "John"' | |
>>> stmt.where((name != 'John') & (name != 'Humberto')).to_sql() | |
'select * from mydomain where (name != "John") and (name != "Humberto")' | |
>>> weight = Field('weight') | |
>>> stmt.where(weight > 34).to_sql() | |
'select * from mydomain where weight > 34' | |
>>> stmt.where(weight >= 65).to_sql() | |
'select * from mydomain where weight >= 65' | |
>>> stmt.where(weight < 34).to_sql() | |
'select * from mydomain where weight < 34' | |
>>> year = Field('Year') | |
>>> stmt.where(year <= 2000).to_sql() | |
'select * from mydomain where Year <= 2000' | |
>>> author = Field('Author') | |
>>> stmt.where(author % 'Henry%').to_sql() | |
'select * from mydomain where Author like "Henry%"' | |
>>> keyword = Field('Keyword') | |
>>> stmt.where((keyword == 'Book') & (author % '%Miller')).to_sql() | |
'select * from mydomain where (Keyword = "Book") and (Author like "%Miller")' | |
>>> stmt.where(~(author % 'Henry%')).to_sql() | |
'select * from mydomain where Author not like "Henry%"' | |
>>> stmt.where(year.between(1998, 2000)).to_sql() | |
'select * from mydomain where Year between 1998 and 2000' | |
>>> stmt.where(year.in_(1998, 2000, 2003)).to_sql() | |
'select * from mydomain where Year in (1998, 2000, 2003)' | |
>>> stmt.where(year.is_null()).to_sql() | |
'select * from mydomain where Year is null' | |
>>> stmt.where(~year.is_null()).to_sql() | |
'select * from mydomain where Year is not null' | |
>>> stmt.where(Every(keyword) == 'Book').to_sql() | |
'select * from mydomain where every(Keyword) = "Book"' | |
>>> title = Field('Title') | |
>>> stmt.where(title == 'The Right Stuff').to_sql() | |
'select * from mydomain where Title = "The Right Stuff"' | |
>>> stmt.where(year > '1985').to_sql() | |
'select * from mydomain where Year > "1985"' | |
>>> rating = Field('Rating') | |
>>> stmt.where(rating.like('****%')).to_sql() | |
'select * from mydomain where Rating like "****%"' | |
>>> pages = Field('Pages') | |
>>> stmt.where(pages < '00320').to_sql() | |
'select * from mydomain where Pages < "00320"' | |
>>> year = Field('Year') | |
>>> stmt.where((year > '1975') & (year < '2008')).to_sql() | |
'select * from mydomain where (Year > "1975") and (Year < "2008")' | |
>>> stmt.where(year.between('1975', '2008')).to_sql() | |
'select * from mydomain where Year between "1975" and "2008"' | |
>>> stmt.where((rating == '***') | (rating == '*****')).to_sql() | |
'select * from mydomain where (Rating = "***") or (Rating = "*****")' | |
>>> stmt.where(((year > '1950') & (year < '1960')) | year.like('193%') | (year == '2007')).to_sql() | |
'select * from mydomain where ((Year > "1950") and (Year < "1960")) or (Year like "193%") or (Year = "2007")' | |
>>> stmt.where((rating == '4 stars') | (rating == '****')).to_sql() | |
'select * from mydomain where (Rating = "4 stars") or (Rating = "****")' | |
>>> stmt.where((keyword == 'Book') & (keyword == 'Hardcover')).to_sql() | |
'select * from mydomain where (Keyword = "Book") and (Keyword = "Hardcover")' | |
>>> stmt.where(Every(keyword).in_('Book', 'Paperback')).to_sql() | |
'select * from mydomain where every(Keyword) in ("Book", "Paperback")' | |
>>> stmt.where(rating == '****').to_sql() | |
'select * from mydomain where Rating = "****"' | |
>>> stmt.where(Every(rating) == '****').to_sql() | |
'select * from mydomain where every(Rating) = "****"' | |
>>> stmt.where((keyword == 'Book') ^ (keyword == 'Hardcover')).to_sql() | |
'select * from mydomain where (Keyword = "Book") intersection (Keyword = "Hardcover")' | |
>>> stmt.where(year < '1980').order_by(year).to_sql() | |
'select * from mydomain where Year < "1980" order by Year' | |
>>> stmt.where((year == '2007') ^ ~author.is_null()).order_by(author, desc=True).to_sql() | |
'select * from mydomain where (Year = "2007") intersection (Author is not null) order by Author desc' | |
>>> stmt.where(year < '1980').order_by(year).limit(2).to_sql() | |
'select * from mydomain where Year < "1980" order by Year limit 2' | |
>>> stmt = select(ItemName(), 'mydomain') | |
>>> stmt.where(ItemName().like('B000%')).order_by(ItemName()).to_sql() | |
'select itemName() from mydomain where itemName() like "B000%" order by itemName()' | |
>>> stmt = select(Count(), 'mydomain') | |
>>> stmt.where(title == 'The Right Stuff').to_sql() | |
'select count(*) from mydomain where Title = "The Right Stuff"' | |
>>> stmt.where(year > '1985').to_sql() | |
'select count(*) from mydomain where Year > "1985"' | |
>>> stmt.limit(500).to_sql() | |
'select count(*) from mydomain limit 500' | |
>>> stmt = select(Star(), 'mydomain') | |
>>> stmt.where(Field('abc`123') == '1').to_sql() | |
'select * from mydomain where `abc``123` = "1"' | |
>>> stmt.where(Field('between') == '1').to_sql() | |
'select * from mydomain where `between` = "1"' | |
""" | |
import re | |
from collections import OrderedDict | |
class Node(object): | |
pass | |
class Literal(Node): | |
def __init__(self, value): | |
self.value = value | |
class Statement(Node): | |
pass | |
class SelectStatement(Statement): | |
def __init__(self, fields=None, from_=None, where=None, order=None, limit=None): | |
self.clauses = OrderedDict() | |
self.clauses['fields'] = fields | |
self.clauses['from'] = from_ | |
self.clauses['where'] = where | |
self.clauses['order'] = order | |
self.clauses['limit'] = limit | |
def factory(self, fields=None, from_=None, where=None, order=None, limit=None): | |
return SelectStatement( | |
fields=fields or self.clauses['fields'], | |
from_=from_ or self.clauses['from'], | |
where=where or self.clauses['where'], | |
order=order or self.clauses['order'], | |
limit=limit or self.clauses['limit'] | |
) | |
def fields(self, fields): | |
return self.factory(fields=FieldsClause(fields)) | |
def from_(self, domain): | |
return self.factory(from_=FromClause(domain)) | |
def where(self, expr): | |
return self.factory(where=WhereClause(expr)) | |
def order_by(self, field, desc=False): | |
return self.factory(order=OrderClause(field, desc)) | |
def limit(self, count): | |
return self.factory(limit=LimitClause(count)) | |
def to_sql(self): | |
compiler = Compiler() | |
return compiler.compile(self) | |
def __str__(self): | |
return self.to_sql() | |
class Clause(Node): | |
pass | |
class FieldsClause(Clause): | |
def __init__(self, fields): | |
self.fields = fields | |
class FromClause(Clause): | |
def __init__(self, domain): | |
self.domain = domain | |
class WhereClause(Clause): | |
def __init__(self, expr): | |
self.expr = expr | |
class OrderClause(Clause): | |
def __init__(self, field, desc=False): | |
self.field = field | |
self.desc = desc | |
class LimitClause(Clause): | |
def __init__(self, limit): | |
self.limit = limit | |
class Expr(Node): | |
def __eq__(self, other): | |
return EqualsOp(self, other) | |
def __ne__(self, other): | |
return NotEqualsOp(self, other) | |
def __gt__(self, other): | |
return GreaterThanOp(self, other) | |
def __ge__(self, other): | |
return GreaterThanOrEqualsOp(self, other) | |
def __lt__(self, other): | |
return LessThanOp(self, other) | |
def __le__(self, other): | |
return LessThanOrEqualsOp(self, other) | |
def __mod__(self, other): | |
return LikeOp(self, other) | |
def __or__(self, other): | |
if isinstance(self, OrOp): | |
self.terms.append(other) | |
return self | |
return OrOp(self, other) | |
def __and__(self, other): | |
if isinstance(self, AndOp): | |
self.terms.append(other) | |
return self | |
return AndOp(self, other) | |
def __xor__(self, other): | |
return IntersectionOp(self, other) | |
def __invert__(self): | |
raise NotImplementedError | |
def like(self, other): | |
return LikeOp(self, other) | |
def between(self, lower, upper): | |
return BetweenOp(self, (lower, upper)) | |
def in_(self, *terms): | |
return InOp(self, terms) | |
def is_null(self): | |
return IsOp(self, Literal('null')) | |
class Field(Expr): | |
def __init__(self, name): | |
self.name = name | |
class Every(Field): | |
def __init__(self, field): | |
self.field = field | |
class ItemName(Field, Literal): | |
def __init__(self): | |
Field.__init__(self, 'itemName()') | |
Literal.__init__(self, 'itemName()') | |
class Star(Literal): | |
def __init__(self): | |
self.value = '*' | |
class Count(Field, Literal): | |
def __init__(self): | |
Field.__init__(self, 'count(*)') | |
Literal.__init__(self, 'count(*)') | |
class Op(Expr): | |
pass | |
class BinaryOp(Op): | |
def __init__(self, left, right): | |
self.left = left | |
self.right = right | |
class EqualsOp(BinaryOp): | |
pass | |
class NotEqualsOp(BinaryOp): | |
pass | |
class GreaterThanOp(BinaryOp): | |
pass | |
class GreaterThanOrEqualsOp(BinaryOp): | |
pass | |
class LessThanOp(BinaryOp): | |
pass | |
class LessThanOrEqualsOp(BinaryOp): | |
pass | |
class LikeOp(BinaryOp): | |
def __invert__(self): | |
return NotLikeOp(self.left, self.right) | |
class NotLikeOp(BinaryOp): | |
def __invert__(self): | |
return LikeOp(self.left, self.right) | |
class IsOp(BinaryOp): | |
def __invert__(self): | |
return IsNotOp(self.left, self.right) | |
class IsNotOp(BinaryOp): | |
def __invert__(self): | |
return IsOp(self.left, self.right) | |
class BetweenOp(BinaryOp): | |
pass | |
class InOp(BinaryOp): | |
pass | |
class IntersectionOp(BinaryOp): | |
pass | |
class AssociativeOp(Op): | |
def __init__(self, *terms): | |
self.terms = list(terms) | |
class OrOp(AssociativeOp): | |
pass | |
class AndOp(AssociativeOp): | |
pass | |
class Compiler(object): | |
op_tokens = { | |
EqualsOp: '=', | |
NotEqualsOp: '!=', | |
GreaterThanOp: '>', | |
GreaterThanOrEqualsOp: '>=', | |
LessThanOp: '<', | |
LessThanOrEqualsOp: '<=', | |
LikeOp: 'like', | |
NotLikeOp: 'not like', | |
IsOp: 'is', | |
IsNotOp: 'is not', | |
OrOp: 'or', | |
AndOp: 'and', | |
IntersectionOp: 'intersection' | |
} | |
reserved_words = [ | |
'or', 'and', 'not', 'from', 'where', 'select', 'like', 'null', 'is', | |
'order', 'by', 'asc', 'desc', 'in', 'between', 'intersection', 'limit', | |
'every' | |
] | |
def compile_field(self, field): | |
if isinstance(field, Every): | |
return "every({})".format(self.compile_field(field.field)) | |
FIELD_RE = re.compile(r'^[a-z0-9_$]+$', re.I) | |
if not FIELD_RE.match(field.name) or field.name in self.reserved_words: | |
return "`{}`".format(field.name.replace('`', '``')) | |
else: | |
return field.name | |
def compile_binary_op(self, op, depth=0): | |
if isinstance(op, BetweenOp): | |
fmt = "{} between {} and {}".format( | |
self.compile_expr(op.left, depth + 1), | |
self.compile_expr(op.right[0], depth + 1), | |
self.compile_expr(op.right[1], depth + 1) | |
) | |
elif isinstance(op, InOp): | |
fmt = "{} in ({})".format( | |
self.compile_expr(op.left, depth + 1), | |
', '.join([self.compile_expr(t, depth + 1) for t in op.right]) | |
) | |
else: | |
fmt = "{} {} {}".format( | |
self.compile_expr(op.left, depth + 1), | |
self.op_tokens.get(op.__class__), | |
self.compile_expr(op.right, depth + 1) | |
) | |
if depth > 0: | |
return '(' + fmt + ')' | |
else: | |
return fmt | |
def compile_associative_op(self, expr, depth=0): | |
token = self.op_tokens.get(expr.__class__) | |
terms = [self.compile_expr(t, depth + 1) for t in expr.terms] | |
fmt = ' {} '.format(token).join(terms) | |
if depth > 0: | |
return '(' + fmt + ')' | |
else: | |
return fmt | |
def compile_expr(self, expr, depth=0): | |
if isinstance(expr, basestring): | |
return '"{}"'.format(expr.replace('"', '""')) | |
elif isinstance(expr, (int, float)): | |
return str(expr) | |
elif isinstance(expr, Literal): | |
return expr.value | |
elif isinstance(expr, Field): | |
return self.compile_field(expr) | |
elif isinstance(expr, BinaryOp): | |
return self.compile_binary_op(expr, depth) | |
elif isinstance(expr, AssociativeOp): | |
return self.compile_associative_op(expr, depth) | |
else: | |
raise ValueError(expr) | |
def compile_clause(self, clause): | |
if isinstance(clause, FieldsClause): | |
return "select {0}".format(', '.join([self.compile(f) for f in clause.fields])) | |
if isinstance(clause, FromClause): | |
return "from {0}".format(self.compile(clause.domain)) | |
elif isinstance(clause, WhereClause): | |
return "where {0}".format(self.compile(clause.expr)) | |
elif isinstance(clause, OrderClause): | |
fmt = "order by {}".format(self.compile(clause.field)) | |
if clause.desc: | |
fmt += " desc" | |
return fmt | |
elif isinstance(clause, LimitClause): | |
return "limit {}".format(self.compile(clause.limit)) | |
else: | |
raise ValueError(clause) | |
def compile_statement(self, stmt): | |
return ' '.join(map(self.compile, filter(None, stmt.clauses.values()))) | |
def compile(self, node): | |
if isinstance(node, Statement): | |
return self.compile_statement(node) | |
elif isinstance(node, Clause): | |
return self.compile_clause(node) | |
else: | |
return self.compile_expr(node) | |
def select(fields, domain): | |
def field_promote(name): | |
if isinstance(name, Node): | |
return name | |
return Field(name) | |
try: | |
iter(fields) | |
except: | |
fields = [fields] | |
stmt = SelectStatement() | |
return stmt.fields(map(field_promote, fields)).from_(field_promote(domain)) | |
if __name__ == '__main__': | |
import doctest | |
doctest.testmod() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment