Created
April 30, 2009 14:39
-
-
Save antsaasma/104474 to your computer and use it in GitHub Desktop.
Python decompiler for language integrated queries
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import opcode | |
import new | |
import bisect | |
from itertools import islice, chain | |
def sliding_window(seq, size=2): | |
iterator = iter(seq) | |
window = tuple(islice(iterator, size)) | |
if len(window) == size: | |
while True: | |
yield window | |
window = window[1:] + (iterator.next(),) | |
class Function(object): | |
def __init__(self, code, globals_, name, argdefs, closure, doc): | |
self.code = code | |
self.globals = globals_ | |
self.name = name | |
self.argdefs = argdefs | |
self.closure = closure | |
@classmethod | |
def from_function(cls, func): | |
return cls( | |
Codeblock.from_code(func.func_code), | |
func.func_globals, | |
func.func_name, | |
func.func_defaults, | |
func.func_closure, | |
func.func_doc) | |
def analyze(self, argvalues=None): | |
return self.code.analyze(self, argvalues) | |
class Codeblock(object): | |
def __init__(self, argcount, nlocals, stacksize, flags, codestring, constants, names, | |
varnames, filename, name, firstlineno, lnotab, freevars, cellvars): | |
self.argcount = argcount | |
self.nlocals = nlocals | |
self.stacksize = stacksize | |
self.flags = flags | |
self.codestring = codestring | |
self.constants = constants | |
self.names = names | |
self.varnames = varnames | |
self.filename = filename | |
self.name = name | |
self.firstlineno = firstlineno | |
self.lnotab = lnotab | |
self.freevars = freevars | |
self.cellvars = cellvars | |
@classmethod | |
def from_code(self, code): | |
return Codeblock( | |
code.co_argcount, | |
code.co_nlocals, | |
code.co_stacksize, | |
code.co_flags, | |
code.co_code, | |
code.co_consts, | |
code.co_names, | |
code.co_varnames, | |
code.co_filename, | |
code.co_name, | |
code.co_firstlineno, | |
code.co_lnotab, | |
code.co_freevars, | |
code.co_cellvars | |
) | |
@property | |
def opcodes(self): | |
idx = 0 | |
while idx < len(self.codestring): | |
op = ord(self.codestring[idx]) | |
if op > opcode.HAVE_ARGUMENT: | |
arg = ord(self.codestring[idx+1]) + 256*ord(self.codestring[idx+2]) | |
yield Opcode.from_code(idx, op, arg) | |
idx += 3 | |
else: | |
yield Opcode.from_code(idx, op) | |
idx += 1 | |
def analyze(self, func, argvalues=None): | |
self.globals = func.globals | |
opcode_list = list(self.opcodes) | |
opnr_map = dict((op.idx, opnr) for opnr, op in enumerate(opcode_list)) | |
basic_block_lines = [0, len(opcode_list)] | |
for op in opcode_list: | |
if op.is_jump: | |
basic_block_lines.append(opnr_map[op.next_op]) | |
basic_block_lines.append(opnr_map[op.jump_target]) | |
basic_block_lines = sorted(set(basic_block_lines)) | |
blocks = [] | |
for block_nr, (start,end) in enumerate(sliding_window(basic_block_lines)): | |
blocks.append(BasicBlock(block_nr, opcode_list[start:end])) | |
self.return_block = BasicBlock(-1, []) | |
for block, next_block in sliding_window(chain(blocks,[None])): | |
if any(op.is_return for op in block.opcodes): | |
block.target = self.return_block | |
self.return_block.incoming.append(block) | |
else: | |
last_op = block.opcodes[-1] | |
if last_op.is_jump: | |
block.target = blocks[bisect.bisect_left(basic_block_lines, opnr_map[last_op.jump_target])] | |
block.target.incoming.append(block) | |
if not last_op.is_jump or last_op.conditional: | |
block.next = next_block | |
if next_block: | |
block.next.incoming.append(block) | |
blocks[0].output(self, [], args=argvalues) | |
return blocks | |
def idom(block): | |
nodes_to_visit = block.incoming[:] | |
n = len(nodes_to_visit) - 1 | |
visited_nodes = set() | |
while True: | |
node = nodes_to_visit.pop(0) | |
if node in visited_nodes: | |
n -= 1 | |
else: | |
n += max(0,len(node.incoming) - 1) | |
nodes_to_visit.extend(node.incoming) | |
visited_nodes.add(node) | |
if not n: | |
return node | |
def optimize_phi(phinode): | |
if phinode.true == phinode.false: | |
return phinode.true | |
if phinode.condition == phinode.true: | |
return BinaryExpression('OR', phinode.condition, phinode.false) | |
if phinode.condition == phinode.false: | |
return BinaryExpression('AND', phinode.condition, phinode.true) | |
if isinstance(phinode.true, PHI) and phinode.true.false == phinode.false: | |
phinode = PHI(BinaryExpression('AND', phinode.condition, phinode.true.condition), phinode.true.true, phinode.false) | |
if isinstance(phinode.false, PHI) and phinode.false.true == phinode.true: | |
phinode = PHI(BinaryExpression('OR', phinode.condition, phinode.false.condition), phinode.true, phinode.false.false) | |
if isinstance(phinode.condition, BinaryExpression) and phinode.condition.op == 'AND' and phinode.true in [phinode.condition.left, phinode.condition.right]: | |
return BinaryExpression('OR', phinode.condition, phinode.false) | |
if isinstance(phinode.condition, BinaryExpression) and phinode.condition.op == 'OR' and phinode.false in [phinode.condition.left, phinode.condition.right]: | |
return BinaryExpression('AND', phinode.condition, phinode.true) | |
if isinstance(phinode.false, BinaryExpression) and phinode.false.op == 'AND': | |
if phinode.true == phinode.false.left: | |
return BinaryExpression('AND', BinaryExpression('OR', phinode.condition, phinode.false.right), phinode.true) | |
if phinode.true == phinode.false.right: | |
return BinaryExpression('AND', BinaryExpression('OR', phinode.condition, phinode.false.left), phinode.true) | |
# Sink PHI | |
if isinstance(phinode.true, BinaryExpression) and phinode.false in [phinode.true.left, phinode.true.right]: | |
if phinode.true.op == '*': | |
if phinode.false == phinode.true.left: | |
return BinaryExpression(phinode.true.op, phinode.true.left, PHI(phinode.condition, phinode.true.right, Literal(1))) | |
else: | |
return BinaryExpression(phinode.true.op, phinode.true.right, PHI(phinode.condition, phinode.true.left, Literal(1))) | |
return phinode | |
def phinode(block, varname): | |
dominator = idom(block) | |
base_value = dominator.get_var(varname) | |
def recurse(node, current_value): | |
if node is block: | |
return current_value | |
if node.is_terminal: | |
return None | |
value = current_value if varname not in node.local_assignments else node.local_assignments[varname] | |
if node.divergent: | |
true_value = recurse(node.target_if_true, value) | |
false_value = recurse(node.target_if_false, value) | |
if true_value is None and false_value is None: | |
return None | |
if true_value is None: | |
return false_value | |
if false_value is None: | |
return true_value | |
phinode = PHI(node.condition, true_value, false_value) | |
return optimize_phi(phinode) | |
return recurse(node.rtarget, value) | |
return recurse(dominator, base_value) | |
class BasicBlock(object): | |
def __init__(self, nr, opcodes): | |
self.nr = nr | |
self.opcodes = opcodes | |
self.target = None | |
self.next = None | |
self.incoming = [] | |
self.stack = None | |
self.inverse = None | |
self.numvisits = 0 | |
self.local_assignments = {} | |
divergent = property(lambda s: s.next and s.target) | |
target_if_true = property(lambda s: s.next if s.inverse else s.target) | |
target_if_false = property(lambda s: s.target if s.inverse else s.next) | |
is_terminal = property(lambda s: not s.next and not s.target) | |
rtarget = property(lambda s: s.next if s.next else s.target) | |
def output(self, code, in_stack, ssa=None, visited=None, args=None): | |
self.numvisits += 1 | |
if visited is None: | |
visited = set() | |
if args is not None: | |
self.local_assignments = dict(zip(code.varnames[:code.argcount], args)) | |
else: | |
self.local_assignments = dict((varname, varname) for varname in code.varnames[:code.argcount]) | |
self.local_assignments['RETURN'] = None | |
if self.numvisits < len(self.incoming): | |
return | |
visited.add(self.nr) | |
if len(in_stack) > 0: | |
stack = [phinode(self, "_STACK_%d" % (pos-1)) for pos in xrange(len(in_stack), 0, -1)] | |
else: | |
stack = [] | |
for op in self.opcodes: | |
op.apply(code, stack, self) | |
if op.is_jump: | |
if op.conditional: | |
self.condition = stack[-1] | |
self.inverse = op.inverse | |
self.stack = stack | |
if len(stack): | |
for pos, value in enumerate(reversed(stack)): | |
self.local_assignments["_STACK_%d" % pos] = value | |
if self.next: | |
self.next.output(code, stack, ssa, visited) | |
if self.target: | |
self.target.output(code, stack, ssa, visited) | |
return stack | |
def get_var(self, name): | |
if name in self.local_assignments: | |
return self.local_assignments[name] | |
return phinode(self, name) | |
def __repr__(self): | |
return "<Block %d %r %r>" % (self.nr, "B%s" % self.next.nr if self.next else None, "B%s" % self.target.nr if self.target else None ) | |
class Expression(object): | |
base_value = property(lambda self: self) | |
def __eq__(self, other): | |
a,b = self.base_value,other.base_value | |
return type(a) == type(b) and a.__dict__ == b.__dict__ | |
def __str__(self): | |
return repr(self) | |
class ValueExpression(Expression): | |
base_value = property(lambda self: self.value.base_value if isinstance(self.value, Expression) else self.value) | |
def __init__(self, value): | |
self.value = value | |
def __repr__(self): | |
return repr(self.value) | |
class PHI(Expression): | |
def __init__(self, condition, true, false): | |
self.condition = condition | |
self.true = true | |
self.false = false | |
def __repr__(self): | |
retval = "(CASE\n" | |
caseexp = self | |
while isinstance(caseexp, PHI): | |
retval += " WHEN %r THEN %r\n" % (caseexp.condition, caseexp.true) | |
caseexp = caseexp.false | |
return retval + " ELSE %r\nEND)" % caseexp | |
return "PHI(%r, %r, %r)" % (self.condition, self.true, self.false) | |
class LoadVar(Expression): | |
base_value = property(lambda self: self.value.base_value if isinstance(self.value, Expression) else self.value) | |
def __init__(self, name, value): | |
self.name = name | |
self.value = value | |
def __repr__(self): | |
return "(%s)" % self.value | |
class GetAttr(Expression): | |
def __init__(self, obj, name): | |
self.obj = obj | |
self.name = name | |
def __repr__(self): | |
return "%r.%s" % (self.obj, self.name) | |
class Literal(Expression): | |
def __init__(self, value): | |
self.value = value | |
def __repr__(self): | |
return "%r" % self.value | |
class BinaryExpression(Expression): | |
def __init__(self, op, left, right): | |
self.op = op | |
self.left = left | |
self.right = right | |
def __repr__(self): | |
return "(%r %s %r)" % (self.left, self.op, self.right) | |
class TupleExpression(Expression): | |
def __init__(self, values): | |
self.values = values | |
def __repr__(self): | |
return repr(self.values) | |
class FunctionCall(Expression): | |
def __init__(self, func, args): | |
self.func = func | |
self.args = args | |
def __repr__(self): | |
return "%s(%s)" % (self.func.__name__, ", ".join(map(repr, self.args))) | |
class Opcode(object): | |
opmap = {} | |
class __metaclass__(type): | |
def __init__(self, name, bases, dict_): | |
if name in opcode.opname: | |
self.opmap[opcode.opname.index(name)] = self | |
type.__init__(self, name, bases, dict_) | |
def __init__(self, code, arg=None, idx=None): | |
self.code = code | |
self.arg = arg | |
self.idx = idx | |
@classmethod | |
def from_code(cls, idx, code, arg=None): | |
return cls.opmap.get(code, Opcode)(code, arg, idx) | |
name = property(lambda s: opcode.opname[s.code]) | |
has_arg = property(lambda s: s.code > opcode.HAVE_ARGUMENT) | |
length = property(lambda s: 3 if s.has_arg else 1) | |
next_op = property(lambda s: s.idx + s.length) | |
is_jump = property(lambda s: s.is_absjump or s.is_reljump) | |
is_absjump = property(lambda s: s.code in opcode.hasjabs) | |
is_reljump = property(lambda s: s.code in opcode.hasjrel) | |
is_return = False | |
def format(self, code): | |
return str(self) | |
def __str__(self): | |
if self.has_arg: | |
return "%4s: %s %s" % (self.idx, self.name, self.arg) | |
else: | |
return "%4s: %s" % (self.idx, self.name) | |
def __repr__(self): | |
return str(self) | |
def apply(self, code, stack, block): | |
raise NotImplementedError("Apply not implemented for %s" % self.name) | |
class AbsJump(Opcode): | |
conditional = False | |
jump_target = property(lambda s: s.arg) | |
def __str__(self): | |
return "%4s: %s +%d to %d" % (self.idx, self.name, self.arg, self.jump_target) | |
def apply(self, code, stack, block): | |
pass | |
class JUMP_ABSOLUTE(AbsJump): | |
pass | |
class RelJump(Opcode): | |
conditional = True | |
jump_target = property(lambda s: s.idx + s.length + s.arg) | |
def __str__(self): | |
return "%4s: %s +%d to %d" % (self.idx, self.name, self.arg, self.jump_target) | |
def apply(self, code, stack, block): | |
pass | |
class FOR_ITER(RelJump): pass | |
class JUMP_FORWARD(RelJump): | |
conditional = False | |
class JUMP_IF_FALSE(RelJump): | |
inverse = True | |
class JUMP_IF_TRUE(RelJump): | |
inverse = False | |
class SETUP_LOOP(RelJump): pass | |
class SETUP_EXCEPT(RelJump): pass | |
class SETUP_FINALLY(RelJump): pass | |
class LocalOp(Opcode): | |
def format(self, code): | |
return Opcode.format(self, code) + " (%s)" % (code.varnames[self.arg]) | |
class LOAD_FAST(LocalOp): | |
def apply(self, code, stack, block): | |
name = code.varnames[self.arg] | |
stack.append(LoadVar(name, block.get_var(name))) | |
class LOAD_ATTR(Opcode): | |
def apply(self, code, stack, block): | |
name = code.names[self.arg] | |
stack.append(GetAttr(stack.pop(), name)) | |
def format(self, code): | |
return Opcode.format(self, code) + " (%s)" % (code.names[self.arg]) | |
class STORE_FAST(LocalOp): | |
def apply(self, code, stack, block): | |
value = stack.pop() | |
block.local_assignments[code.varnames[self.arg]] = value | |
class LOAD_CONST(Opcode): | |
def format(self, code): | |
return Opcode.format(self, code) + " (%r)" % (code.constants[self.arg]) | |
def apply(self, code, stack, block): | |
stack.append(Literal(code.constants[self.arg])) | |
class BinaryOp(Opcode): | |
def apply(self, code, stack, block): | |
right = stack.pop() | |
left = stack.pop() | |
stack.append(BinaryExpression(self.operator, left, right)) | |
class COMPARE_OP(BinaryOp): | |
operator = property(lambda s: opcode.cmp_op[s.arg]) | |
def format(self, code): | |
return Opcode.format(self, code) + " (%s)" % (self.operator) | |
class BINARY_ADD(BinaryOp): | |
operator = "+" | |
class BINARY_SUBTRACT(BinaryOp): | |
operator = "-" | |
class BINARY_MULTIPLY(BinaryOp): | |
operator = "*" | |
class INPLACE_MULTIPLY(BinaryOp): | |
operator = "*" | |
class BINARY_DIVIDE(BinaryOp): | |
operator = "/" | |
class POP_TOP(Opcode): | |
def apply(self, code, stack, block): | |
stack.pop() | |
class RETURN_VALUE(Opcode): | |
is_return = True | |
def apply(self, code, stack, block): | |
block.local_assignments['RETURN'] = stack.pop() | |
import __builtin__ | |
class LOAD_GLOBAL(Opcode): | |
def apply(self, code, stack, block): | |
name = code.names[self.arg] | |
if name in code.globals: | |
stack.append(code.globals[name]) | |
else: | |
stack.append(getattr(__builtin__, name)) | |
class CALL_FUNCTION(Opcode): | |
def apply(self, code, stack, block): | |
args = [stack.pop() for i in xrange(self.arg)][::-1] | |
func = stack.pop() | |
if isinstance(func, types.FunctionType): | |
f = Function.from_function(func) | |
f.analyze(args) | |
stack.append(f.code.return_block.get_var('RETURN')) | |
else: | |
stack.append(FunctionCall(func, args)) | |
class DUP_TOP(Opcode): | |
def apply(self, code, stack, block): | |
stack.append(stack[-1]) | |
class ROT_THREE(Opcode): | |
def apply(self, code, stack, block): | |
v = stack.pop() | |
k = stack.pop() | |
x = stack.pop() | |
stack.extend([v, x, k]) | |
class ROT_TWO(Opcode): | |
def apply(self, code, stack, block): | |
v = stack.pop() | |
k = stack.pop() | |
stack.extend([v, k]) | |
class BUILD_TUPLE(Opcode): | |
def apply(self, code, stack, block): | |
stack[-self.arg:] = [TupleExpression(tuple(stack[-self.arg:]))] | |
class UNPACK_SEQUENCE(Opcode): | |
def apply(self, code, stack, block): | |
value = stack.pop() | |
if isinstance(value, TupleExpression): | |
if len(value.values) != self.arg: | |
raise TypeError("Cannot unpack a tuple of %d to %d elements") % (len(value.values), self.arg) | |
stack.extend(reversed(value.values)) | |
else: | |
raise NotImplementedError("Cannot unpack %r" % value) | |
import sqlalchemy | |
import sqlalchemy.orm | |
import types | |
class SqlalchemyCompiler(object): | |
def process(self, node): | |
if isinstance(node, Expression): | |
ret = getattr(self, 'visit_%s' % type(node).__name__)(node) | |
return ret | |
raise Exception("Invalid node %r" % node) | |
def visit_PHI(self, node): | |
phinode = node | |
whens = [] | |
while isinstance(phinode, PHI): | |
whens.append((self.process(phinode.condition), self.process(phinode.true))) | |
phinode = phinode.false | |
return sqlalchemy.case(whens, else_=self.process(phinode)) | |
def visit_BinaryExpression(self, node): | |
opmap = { | |
'<': sqlalchemy.sql.operators.lt, | |
'>': sqlalchemy.sql.operators.gt, | |
'==': sqlalchemy.sql.operators.eq, | |
'+': sqlalchemy.sql.operators.add, | |
'-': sqlalchemy.sql.operators.sub, | |
'*': sqlalchemy.sql.operators.mul, | |
'/': sqlalchemy.sql.operators.div, | |
'AND': sqlalchemy.sql.operators.and_, | |
'OR': sqlalchemy.sql.operators.or_, | |
} | |
return opmap[node.op](self.process(node.left), self.process(node.right)) | |
def visit_LoadVar(self, node): | |
return self.process(node.value) | |
def visit_Literal(self, node): | |
return sqlalchemy.literal(node.value) | |
def visit_GetAttr(self, node): | |
obj = self.process(node.obj) | |
value = getattr(self.process(node.obj), node.name) | |
if isinstance(value, property): | |
f = Function.from_function(value.fget) | |
f.analyze([ValueExpression(obj)]) | |
return self.process(f.code.return_block.get_var('RETURN')) | |
return value | |
def visit_ValueExpression(self, node): | |
return node.value | |
funcs = { | |
len: sqlalchemy.func.length, | |
} | |
def visit_FunctionCall(self, node): | |
return self.funcs[node.func](*map(self.process, node.args)) | |
import types | |
class LambdaQuery(sqlalchemy.orm.Query): | |
compiler = SqlalchemyCompiler() | |
def filter(self, condition): | |
if isinstance(condition, types.FunctionType): | |
f = Function.from_function(condition) | |
f.analyze(ValueExpression(e.entity.class_) for e in self._entities) | |
retval = f.code.return_block.get_var('RETURN') | |
return super(LambdaQuery, self).filter(self.compiler.process(retval)) | |
return super(LambdaQuery, self).filter(self, condition) | |
def output_callgraph(func): | |
f = Function.from_function(func) | |
blocks = f.analyze() | |
import pydot | |
g = pydot.Dot() | |
b_name = lambda b: "Block %d" % b.nr | |
for block in blocks: | |
node = pydot.Node(b_name(block), shape="plaintext") | |
lbl = b_name(block) + "<BR/>" | |
for op in block.opcodes: | |
lbl += "%s<BR/>" % op.format(f.code).replace(':','-').replace('<','lt').replace('>','gt') | |
lbl = '<<TABLE CELLSPACING="0" CELLPADDING="1" BORDER="0" CELLBORDER="1" ><TR><TD>%s</TD></TR></TABLE>>' % lbl[0:400] | |
node.set_label(lbl) | |
if block.inverse: | |
node.set_color('red') | |
g.add_node(node) | |
if block.incoming: | |
for parent in block.incoming: | |
edge = pydot.Edge(b_name(parent), b_name(block)) | |
#if parent.target and parent.next: | |
#if (block is parent.next and not parent.inverse) or (block is not parent.next and parent.inverse): | |
# edge.set_label("not "+str(parent.stack[-1])[0:20].replace('(','[').replace(')',']')) | |
#else: | |
# edge.set_label(str(parent.stack[-1])[0:20].replace('(','[').replace(')',']')) | |
g.add_edge(edge) | |
g.write_png('code.png') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
from analyze import LambdaQuery | |
from sqlalchemy import Column, String, Integer, Float, Boolean, create_engine | |
from sqlalchemy.ext.declarative import declarative_base | |
from sqlalchemy.orm import sessionmaker, Query | |
Base = declarative_base(bind=create_engine('sqlite://')) | |
session = sessionmaker(query_cls=LambdaQuery)() | |
class User(Base): | |
__tablename__ = 'user' | |
id = Column(Integer, primary_key=True) | |
name = Column(String(32)) | |
age = Column(Integer) | |
shoe_size = Column(Integer) | |
@property | |
def importance(self): | |
return larger(self.age, self.shoe_size) + len(self.name) | |
def larger(a,b): | |
if a > b: | |
return a | |
return b | |
print session.query(User).filter(lambda u: u.name == 'ants' and u.importance > 10) | |
#Output: | |
#SELECT user.id AS user_id, user.name AS user_name, user.age AS user_age, user.shoe_size AS user_shoe_size | |
#FROM user | |
#WHERE user.name = ? AND CASE WHEN (user.age > user.shoe_size) THEN user.age ELSE user.shoe_size END + length(user.name) > ? |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment