Skip to content

Instantly share code, notes, and snippets.

@serge-sans-paille
Created September 23, 2014 13:15
Show Gist options
  • Star 9 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save serge-sans-paille/79b44dd89f374c96b20f to your computer and use it in GitHub Desktop.
Save serge-sans-paille/79b44dd89f374c96b20f to your computer and use it in GitHub Desktop.
Python - functional style!
import ast
import sys
import shutil
import unparse
import unittest
import doctest
import StringIO
import os
from copy import deepcopy
def _chain(l, s):
"""
chain all lambdas from l, starting with expression s
"""
def _combine(x, y):
return ast.Call(y, [x], [], None, None)
return reduce(_combine, l, s)
class NotSupportedError(RuntimeError):
pass
class GatherIdentifiers(ast.NodeVisitor):
def __init__(self):
self.result = []
def visit_Name(self, node):
if type(node.ctx) is ast.Param:
self.result.append(node)
class IsRegular(ast.NodeVisitor):
def __init__(self):
self.result = True
def irregular(self, node):
self.result = False
visit_Break = visit_Continue = visit_Return = irregular
def isregular(node):
ir = IsRegular()
ir.visit(node)
return ir.result
class FunctionalStyle(ast.NodeTransformer):
"""
Turns a function into a lambda expression
The whole idea is to turn a function into a lambda expression that should
be complex enough to understand to prevent straight forward desobfuscation
To do so two operators are introduced:
E -> (expr -> store) -> new_expr
S -> (stmt -> store) -> (store -> store)
E(expr, store) returns the functionnal version of the expression `expr'
when evaluated in store `store'.
The result is an expression.
S(stmt, store) returns a function that takes a store as input
and returns a new store
The `nesting_level' parameter is only there for debugging purpose.
Set it to one simulates the processing of an instruction inside a function.
"""
def __init__(self, nesting_level=0):
self.rec = '__'
self.store = "_"
self.return_ = "$"
self.wtmp = "!"
self.formal_rec = 'f'
self.formal_store = '_'
args = ast.arguments([ast.Name(self.formal_rec, ast.Param()),
ast.Name(self.formal_store, ast.Param())],
None, None, [])
body = ast.Call(ast.Name(self.formal_rec, ast.Load()),
[ast.Name(self.formal_rec, ast.Load()),
ast.Name(self.formal_store, ast.Load())],
[],
None,
None)
self.ycombinator = ast.Lambda(args, body)
self.nesting_level = nesting_level
def not_supported(self, node):
if self.nesting_level:
raise NotSupportedError(str(type(node)))
else:
return node
visit_ClassDef = not_supported
visit_Print = not_supported
visit_With = not_supported
visit_Raise = not_supported
visit_TryExcept = not_supported
visit_TryFinally = not_supported
visit_Assert = not_supported
visit_ImportFrom = not_supported
visit_Exec = not_supported
visit_Global = not_supported
visit_Break = not_supported
visit_Continue = not_supported
visit_Yield = not_supported
visit_Lambda = not_supported
def visit_FunctionDef(self, node):
"""
A function is turned into a lambda declaration
>>> node = ast.parse('def foo(x): pass')
>>> newnode = FunctionalStyle().visit(node)
>>> _ = unparse.Unparser(newnode, sys.stdout)
<BLANKLINE>
foo = (lambda x: (lambda _: _)({'x': x, '$': None})['$'])
"""
if self.nesting_level:
raise NotSupportedError("Nested Functions")
if node.decorator_list:
return node
if node.args.vararg:
return node
if node.args.kwarg:
return node
if node.args.defaults:
return node
# gather all the function to chain
nesting_level = self.nesting_level
self.nesting_level += 1
try:
orig = deepcopy(node)
calls = map(self.visit, node.body)
except NotSupportedError:
self.nesting_level = nesting_level
return orig
self.nesting_level -= 1
# create the initial state
gi = GatherIdentifiers()
map(gi.visit, node.args.args)
formal_parameters = gi.result
keys = [ast.Str(n.id) for n in formal_parameters]
keys.append(ast.Str(self.return_))
values = [ast.Name(n.id, ast.Load()) for n in formal_parameters]
values.append(ast.Name('None', ast.Load()))
init_expr = ast.Dict(keys, values)
# create the lambda
lambda_ = ast.Lambda(node.args,
ast.Subscript(_chain(calls, init_expr),
ast.Index(ast.Str(self.return_)),
ast.Load())
)
res = ast.Assign([ast.Name(node.name, ast.Store())], lambda_)
return ast.fix_missing_locations(res)
def visit_Return(self, node):
"""
A return just adds an entry in the state and returns the state
>>> node = ast.parse('return 1')
>>> newnode = FunctionalStyle(nesting_level=1).visit(node)
>>> _ = unparse.Unparser(newnode, sys.stdout)
(lambda _: (_.__setitem__('$', 1), _)[(-1)])
"""
if not self.nesting_level:
return node
args = ast.arguments([ast.Name(self.store, ast.Param())],
None, None, [])
if node.value:
returned = self.visit(node.value)
else:
returned = ast.Name("None", ast.Load())
setreturn = ast.Call(ast.Attribute(ast.Name(self.store, ast.Load()),
'__setitem__',
ast.Load()),
[ast.Str(self.return_), returned],
[],
None,
None)
body = ast.Subscript(ast.Tuple([setreturn,
ast.Name(self.store, ast.Load())],
ast.Load()),
ast.Index(ast.Num(-1)),
ast.Load())
return ast.Lambda(args, body)
def visit_Delete(self, node):
"""
A delete removes entries from the store
>>> node = ast.parse('del a')
>>> newnode = FunctionalStyle(nesting_level=1).visit(node)
>>> _ = unparse.Unparser(newnode, sys.stdout)
(lambda _: (_.__delitem__('a'), _)[(-1)])
"""
if not self.nesting_level:
return node
if any(type(t) is not ast.Name for t in node.targets):
raise NotSupportedError("deleting non identifiers")
args = ast.arguments([ast.Name(self.store, ast.Param())],
None, None, [])
bodyn = [ast.Call(ast.Attribute(ast.Name(self.store, ast.Load()),
'__delitem__', ast.Load()),
[ast.Str(t.id)],
[], None, None)
for t in node.targets]
bodyl = ast.Name(self.store, ast.Load())
body = ast.Subscript(ast.Tuple(bodyn + [bodyl], ast.Load()),
ast.Index(ast.Num(-1)),
ast.Load()
)
return ast.Lambda(args, body)
def visit_Pass(self, node):
"""
A Pass is similar to applying the indentity to the locals
S('pass', state) = lambda state : state
>>> node = ast.parse('pass')
>>> newnode = FunctionalStyle(nesting_level=1).visit(node)
>>> _ = unparse.Unparser(newnode, sys.stdout)
(lambda _: _)
"""
if not self.nesting_level:
return node
args = ast.arguments([ast.Name(self.store, ast.Param())],
None, None, [])
body = ast.Name(self.store, ast.Load())
return ast.Lambda(args, body)
def visit_While(self, node):
"""
The definition of a while is recursive!
The recursive function itself is
S('while cond: body else: orelse', state) =
( S('while cond: body else: orelse', S('body', state))
if E('cond', state)
else S('orelse', state) )
So with the y combinator we get
S('while cond: body else: orelse', state) =
Y(lambda self, state:
self(self, S('body', state))
if E('cond', state)
else S('orelse', state),
state)
>>> node = ast.parse('while 1: pass')
>>> newnode = FunctionalStyle(nesting_level=1).visit(node)
>>> _ = unparse.Unparser(newnode, sys.stdout)
(lambda _: (lambda f, _: f(f, _))\
((lambda __, _: \
((lambda _: __(__, _))((lambda _: _)(_)) if 1 else _)), \
_))
"""
if not self.nesting_level:
return node
if not isregular(node):
raise NotSupportedError("irregular control flow")
args = ast.arguments([ast.Name(self.rec, ast.Param()),
ast.Name(self.store, ast.Param())],
None, None, [])
body_ = map(self.visit, node.body)
lambda_args = ast.arguments([ast.Name(self.store, ast.Param())],
None, None, [])
lambda_ = ast.Lambda(lambda_args,
ast.Call(ast.Name(self.rec, ast.Load()),
[ast.Name(self.rec, ast.Load()),
ast.Name(self.store, ast.Load())],
[],
None,
None)
)
body_.append(lambda_)
body_ = _chain(body_, ast.Name(self.store, ast.Load()))
orelse_ = map(self.visit, node.orelse)
orelse_ = _chain(orelse_, ast.Name(self.store, ast.Load()))
body = ast.IfExp(self.visit(node.test), body_, orelse_)
return ast.Lambda(ast.arguments([ast.Name(self.store, ast.Param())],
None, None, []),
ast.Call(self.ycombinator,
[ast.Lambda(args, body),
ast.Name(self.store, ast.Load())],
[],
None,
None)
)
def visit_AugAssign(self, node):
"""
An augassign just updates the store
>>> node = ast.parse('a += 1')
>>> newnode = FunctionalStyle(nesting_level=1).visit(node)
>>> _ = unparse.Unparser(newnode, sys.stdout)
(lambda _: \
(_.__setitem__('a', ((_['a'] if ('a' in _) else a) + 1)), _)[(-1)])
"""
if not self.nesting_level:
return node
args = ast.arguments([ast.Name(self.store, ast.Param())],
None, None, [])
load_target = self.visit(node.target)
load_target.ctx = ast.Load()
op = self.assign_helper(node.target,
ast.BinOp(load_target,
node.op,
self.visit(node.value)))
body = ast.Subscript(ast.Tuple([op,
ast.Name(self.store, ast.Load())],
ast.Load()),
ast.Index(ast.Num(-1)),
ast.Load())
return ast.Lambda(args, body)
def visit_If(self, node):
"""
An if evaluates its condition then yields one of the branch
There is an if expression in python, take advantage of it!
>>> node = ast.parse('if 1: 2')
>>> newnode = FunctionalStyle(nesting_level=1).visit(node)
>>> _ = unparse.Unparser(newnode, sys.stdout)
(lambda _: ((lambda _: (2, _)[(-1)])(_) if 1 else _))
"""
if not self.nesting_level:
return node
if not isregular(node):
raise NotSupportedError("irregular control flow")
args = ast.arguments([ast.Name(self.store, ast.Param())],
None, None, [])
body_ = map(self.visit, node.body)
body_ = _chain(body_, ast.Name(self.store, ast.Load()))
orelse_ = map(self.visit, node.orelse)
orelse_ = _chain(orelse_, ast.Name(self.store, ast.Load()))
body = ast.IfExp(self.visit(node.test), body_, orelse_)
return ast.Lambda(args, body)
def assign_helper(self, target, value):
# assigning to a name is easy
if type(target) is ast.Name:
return ast.Call(ast.Attribute(ast.Name(self.store, ast.Load()),
'__setitem__', ast.Load()),
[ast.Str(target.id), value],
[], None, None)
# but assigning to a subscript is tricky
elif type(target) is ast.Subscript:
# really tricky: there are different types of slices
tslice = type(target.slice)
if tslice is ast.Index:
vslice = self.visit(target.slice.value)
else:
raise NotSupportedError("complex slices")
return ast.Call(ast.Attribute(self.visit(target.value),
'__setitem__', ast.Load()),
[vslice, value],
[], None, None)
else:
raise NotSupportedError("Assigning to something"
"not a subscript not a name")
def visit_Assign(self, node):
"""
An assign creates one or several entries in the store
Type destructuring is not supported
>>> node = ast.parse('a = 2')
>>> newnode = FunctionalStyle(nesting_level=1).visit(node)
>>> _ = unparse.Unparser(newnode, sys.stdout)
(lambda _: (_.__setitem__('!', 2), \
_.__setitem__('a', _['!']), _)[(-1)])
"""
if not self.nesting_level:
return node
if any(type(t) is ast.Tuple for t in node.targets):
raise NotSupportedError("type destructuring")
args = ast.arguments([ast.Name(self.store, ast.Param())],
None, None, [])
body0 = ast.Call(ast.Attribute(ast.Name(self.store, ast.Load()),
'__setitem__', ast.Load()),
[ast.Str(self.wtmp), self.visit(node.value)],
[], None, None)
value = ast.Subscript(ast.Name(self.store, ast.Load()),
ast.Index(ast.Str(self.wtmp)),
ast.Load())
bodyn = [self.assign_helper(t, deepcopy(value)) for t in node.targets]
bodyl = ast.Name(self.store, ast.Load())
body = ast.Subscript(ast.Tuple([body0] + bodyn + [bodyl], ast.Load()),
ast.Index(ast.Num(-1)),
ast.Load()
)
return ast.Lambda(args, body)
def visit_Name(self, node):
"""
When visiting a name, we don't know statically whether it is
- a local name, in which case it should be looked up in the store
- a global name, in which case it should be looked up in the globals
Moreover, one cannot use the globals() function:
it may be monkey patched
>>> node = ast.parse('i')
>>> newnode = FunctionalStyle(nesting_level=1).visit(node)
>>> _ = unparse.Unparser(newnode, sys.stdout)
(lambda _: ((_['i'] if ('i' in _) else i), _)[(-1)])
"""
if not self.nesting_level:
return node
cond = ast.Compare(ast.Str(node.id),
[ast.In()],
[ast.Name(self.store, ast.Load())])
body_ = ast.Subscript(ast.Name(self.store, ast.Load()),
ast.Index(ast.Str(node.id)),
ast.Load())
orelse_ = node
return ast.IfExp(cond, body_, orelse_)
def visit_For(self, node):
"""
A for loop can be emulated using list comprehension
It assumes there is no break or continue, though
>>> node = ast.parse('for i in []: pass')
>>> newnode = FunctionalStyle(nesting_level=1).visit(node)
>>> _ = unparse.Unparser(newnode, sys.stdout)
(lambda _: ([(lambda _: _)(_) for _['i'] in []], _, _)[(-1)])
"""
if not self.nesting_level:
return node
if not isregular(node):
raise NotSupportedError("irregular control flow")
if type(node.target) is not ast.Name:
raise NotSupportedError("only identifiers as loop index")
args = ast.arguments([ast.Name(self.store, ast.Param())],
None, None, [])
# turn the for into a lis comp
body_ = map(self.visit, node.body)
body_ = _chain(body_, ast.Name(self.store, ast.Load()))
orelse_ = map(self.visit, node.orelse)
orelse_ = _chain(orelse_, ast.Name(self.store, ast.Load()))
target_ = ast.Subscript(ast.Name(self.store, ast.Load()),
ast.Index(ast.Str(node.target.id)),
ast.Store())
comp = ast.ListComp(body_, [ast.comprehension(target_,
self.visit(node.iter),
[])])
# combine the orelse statemnt
body = ast.Subscript(ast.Tuple([comp, orelse_,
ast.Name(self.store, ast.Load())],
ast.Load()),
ast.Index(ast.Num(-1)),
ast.Load()
)
return ast.Lambda(args, body)
def visit_Import(self, node):
"""
Emulate import using the __import__ function
This is slightly fragile, as one could have monkey patched it
>>> node = ast.parse('import math')
>>> newnode = FunctionalStyle(nesting_level=1).visit(node)
>>> _ = unparse.Unparser(newnode, sys.stdout)
(lambda _: (_.__setitem__('math', __import__('math')), _)[(-1)])
"""
if not self.nesting_level:
return node
args = ast.arguments([ast.Name(self.store, ast.Param())],
None, None, [])
bodyn = [ast.Call(ast.Attribute(ast.Name(self.store, ast.Load()),
'__setitem__', ast.Load()),
[ast.Str(n.asname or n.name),
ast.Call(ast.Name('__import__', ast.Load()),
[ast.Str(n.name)],
[], None, None)],
[], None, None
)
for n in node.names]
bodyl = ast.Name(self.store, ast.Load())
body = ast.Subscript(ast.Tuple(bodyn + [bodyl], ast.Load()),
ast.Index(ast.Num(-1)),
ast.Load()
)
return ast.Lambda(args, body)
def visit_Expr(self, node):
"""
An expression just needs to be wrapped in a lambda
>>> node = ast.parse('1')
>>> newnode = FunctionalStyle(nesting_level=1).visit(node)
>>> _ = unparse.Unparser(newnode, sys.stdout)
(lambda _: (1, _)[(-1)])
"""
if not self.nesting_level:
return node
args = ast.arguments([ast.Name(self.store, ast.Param())],
None, None, [])
body = ast.Subscript(ast.Tuple([self.visit(node.value),
ast.Name(self.store, ast.Load())],
ast.Load()),
ast.Index(ast.Num(-1)),
ast.Load())
return ast.Lambda(args, body)
def visit_comprehension(self, node):
if type(node.target) is not ast.Name:
raise NotSupportedError("only identifiers as loop index")
target_ = ast.Subscript(ast.Name(self.store, ast.Load()),
ast.Index(ast.Str(node.target.id)),
ast.Store())
return ast.comprehension(target_,
self.visit(node.iter),
map(self.visit, node.ifs))
class TestFunctionalStyle(unittest.TestCase):
def generic_test(self, code, *tests):
ref_env = globals().copy()
exec code in ref_env
for test in tests:
# generate reference
ref = eval(test, ref_env)
# parse, transform and eval
node = ast.parse(code)
node = FunctionalStyle().visit(node)
obj = compile(node, '<test>', 'exec')
obj_env = globals().copy()
exec obj in obj_env
candidate = eval(test, obj_env)
self.assertEqual(ref, candidate)
# also test that generated string can be compiled
out = StringIO.StringIO()
unparse.Unparser(node, out)
ast.parse(out.getvalue())
def test_FunctionDef(self):
self.generic_test("def foo(x): return x",
"foo(1)", "foo(1.5)", "foo('hello')")
self.generic_test("def foo(x,y): return x,y",
"foo(1, True)", "foo(.5, {})", "foo('h', (0, None))")
def test_Pass(self):
self.generic_test("def foo(): pass", 'foo()')
def test_Return(self):
self.generic_test("def foo(x): return", "foo(0)")
self.generic_test("def foo(x): return x + 1", "foo(0)")
def test_AugAssign(self):
self.generic_test("def foo(x, y): x += y ; return x, y",
"foo(1,2)")
self.generic_test("def foo(x, y): x[y] += y ; return x, y",
"foo([1, 2, 3], 2)")
def test_Assign(self):
self.generic_test("def foo(x, y): x = y ; return x, y",
"foo(1,2)")
self.generic_test("def foo(x, y): x = y = x * y; return x, y",
"foo('1.4',2)")
self.generic_test("def foo(x, y): x[y] = y ; return x",
"foo([1, '3'], 1)")
self.generic_test("def foo(x, y): x[y][0][0] = y ; return x",
"foo([1, [['3']]], 1)")
def test_If(self):
self.generic_test("def foo(x, y):\n if x: return x\n else: return y",
"foo(1,2)", "foo(0, 2)")
self.generic_test("""
def foo(x, y):
if x: return x
elif y: return y
else: return 'e'""",
"foo([1], [])",
"foo([], [1])",
"foo([], [])")
def test_While(self):
self.generic_test("def foo(x):\n while x: pass\n return x",
"foo(0)")
self.generic_test("def f(x):\n while x: x-=1\n else: x+=1\n return x",
"f(1)",
"f(0)")
self.generic_test("def foo(x):\n while x>0: x-=1\n return x",
"foo(3)",
"foo(0)")
def test_For(self):
self.generic_test("def foo(x,s):\n for i in x: s+=i;\n return s",
"foo('hello', '')",
"foo([1,2,3], 8)")
def test_Del(self):
self.generic_test("def foo(x): del x", "foo(1)")
def test_Import(self):
self.generic_test("def foo(x): import math as m ; return m.cos(x)",
"foo(1)")
def test_Expr(self):
self.generic_test("def foo(x, y): x(y); return y",
"foo(lambda x: x.append(1),[])")
def test_For(self):
self.generic_test("def foo(x, y):\n for i in x: y+= 1\n return y",
"foo('hello', 0)")
def test_Global(self):
self.generic_test("def foo(x): return range(x)",
"foo(3)")
self.generic_test("range = list\ndef foo(x): return range(x)",
"foo('e')")
def test_bootstrap(self):
# Verify we correctly process ourselves
module = sys.modules[__name__]
module_code = file(module.__file__).read()
module_node = ast.parse(module_code)
module_node = FunctionalStyle().visit(module_node)
module_obj = compile(module_node, '<test>', 'exec')
sys.path.append(os.path.dirname(__file__))
env = {}
exec module_obj in env
def test_on_ast(self):
# Verify we correctly process the ast module
module_code = file(ast.__file__[:-1]).read()
module_node = ast.parse(module_code)
module_node = FunctionalStyle().visit(module_node)
module_obj = compile(module_node, '<test>', 'exec')
env = {}
exec module_obj in env
def transform(input_path, output_path):
try:
with open(input_path) as input_file:
node = ast.parse(input_file.read())
FunctionalStyle().visit(node)
with open(output_path, 'w') as output_file:
output_file.write('#! /usr/bin/env python\n')
unparse.Unparser(node, output_file)
output_file.write('\n')
shutil.copymode(input_path, output_path)
shutil.copystat(input_path, output_path)
except SyntaxError:
pass
if __name__ == "__main__":
if len(sys.argv) < 2 or len(sys.argv) > 3:
print 'Usage: %s <input file> [output file]' % sys.argv[0]
exit(0)
if len(sys.argv) >= 2:
input_name = sys.argv[1]
if len(sys.argv) == 3:
output = open(sys.argv[2], 'w')
else:
output = sys.stdout
input_file = open(input_name, 'r')
node = ast.parse(input_file.read())
node = FunctionalStyle().visit(node)
unparse.Unparser(node, output)
output.write('\n')
"Usage: unparse.py <path to source file>"
import sys
import ast
import cStringIO
import os
# Large float and imaginary literals get turned into infinities in the AST.
# We unparse those infinities to INFSTR.
INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)
def interleave(inter, f, seq):
"""Call f on each item in seq, calling inter() in between.
"""
seq = iter(seq)
try:
f(next(seq))
except StopIteration:
pass
else:
for x in seq:
inter()
f(x)
class Unparser:
"""Methods in this class recursively traverse an AST and
output source code for the abstract syntax; original formatting
is disregarded. """
def __init__(self, tree, file = sys.stdout):
"""Unparser(tree, file=sys.stdout) -> None.
Print the source for tree to file."""
self.f = file
self.future_imports = []
self._indent = 0
self.dispatch(tree)
self.f.write("")
self.f.flush()
def fill(self, text = ""):
"Indent a piece of text, according to the current indentation level"
self.f.write("\n"+" "*self._indent + text)
def write(self, text):
"Append a piece of text to the current line."
self.f.write(text)
def enter(self):
"Print ':', and increase the indentation."
self.write(":")
self._indent += 1
def leave(self):
"Decrease the indentation level."
self._indent -= 1
def dispatch(self, tree):
"Dispatcher function, dispatching tree type T to method _T."
if isinstance(tree, list):
for t in tree:
self.dispatch(t)
return
meth = getattr(self, "_"+tree.__class__.__name__)
meth(tree)
############### Unparsing methods ######################
# There should be one method per concrete grammar type #
# Constructors should be grouped by sum type. Ideally, #
# this would follow the order in the grammar, but #
# currently doesn't. #
########################################################
def _Module(self, tree):
for stmt in tree.body:
self.dispatch(stmt)
# stmt
def _Expr(self, tree):
self.fill()
self.dispatch(tree.value)
def _Import(self, t):
self.fill("import ")
interleave(lambda: self.write(", "), self.dispatch, t.names)
def _ImportFrom(self, t):
# A from __future__ import may affect unparsing, so record it.
if t.module and t.module == '__future__':
self.future_imports.extend(n.name for n in t.names)
self.fill("from ")
self.write("." * t.level)
if t.module:
self.write(t.module)
self.write(" import ")
interleave(lambda: self.write(", "), self.dispatch, t.names)
def _Assign(self, t):
self.fill()
for target in t.targets:
self.dispatch(target)
self.write(" = ")
self.dispatch(t.value)
def _AugAssign(self, t):
self.fill()
self.dispatch(t.target)
self.write(" "+self.binop[t.op.__class__.__name__]+"= ")
self.dispatch(t.value)
def _Return(self, t):
self.fill("return")
if t.value:
self.write(" ")
self.dispatch(t.value)
def _Pass(self, t):
self.fill("pass")
def _Break(self, t):
self.fill("break")
def _Continue(self, t):
self.fill("continue")
def _Delete(self, t):
self.fill("del ")
interleave(lambda: self.write(", "), self.dispatch, t.targets)
def _Assert(self, t):
self.fill("assert ")
self.dispatch(t.test)
if t.msg:
self.write(", ")
self.dispatch(t.msg)
def _Exec(self, t):
self.fill("exec ")
self.dispatch(t.body)
if t.globals:
self.write(" in ")
self.dispatch(t.globals)
if t.locals:
self.write(", ")
self.dispatch(t.locals)
def _Print(self, t):
self.fill("print ")
do_comma = False
if t.dest:
self.write(">>")
self.dispatch(t.dest)
do_comma = True
for e in t.values:
if do_comma:self.write(", ")
else:do_comma=True
self.dispatch(e)
if not t.nl:
self.write(",")
def _Global(self, t):
self.fill("global ")
interleave(lambda: self.write(", "), self.write, t.names)
def _Yield(self, t):
self.write("(")
self.write("yield")
if t.value:
self.write(" ")
self.dispatch(t.value)
self.write(")")
def _Raise(self, t):
self.fill('raise ')
if t.type:
self.dispatch(t.type)
if t.inst:
self.write(", ")
self.dispatch(t.inst)
if t.tback:
self.write(", ")
self.dispatch(t.tback)
def _TryExcept(self, t):
self.fill("try")
self.enter()
self.dispatch(t.body)
self.leave()
for ex in t.handlers:
self.dispatch(ex)
if t.orelse:
self.fill("else")
self.enter()
self.dispatch(t.orelse)
self.leave()
def _TryFinally(self, t):
if len(t.body) == 1 and isinstance(t.body[0], ast.TryExcept):
# try-except-finally
self.dispatch(t.body)
else:
self.fill("try")
self.enter()
self.dispatch(t.body)
self.leave()
self.fill("finally")
self.enter()
self.dispatch(t.finalbody)
self.leave()
def _ExceptHandler(self, t):
self.fill("except")
if t.type:
self.write(" ")
self.dispatch(t.type)
if t.name:
self.write(" as ")
self.dispatch(t.name)
self.enter()
self.dispatch(t.body)
self.leave()
def _ClassDef(self, t):
self.write("\n")
for deco in t.decorator_list:
self.fill("@")
self.dispatch(deco)
self.fill("class "+t.name)
if t.bases:
self.write("(")
for a in t.bases:
self.dispatch(a)
self.write(", ")
self.write(")")
self.enter()
self.dispatch(t.body)
self.leave()
def _FunctionDef(self, t):
self.write("\n")
for deco in t.decorator_list:
self.fill("@")
self.dispatch(deco)
self.fill("def "+t.name + "(")
self.dispatch(t.args)
self.write(")")
self.enter()
self.dispatch(t.body)
self.leave()
def _For(self, t):
self.fill("for ")
self.dispatch(t.target)
self.write(" in ")
self.dispatch(t.iter)
self.enter()
self.dispatch(t.body)
self.leave()
if t.orelse:
self.fill("else")
self.enter()
self.dispatch(t.orelse)
self.leave()
def _If(self, t):
self.fill("if ")
self.dispatch(t.test)
self.enter()
self.dispatch(t.body)
self.leave()
# collapse nested ifs into equivalent elifs.
while (t.orelse and len(t.orelse) == 1 and
isinstance(t.orelse[0], ast.If)):
t = t.orelse[0]
self.fill("elif ")
self.dispatch(t.test)
self.enter()
self.dispatch(t.body)
self.leave()
# final else
if t.orelse:
self.fill("else")
self.enter()
self.dispatch(t.orelse)
self.leave()
def _While(self, t):
self.fill("while ")
self.dispatch(t.test)
self.enter()
self.dispatch(t.body)
self.leave()
if t.orelse:
self.fill("else")
self.enter()
self.dispatch(t.orelse)
self.leave()
def _With(self, t):
self.fill("with ")
self.dispatch(t.context_expr)
if t.optional_vars:
self.write(" as ")
self.dispatch(t.optional_vars)
self.enter()
self.dispatch(t.body)
self.leave()
# expr
def _Str(self, tree):
# if from __future__ import unicode_literals is in effect,
# then we want to output string literals using a 'b' prefix
# and unicode literals with no prefix.
if "unicode_literals" not in self.future_imports:
self.write(repr(tree.s))
elif isinstance(tree.s, str):
self.write("b" + repr(tree.s))
elif isinstance(tree.s, unicode):
self.write(repr(tree.s).lstrip("u"))
else:
assert False, "shouldn't get here"
def _Name(self, t):
self.write(t.id)
def _Repr(self, t):
self.write("`")
self.dispatch(t.value)
self.write("`")
def _Num(self, t):
repr_n = repr(t.n)
# Parenthesize negative numbers, to avoid turning (-1)**2 into -1**2.
if repr_n.startswith("-"):
self.write("(")
# Substitute overflowing decimal literal for AST infinities.
self.write(repr_n.replace("inf", INFSTR))
if repr_n.startswith("-"):
self.write(")")
def _List(self, t):
self.write("[")
interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write("]")
def _ListComp(self, t):
self.write("[")
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
self.write("]")
def _GeneratorExp(self, t):
self.write("(")
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
self.write(")")
def _SetComp(self, t):
self.write("{")
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
self.write("}")
def _DictComp(self, t):
self.write("{")
self.dispatch(t.key)
self.write(": ")
self.dispatch(t.value)
for gen in t.generators:
self.dispatch(gen)
self.write("}")
def _comprehension(self, t):
self.write(" for ")
self.dispatch(t.target)
self.write(" in ")
self.dispatch(t.iter)
for if_clause in t.ifs:
self.write(" if ")
self.dispatch(if_clause)
def _IfExp(self, t):
self.write("(")
self.dispatch(t.body)
self.write(" if ")
self.dispatch(t.test)
self.write(" else ")
self.dispatch(t.orelse)
self.write(")")
def _Set(self, t):
assert(t.elts) # should be at least one element
self.write("{")
interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write("}")
def _Dict(self, t):
self.write("{")
def write_pair(pair):
(k, v) = pair
self.dispatch(k)
self.write(": ")
self.dispatch(v)
interleave(lambda: self.write(", "), write_pair, zip(t.keys, t.values))
self.write("}")
def _Tuple(self, t):
self.write("(")
if len(t.elts) == 1:
(elt,) = t.elts
self.dispatch(elt)
self.write(",")
else:
interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write(")")
unop = {"Invert":"~", "Not": "not", "UAdd":"+", "USub":"-"}
def _UnaryOp(self, t):
self.write("(")
self.write(self.unop[t.op.__class__.__name__])
self.write(" ")
# If we're applying unary minus to a number, parenthesize the number.
# This is necessary: -2147483648 is different from -(2147483648) on
# a 32-bit machine (the first is an int, the second a long), and
# -7j is different from -(7j). (The first has real part 0.0, the second
# has real part -0.0.)
if isinstance(t.op, ast.USub) and isinstance(t.operand, ast.Num):
self.write("(")
self.dispatch(t.operand)
self.write(")")
else:
self.dispatch(t.operand)
self.write(")")
binop = { "Add":"+", "Sub":"-", "Mult":"*", "Div":"/", "Mod":"%",
"LShift":"<<", "RShift":">>", "BitOr":"|", "BitXor":"^", "BitAnd":"&",
"FloorDiv":"//", "Pow": "**"}
def _BinOp(self, t):
self.write("(")
self.dispatch(t.left)
self.write(" " + self.binop[t.op.__class__.__name__] + " ")
self.dispatch(t.right)
self.write(")")
cmpops = {"Eq":"==", "NotEq":"!=", "Lt":"<", "LtE":"<=", "Gt":">", "GtE":">=",
"Is":"is", "IsNot":"is not", "In":"in", "NotIn":"not in"}
def _Compare(self, t):
self.write("(")
self.dispatch(t.left)
for o, e in zip(t.ops, t.comparators):
self.write(" " + self.cmpops[o.__class__.__name__] + " ")
self.dispatch(e)
self.write(")")
boolops = {ast.And: 'and', ast.Or: 'or'}
def _BoolOp(self, t):
self.write("(")
s = " %s " % self.boolops[t.op.__class__]
interleave(lambda: self.write(s), self.dispatch, t.values)
self.write(")")
def _Attribute(self,t):
self.dispatch(t.value)
# Special case: 3.__abs__() is a syntax error, so if t.value
# is an integer literal then we need to either parenthesize
# it or add an extra space to get 3 .__abs__().
if isinstance(t.value, ast.Num) and isinstance(t.value.n, int):
self.write(" ")
self.write(".")
self.write(t.attr)
def _Call(self, t):
self.dispatch(t.func)
self.write("(")
comma = False
for e in t.args:
if comma: self.write(", ")
else: comma = True
self.dispatch(e)
for e in t.keywords:
if comma: self.write(", ")
else: comma = True
self.dispatch(e)
if t.starargs:
if comma: self.write(", ")
else: comma = True
self.write("*")
self.dispatch(t.starargs)
if t.kwargs:
if comma: self.write(", ")
else: comma = True
self.write("**")
self.dispatch(t.kwargs)
self.write(")")
def _Subscript(self, t):
self.dispatch(t.value)
self.write("[")
self.dispatch(t.slice)
self.write("]")
# slice
def _Ellipsis(self, t):
self.write("...")
def _Index(self, t):
self.dispatch(t.value)
def _Slice(self, t):
if t.lower:
self.dispatch(t.lower)
self.write(":")
if t.upper:
self.dispatch(t.upper)
if t.step:
self.write(":")
self.dispatch(t.step)
def _ExtSlice(self, t):
interleave(lambda: self.write(', '), self.dispatch, t.dims)
# others
def _arguments(self, t):
first = True
# normal arguments
defaults = [None] * (len(t.args) - len(t.defaults)) + t.defaults
for a,d in zip(t.args, defaults):
if first:first = False
else: self.write(", ")
self.dispatch(a),
if d:
self.write("=")
self.dispatch(d)
# varargs
if t.vararg:
if first:first = False
else: self.write(", ")
self.write("*")
self.write(t.vararg)
# kwargs
if t.kwarg:
if first:first = False
else: self.write(", ")
self.write("**"+t.kwarg)
def _keyword(self, t):
self.write(t.arg)
self.write("=")
self.dispatch(t.value)
def _Lambda(self, t):
self.write("(")
self.write("lambda ")
self.dispatch(t.args)
self.write(": ")
self.dispatch(t.body)
self.write(")")
def _alias(self, t):
self.write(t.name)
if t.asname:
self.write(" as "+t.asname)
def roundtrip(filename, output=sys.stdout):
with open(filename, "r") as pyfile:
source = pyfile.read()
tree = compile(source, filename, "exec", ast.PyCF_ONLY_AST)
Unparser(tree, output)
def testdir(a):
try:
names = [n for n in os.listdir(a) if n.endswith('.py')]
except OSError:
sys.stderr.write("Directory not readable: %s" % a)
else:
for n in names:
fullname = os.path.join(a, n)
if os.path.isfile(fullname):
output = cStringIO.StringIO()
print 'Testing %s' % fullname
try:
roundtrip(fullname, output)
except Exception as e:
print ' Failed to compile, exception is %s' % repr(e)
elif os.path.isdir(fullname):
testdir(fullname)
def main(args):
if args[0] == '--testdir':
for a in args[1:]:
testdir(a)
else:
for a in args:
roundtrip(a)
if __name__=='__main__':
main(sys.argv[1:])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment