Skip to content

Instantly share code, notes, and snippets.

@antsaasma
Created April 30, 2009 14:39
Show Gist options
  • Save antsaasma/104474 to your computer and use it in GitHub Desktop.
Save antsaasma/104474 to your computer and use it in GitHub Desktop.
Python decompiler for language integrated queries
# -*- coding: utf-8 -*-
import opcode
import new
import bisect
from itertools import islice, chain
def sliding_window(seq, size=2):
iterator = iter(seq)
window = tuple(islice(iterator, size))
if len(window) == size:
while True:
yield window
window = window[1:] + (iterator.next(),)
class Function(object):
def __init__(self, code, globals_, name, argdefs, closure, doc):
self.code = code
self.globals = globals_
self.name = name
self.argdefs = argdefs
self.closure = closure
@classmethod
def from_function(cls, func):
return cls(
Codeblock.from_code(func.func_code),
func.func_globals,
func.func_name,
func.func_defaults,
func.func_closure,
func.func_doc)
def analyze(self, argvalues=None):
return self.code.analyze(self, argvalues)
class Codeblock(object):
def __init__(self, argcount, nlocals, stacksize, flags, codestring, constants, names,
varnames, filename, name, firstlineno, lnotab, freevars, cellvars):
self.argcount = argcount
self.nlocals = nlocals
self.stacksize = stacksize
self.flags = flags
self.codestring = codestring
self.constants = constants
self.names = names
self.varnames = varnames
self.filename = filename
self.name = name
self.firstlineno = firstlineno
self.lnotab = lnotab
self.freevars = freevars
self.cellvars = cellvars
@classmethod
def from_code(self, code):
return Codeblock(
code.co_argcount,
code.co_nlocals,
code.co_stacksize,
code.co_flags,
code.co_code,
code.co_consts,
code.co_names,
code.co_varnames,
code.co_filename,
code.co_name,
code.co_firstlineno,
code.co_lnotab,
code.co_freevars,
code.co_cellvars
)
@property
def opcodes(self):
idx = 0
while idx < len(self.codestring):
op = ord(self.codestring[idx])
if op > opcode.HAVE_ARGUMENT:
arg = ord(self.codestring[idx+1]) + 256*ord(self.codestring[idx+2])
yield Opcode.from_code(idx, op, arg)
idx += 3
else:
yield Opcode.from_code(idx, op)
idx += 1
def analyze(self, func, argvalues=None):
self.globals = func.globals
opcode_list = list(self.opcodes)
opnr_map = dict((op.idx, opnr) for opnr, op in enumerate(opcode_list))
basic_block_lines = [0, len(opcode_list)]
for op in opcode_list:
if op.is_jump:
basic_block_lines.append(opnr_map[op.next_op])
basic_block_lines.append(opnr_map[op.jump_target])
basic_block_lines = sorted(set(basic_block_lines))
blocks = []
for block_nr, (start,end) in enumerate(sliding_window(basic_block_lines)):
blocks.append(BasicBlock(block_nr, opcode_list[start:end]))
self.return_block = BasicBlock(-1, [])
for block, next_block in sliding_window(chain(blocks,[None])):
if any(op.is_return for op in block.opcodes):
block.target = self.return_block
self.return_block.incoming.append(block)
else:
last_op = block.opcodes[-1]
if last_op.is_jump:
block.target = blocks[bisect.bisect_left(basic_block_lines, opnr_map[last_op.jump_target])]
block.target.incoming.append(block)
if not last_op.is_jump or last_op.conditional:
block.next = next_block
if next_block:
block.next.incoming.append(block)
blocks[0].output(self, [], args=argvalues)
return blocks
def idom(block):
nodes_to_visit = block.incoming[:]
n = len(nodes_to_visit) - 1
visited_nodes = set()
while True:
node = nodes_to_visit.pop(0)
if node in visited_nodes:
n -= 1
else:
n += max(0,len(node.incoming) - 1)
nodes_to_visit.extend(node.incoming)
visited_nodes.add(node)
if not n:
return node
def optimize_phi(phinode):
if phinode.true == phinode.false:
return phinode.true
if phinode.condition == phinode.true:
return BinaryExpression('OR', phinode.condition, phinode.false)
if phinode.condition == phinode.false:
return BinaryExpression('AND', phinode.condition, phinode.true)
if isinstance(phinode.true, PHI) and phinode.true.false == phinode.false:
phinode = PHI(BinaryExpression('AND', phinode.condition, phinode.true.condition), phinode.true.true, phinode.false)
if isinstance(phinode.false, PHI) and phinode.false.true == phinode.true:
phinode = PHI(BinaryExpression('OR', phinode.condition, phinode.false.condition), phinode.true, phinode.false.false)
if isinstance(phinode.condition, BinaryExpression) and phinode.condition.op == 'AND' and phinode.true in [phinode.condition.left, phinode.condition.right]:
return BinaryExpression('OR', phinode.condition, phinode.false)
if isinstance(phinode.condition, BinaryExpression) and phinode.condition.op == 'OR' and phinode.false in [phinode.condition.left, phinode.condition.right]:
return BinaryExpression('AND', phinode.condition, phinode.true)
if isinstance(phinode.false, BinaryExpression) and phinode.false.op == 'AND':
if phinode.true == phinode.false.left:
return BinaryExpression('AND', BinaryExpression('OR', phinode.condition, phinode.false.right), phinode.true)
if phinode.true == phinode.false.right:
return BinaryExpression('AND', BinaryExpression('OR', phinode.condition, phinode.false.left), phinode.true)
# Sink PHI
if isinstance(phinode.true, BinaryExpression) and phinode.false in [phinode.true.left, phinode.true.right]:
if phinode.true.op == '*':
if phinode.false == phinode.true.left:
return BinaryExpression(phinode.true.op, phinode.true.left, PHI(phinode.condition, phinode.true.right, Literal(1)))
else:
return BinaryExpression(phinode.true.op, phinode.true.right, PHI(phinode.condition, phinode.true.left, Literal(1)))
return phinode
def phinode(block, varname):
dominator = idom(block)
base_value = dominator.get_var(varname)
def recurse(node, current_value):
if node is block:
return current_value
if node.is_terminal:
return None
value = current_value if varname not in node.local_assignments else node.local_assignments[varname]
if node.divergent:
true_value = recurse(node.target_if_true, value)
false_value = recurse(node.target_if_false, value)
if true_value is None and false_value is None:
return None
if true_value is None:
return false_value
if false_value is None:
return true_value
phinode = PHI(node.condition, true_value, false_value)
return optimize_phi(phinode)
return recurse(node.rtarget, value)
return recurse(dominator, base_value)
class BasicBlock(object):
def __init__(self, nr, opcodes):
self.nr = nr
self.opcodes = opcodes
self.target = None
self.next = None
self.incoming = []
self.stack = None
self.inverse = None
self.numvisits = 0
self.local_assignments = {}
divergent = property(lambda s: s.next and s.target)
target_if_true = property(lambda s: s.next if s.inverse else s.target)
target_if_false = property(lambda s: s.target if s.inverse else s.next)
is_terminal = property(lambda s: not s.next and not s.target)
rtarget = property(lambda s: s.next if s.next else s.target)
def output(self, code, in_stack, ssa=None, visited=None, args=None):
self.numvisits += 1
if visited is None:
visited = set()
if args is not None:
self.local_assignments = dict(zip(code.varnames[:code.argcount], args))
else:
self.local_assignments = dict((varname, varname) for varname in code.varnames[:code.argcount])
self.local_assignments['RETURN'] = None
if self.numvisits < len(self.incoming):
return
visited.add(self.nr)
if len(in_stack) > 0:
stack = [phinode(self, "_STACK_%d" % (pos-1)) for pos in xrange(len(in_stack), 0, -1)]
else:
stack = []
for op in self.opcodes:
op.apply(code, stack, self)
if op.is_jump:
if op.conditional:
self.condition = stack[-1]
self.inverse = op.inverse
self.stack = stack
if len(stack):
for pos, value in enumerate(reversed(stack)):
self.local_assignments["_STACK_%d" % pos] = value
if self.next:
self.next.output(code, stack, ssa, visited)
if self.target:
self.target.output(code, stack, ssa, visited)
return stack
def get_var(self, name):
if name in self.local_assignments:
return self.local_assignments[name]
return phinode(self, name)
def __repr__(self):
return "<Block %d %r %r>" % (self.nr, "B%s" % self.next.nr if self.next else None, "B%s" % self.target.nr if self.target else None )
class Expression(object):
base_value = property(lambda self: self)
def __eq__(self, other):
a,b = self.base_value,other.base_value
return type(a) == type(b) and a.__dict__ == b.__dict__
def __str__(self):
return repr(self)
class ValueExpression(Expression):
base_value = property(lambda self: self.value.base_value if isinstance(self.value, Expression) else self.value)
def __init__(self, value):
self.value = value
def __repr__(self):
return repr(self.value)
class PHI(Expression):
def __init__(self, condition, true, false):
self.condition = condition
self.true = true
self.false = false
def __repr__(self):
retval = "(CASE\n"
caseexp = self
while isinstance(caseexp, PHI):
retval += " WHEN %r THEN %r\n" % (caseexp.condition, caseexp.true)
caseexp = caseexp.false
return retval + " ELSE %r\nEND)" % caseexp
return "PHI(%r, %r, %r)" % (self.condition, self.true, self.false)
class LoadVar(Expression):
base_value = property(lambda self: self.value.base_value if isinstance(self.value, Expression) else self.value)
def __init__(self, name, value):
self.name = name
self.value = value
def __repr__(self):
return "(%s)" % self.value
class GetAttr(Expression):
def __init__(self, obj, name):
self.obj = obj
self.name = name
def __repr__(self):
return "%r.%s" % (self.obj, self.name)
class Literal(Expression):
def __init__(self, value):
self.value = value
def __repr__(self):
return "%r" % self.value
class BinaryExpression(Expression):
def __init__(self, op, left, right):
self.op = op
self.left = left
self.right = right
def __repr__(self):
return "(%r %s %r)" % (self.left, self.op, self.right)
class TupleExpression(Expression):
def __init__(self, values):
self.values = values
def __repr__(self):
return repr(self.values)
class FunctionCall(Expression):
def __init__(self, func, args):
self.func = func
self.args = args
def __repr__(self):
return "%s(%s)" % (self.func.__name__, ", ".join(map(repr, self.args)))
class Opcode(object):
opmap = {}
class __metaclass__(type):
def __init__(self, name, bases, dict_):
if name in opcode.opname:
self.opmap[opcode.opname.index(name)] = self
type.__init__(self, name, bases, dict_)
def __init__(self, code, arg=None, idx=None):
self.code = code
self.arg = arg
self.idx = idx
@classmethod
def from_code(cls, idx, code, arg=None):
return cls.opmap.get(code, Opcode)(code, arg, idx)
name = property(lambda s: opcode.opname[s.code])
has_arg = property(lambda s: s.code > opcode.HAVE_ARGUMENT)
length = property(lambda s: 3 if s.has_arg else 1)
next_op = property(lambda s: s.idx + s.length)
is_jump = property(lambda s: s.is_absjump or s.is_reljump)
is_absjump = property(lambda s: s.code in opcode.hasjabs)
is_reljump = property(lambda s: s.code in opcode.hasjrel)
is_return = False
def format(self, code):
return str(self)
def __str__(self):
if self.has_arg:
return "%4s: %s %s" % (self.idx, self.name, self.arg)
else:
return "%4s: %s" % (self.idx, self.name)
def __repr__(self):
return str(self)
def apply(self, code, stack, block):
raise NotImplementedError("Apply not implemented for %s" % self.name)
class AbsJump(Opcode):
conditional = False
jump_target = property(lambda s: s.arg)
def __str__(self):
return "%4s: %s +%d to %d" % (self.idx, self.name, self.arg, self.jump_target)
def apply(self, code, stack, block):
pass
class JUMP_ABSOLUTE(AbsJump):
pass
class RelJump(Opcode):
conditional = True
jump_target = property(lambda s: s.idx + s.length + s.arg)
def __str__(self):
return "%4s: %s +%d to %d" % (self.idx, self.name, self.arg, self.jump_target)
def apply(self, code, stack, block):
pass
class FOR_ITER(RelJump): pass
class JUMP_FORWARD(RelJump):
conditional = False
class JUMP_IF_FALSE(RelJump):
inverse = True
class JUMP_IF_TRUE(RelJump):
inverse = False
class SETUP_LOOP(RelJump): pass
class SETUP_EXCEPT(RelJump): pass
class SETUP_FINALLY(RelJump): pass
class LocalOp(Opcode):
def format(self, code):
return Opcode.format(self, code) + " (%s)" % (code.varnames[self.arg])
class LOAD_FAST(LocalOp):
def apply(self, code, stack, block):
name = code.varnames[self.arg]
stack.append(LoadVar(name, block.get_var(name)))
class LOAD_ATTR(Opcode):
def apply(self, code, stack, block):
name = code.names[self.arg]
stack.append(GetAttr(stack.pop(), name))
def format(self, code):
return Opcode.format(self, code) + " (%s)" % (code.names[self.arg])
class STORE_FAST(LocalOp):
def apply(self, code, stack, block):
value = stack.pop()
block.local_assignments[code.varnames[self.arg]] = value
class LOAD_CONST(Opcode):
def format(self, code):
return Opcode.format(self, code) + " (%r)" % (code.constants[self.arg])
def apply(self, code, stack, block):
stack.append(Literal(code.constants[self.arg]))
class BinaryOp(Opcode):
def apply(self, code, stack, block):
right = stack.pop()
left = stack.pop()
stack.append(BinaryExpression(self.operator, left, right))
class COMPARE_OP(BinaryOp):
operator = property(lambda s: opcode.cmp_op[s.arg])
def format(self, code):
return Opcode.format(self, code) + " (%s)" % (self.operator)
class BINARY_ADD(BinaryOp):
operator = "+"
class BINARY_SUBTRACT(BinaryOp):
operator = "-"
class BINARY_MULTIPLY(BinaryOp):
operator = "*"
class INPLACE_MULTIPLY(BinaryOp):
operator = "*"
class BINARY_DIVIDE(BinaryOp):
operator = "/"
class POP_TOP(Opcode):
def apply(self, code, stack, block):
stack.pop()
class RETURN_VALUE(Opcode):
is_return = True
def apply(self, code, stack, block):
block.local_assignments['RETURN'] = stack.pop()
import __builtin__
class LOAD_GLOBAL(Opcode):
def apply(self, code, stack, block):
name = code.names[self.arg]
if name in code.globals:
stack.append(code.globals[name])
else:
stack.append(getattr(__builtin__, name))
class CALL_FUNCTION(Opcode):
def apply(self, code, stack, block):
args = [stack.pop() for i in xrange(self.arg)][::-1]
func = stack.pop()
if isinstance(func, types.FunctionType):
f = Function.from_function(func)
f.analyze(args)
stack.append(f.code.return_block.get_var('RETURN'))
else:
stack.append(FunctionCall(func, args))
class DUP_TOP(Opcode):
def apply(self, code, stack, block):
stack.append(stack[-1])
class ROT_THREE(Opcode):
def apply(self, code, stack, block):
v = stack.pop()
k = stack.pop()
x = stack.pop()
stack.extend([v, x, k])
class ROT_TWO(Opcode):
def apply(self, code, stack, block):
v = stack.pop()
k = stack.pop()
stack.extend([v, k])
class BUILD_TUPLE(Opcode):
def apply(self, code, stack, block):
stack[-self.arg:] = [TupleExpression(tuple(stack[-self.arg:]))]
class UNPACK_SEQUENCE(Opcode):
def apply(self, code, stack, block):
value = stack.pop()
if isinstance(value, TupleExpression):
if len(value.values) != self.arg:
raise TypeError("Cannot unpack a tuple of %d to %d elements") % (len(value.values), self.arg)
stack.extend(reversed(value.values))
else:
raise NotImplementedError("Cannot unpack %r" % value)
import sqlalchemy
import sqlalchemy.orm
import types
class SqlalchemyCompiler(object):
def process(self, node):
if isinstance(node, Expression):
ret = getattr(self, 'visit_%s' % type(node).__name__)(node)
return ret
raise Exception("Invalid node %r" % node)
def visit_PHI(self, node):
phinode = node
whens = []
while isinstance(phinode, PHI):
whens.append((self.process(phinode.condition), self.process(phinode.true)))
phinode = phinode.false
return sqlalchemy.case(whens, else_=self.process(phinode))
def visit_BinaryExpression(self, node):
opmap = {
'<': sqlalchemy.sql.operators.lt,
'>': sqlalchemy.sql.operators.gt,
'==': sqlalchemy.sql.operators.eq,
'+': sqlalchemy.sql.operators.add,
'-': sqlalchemy.sql.operators.sub,
'*': sqlalchemy.sql.operators.mul,
'/': sqlalchemy.sql.operators.div,
'AND': sqlalchemy.sql.operators.and_,
'OR': sqlalchemy.sql.operators.or_,
}
return opmap[node.op](self.process(node.left), self.process(node.right))
def visit_LoadVar(self, node):
return self.process(node.value)
def visit_Literal(self, node):
return sqlalchemy.literal(node.value)
def visit_GetAttr(self, node):
obj = self.process(node.obj)
value = getattr(self.process(node.obj), node.name)
if isinstance(value, property):
f = Function.from_function(value.fget)
f.analyze([ValueExpression(obj)])
return self.process(f.code.return_block.get_var('RETURN'))
return value
def visit_ValueExpression(self, node):
return node.value
funcs = {
len: sqlalchemy.func.length,
}
def visit_FunctionCall(self, node):
return self.funcs[node.func](*map(self.process, node.args))
import types
class LambdaQuery(sqlalchemy.orm.Query):
compiler = SqlalchemyCompiler()
def filter(self, condition):
if isinstance(condition, types.FunctionType):
f = Function.from_function(condition)
f.analyze(ValueExpression(e.entity.class_) for e in self._entities)
retval = f.code.return_block.get_var('RETURN')
return super(LambdaQuery, self).filter(self.compiler.process(retval))
return super(LambdaQuery, self).filter(self, condition)
def output_callgraph(func):
f = Function.from_function(func)
blocks = f.analyze()
import pydot
g = pydot.Dot()
b_name = lambda b: "Block %d" % b.nr
for block in blocks:
node = pydot.Node(b_name(block), shape="plaintext")
lbl = b_name(block) + "<BR/>"
for op in block.opcodes:
lbl += "%s<BR/>" % op.format(f.code).replace(':','-').replace('<','lt').replace('>','gt')
lbl = '<<TABLE CELLSPACING="0" CELLPADDING="1" BORDER="0" CELLBORDER="1" ><TR><TD>%s</TD></TR></TABLE>>' % lbl[0:400]
node.set_label(lbl)
if block.inverse:
node.set_color('red')
g.add_node(node)
if block.incoming:
for parent in block.incoming:
edge = pydot.Edge(b_name(parent), b_name(block))
#if parent.target and parent.next:
#if (block is parent.next and not parent.inverse) or (block is not parent.next and parent.inverse):
# edge.set_label("not "+str(parent.stack[-1])[0:20].replace('(','[').replace(')',']'))
#else:
# edge.set_label(str(parent.stack[-1])[0:20].replace('(','[').replace(')',']'))
g.add_edge(edge)
g.write_png('code.png')
# -*- coding: utf-8 -*-
from analyze import LambdaQuery
from sqlalchemy import Column, String, Integer, Float, Boolean, create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Query
Base = declarative_base(bind=create_engine('sqlite://'))
session = sessionmaker(query_cls=LambdaQuery)()
class User(Base):
__tablename__ = 'user'
id = Column(Integer, primary_key=True)
name = Column(String(32))
age = Column(Integer)
shoe_size = Column(Integer)
@property
def importance(self):
return larger(self.age, self.shoe_size) + len(self.name)
def larger(a,b):
if a > b:
return a
return b
print session.query(User).filter(lambda u: u.name == 'ants' and u.importance > 10)
#Output:
#SELECT user.id AS user_id, user.name AS user_name, user.age AS user_age, user.shoe_size AS user_shoe_size
#FROM user
#WHERE user.name = ? AND CASE WHEN (user.age > user.shoe_size) THEN user.age ELSE user.shoe_size END + length(user.name) > ?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment