Skip to content

Instantly share code, notes, and snippets.

@elazarg
Created June 6, 2016 03:04
Show Gist options
  • Save elazarg/cec3a94ca1520947dc1f78e5d762807c to your computer and use it in GitHub Desktop.
Save elazarg/cec3a94ca1520947dc1f78e5d762807c to your computer and use it in GitHub Desktop.
use `with`: transform `f=open(...); ... ; f.close()' into `with open(...) as f: ...`
# -*- coding: utf-8 -*-
"""
codegen
~~~~~~~
Extension to ast that allow ast -> python code generation.
:copyright: Copyright 2008 by Armin Ronacher.
:license: BSD.
"""
from ast import *
BOOLOP_SYMBOLS = {
And: 'and',
Or: 'or'
}
BINOP_SYMBOLS = {
Add: '+',
Sub: '-',
Mult: '*',
Div: '/',
FloorDiv: '//',
Mod: '%',
LShift: '<<',
RShift: '>>',
BitOr: '|',
BitAnd: '&',
BitXor: '^',
Pow: '**',
}
CMPOP_SYMBOLS = {
Eq: '==',
Gt: '>',
GtE: '>=',
In: 'in',
Is: 'is',
IsNot: 'is not',
Lt: '<',
LtE: '<=',
NotEq: '!=',
NotIn: 'not in'
}
UNARYOP_SYMBOLS = {
Invert: '~',
Not: 'not ',
UAdd: '+',
USub: '-'
}
def enclose(st):
d = {'start':st[0], 'end':st[-1]}
if len(st) == 3:
d['sep'] = st[1:-1]
return d
def to_source(node, indent_with=' ' * 4, add_line_information=False):
"""This function can convert a node tree back into python sourcecode.
This is useful for debugging purposes, especially if you're dealing with
custom asts not generated by python itself.
It could be that the sourcecode is evaluable when the AST itself is not
compilable / evaluable. The reason for this is that the AST contains some
more data than regular sourcecode does, which is dropped during
conversion.
Each level of indentation is replaced with `indent_with`. Per default this
parameter is equal to four spaces as suggested by PEP 8, but it might be
adjusted to match the application's styleguide.
If `add_line_information` is set to `True` comments for the line numbers
of the nodes are added to the output. This can be used to spot wrong line
number information of statement nodes.
"""
generator = SourceGenerator(indent_with, add_line_information)
generator.visit(node)
# for i in generator.result: print(i, end='') print()
return ''.join(map(str, generator.result))[2:]
def combine(ls1, ls2):
res = [None for _ in ls1 + ls2]
res[::2], res[1::2] = ls1, ls2
return res
class SourceGenerator(NodeVisitor):
"""This visitor is able to transform a well formed syntax tree into python
sourcecode. For more details have a look at the docstring of the
`node_to_source` function.
"""
def __init__(self, indent_with, add_line_information=False):
self.result = []
self.indent_with = indent_with
self.add_line_information = add_line_information
self.indentation = -1
def write(self, *xlist, start='', end='', sep=' ', prepend='', endeach=''):
def app_or_line(c):
if c == '\n':
self.result.append('\n' + self.indent_with * self.indentation)
else:
self.result.append(c)
putsep = False
app_or_line(start)
for x in xlist:
if not x or (isinstance(x, tuple) and len(x) == 4 and not x[1]):
continue
if putsep and not isinstance(x, list):
app_or_line(sep)
putsep = True
app_or_line(prepend)
if isinstance(x, AST):
self.visit(x)
elif isinstance(x, list):
'assuming it is "body"'
self.indentation += 1
self.write(*x, prepend='\n', start=':')
self.indentation -= 1
putsep = False
continue
elif isinstance(x, tuple) and len(x) == 4:
pre, key, join, val = x
self.write(pre, key, join if val else None, val, sep='')
else:
self.result.append(x)
app_or_line(endeach)
app_or_line(end)
def body_or_else(self, node):
self.write(node.body)
self.write(('else', node.orelse, None, None), prepend='\n')
def visit_arg(self, node):
self.write(node.arg, node.annotation, sep=':')
def visit_NameConstant(self, node):
self.write('{0}'.format(node.value))
def sigwrite(self, name, args, keywords):
args += [ (None, k.arg, '=', k.value) for k in keywords]
self.write(name)
self.write(*args, sep=', ', **enclose('()'))
def visit_arguments(self, node):
rearg = [i.arg for i in node.args]
size = len(node.defaults)
args = rearg[:-size] if size > 0 else rearg
args += [(None, k, '=', v) for k, v in zip(rearg[-size:] , node.defaults) ]
def dump_with_annot(c, arg):
args = []
if arg is not None:
args += [c + arg.arg]
if arg.annotation is not None:
args[-1] += ':' + arg.annotation.id
return args
if node.vararg is not None:
args += dump_with_annot('*', node.vararg)
args += [(None, k, '=', v) for k, v in zip(node.kwonlyargs, node.kw_defaults)]
if node.kwarg is not None:
args += dump_with_annot('**', node.kwarg)
return self.write(*args, sep=', ', **enclose('()'))
def decorators(self, node):
self.write(*node.decorator_list, prepend='@', endeach='\n', sep='')
# Statements
def visit_Assign(self, node):
self.write(*node.targets, sep=', ')
self.write('=', node.value, start=' ')
def visit_AugAssign(self, node):
self.write(node.target, BINOP_SYMBOLS[type(node.op)] + '=', node.value)
def visit_Import(self, node):
self.write(*node.names, sep=',', start='import', prepend=' ')
def visit_ImportFrom(self, node):
self.write('from', '.' * node.level + node.module, endeach=' ', sep='')
self.visit_Import(node)
def visit_Module(self, node):
self.write(node.body)
def visit_Expr(self, node):
self.generic_visit(node)
def visit_FunctionDef(self, node):
dump(node)
self.decorators(node)
self.write('def ', node.name, node.args, node.body, sep='', end='\n')
def visit_ClassDef(self, node):
self.decorators(node)
self.write('class ')
self.sigwrite(node.name, node.bases, node.keywords)
self.write(node.body, end='\n')
def visit_If(self, node):
self.write('if ', node.test, node.body, sep='')
orelse = node.orelse
if len(orelse) == 1 and isinstance(orelse[0], If):
self.write('elif', orelse[0].test, orelse[0].body, start='\n')
elif len(orelse) > 0:
self.write('else', orelse, start='\n')
def visit_For(self, node):
self.write('for', node.target, 'in', node.iter)
self.body_or_else(node)
def visit_While(self, node):
self.write('while', node.test)
self.body_or_else(node)
def visit_With(self, node):
self.write('with ')
for item in node.items:
self.write(item.context_expr, 'as', item.optional_vars)
self.write(node.body)
def visit_Pass(self, node):
self.write('pass')
def visit_Delete(self, node:Delete):
self.write(*node.targets, sep=',', start='del', prepend=' ')
def visit_TryExcept(self, node):
self.write('try', node.body, *node.handlers)
def visit_ExceptHandler(self, node):
self.write('except', node.type, (' as ', node.name, None, None), node.body)
def visit_Try(self, node):
# python 3.4
self.write('try', node.body)
self.write(*node.handlers + [('finally', node.finalbody, None, None)], prepend='\n')
def visit_TryFinally(self, node):
self.write('try', node.body, 'finally', node.finalbody)
def write_list(self, start, args):
self.write(*args, start=start, sep=',', prepend=' ')
def visit_Global(self, node):
self.write_list('global', node.names)
def visit_Nonlocal(self, node):
self.write_list('nonlocal', node.names)
def visit_Assert(self, node):
self.write_list('assert', [node.test, node.msg])
def visit_Return(self, node):
self.write('return', node.value)
def visit_Break(self, node):
self.write('break')
def visit_Continue(self, node):
self.write('continue')
def visit_Raise(self, node):
# XXX: Python 2.6 / 3.0 compatibility
self.write('raise')
if hasattr(node, 'exc') and node.exc is not None:
self.write(node.exc, 'from', node.cause, start=' ')
elif hasattr(node, 'type') and node.type is not None:
self.write(node.type, node.inst, node.tback, sep=', ')
# Expressions
def visit_Attribute(self, node):
self.write(node.value, '.', node.attr, sep='')
def visit_Call(self, node):
self.sigwrite(node.func, node.args, node.keywords)
def visit_Name(self, node):
self.write(node.id)
def visit_Str(self, node):
s = repr(node.s)
if '\\n' in s and len(s) > 5:
s = s[0] * 2 + s.replace(r'\n', '\n') + s[0] * 2
self.write(s)
def visit_Bytes(self, node):
self.write(repr(node.s))
def visit_Num(self, node:Num):
self.write(repr(node.n))
def visit_Tuple(self, node):
self.write(*node.elts, start='(', sep=', ')
if len(node.elts) == 1:
self.write(',')
self.write(')')
def visit_List(self, node):
self.write(*node.elts, **enclose('[, ]'))
def visit_Set(self, node):
self.write(*node.elts, **enclose('{, }'))
def visit_Dict(self, node):
self.write(*[(None, k, ':', v) for k, v in zip(node.keys, node.values)], **enclose('{,}'))
def visit_BinOp(self, node):
self.write(node.left, BINOP_SYMBOLS[type(node.op)], node.right)
def visit_BoolOp(self, node):
self.write(*node.values, sep=' ' + BOOLOP_SYMBOLS[type(node.op)] + ' ', **enclose('()'))
def visit_Compare(self, node):
row = combine([node.left] + node.comparators, [CMPOP_SYMBOLS[type(op)] for op in node.ops])
self.write(*row, start='(', end=')')
def visit_UnaryOp(self, node):
self.write(UNARYOP_SYMBOLS[type(node.op)], node.operand, start='(', end=')', sep='')
def visit_Subscript(self, node):
self.write(node.value)
self.write(node.slice, **enclose('[]'))
def visit_Slice(self, node):
self.write(node.lower, ':', node.upper, (':', node.step, None, None), sep='')
def visit_ExtSlice(self, node):
self.write(*node.dims, sep=', ')
def visit_Yield(self, node):
self.write('yield', node.value)
def visit_Lambda(self, node):
self.write('lambda', node.args, ':', node.body)
def visit_Ellipsis(self, node):
self.write('...')
def generator_visit(self, node, left, right):
self.write(node.elt, *node.generators, **enclose([left, right]))
def visit_ListComp(self, node):
self.generator_visit(node, '[', ']')
def visit_GeneratorExp(self, node):
self.generator_visit(node, '(', ')')
def visit_SetComp(self, node):
self.generator_visit(node, '{', '}')
def visit_DictComp(self, node):
self.write((None, node.key, ':', node.value), *node.generators, **enclose('{}'))
def visit_IfExp(self, node):
self.write(node.body, 'if', node.test, 'else', node.orelse, **enclose('()'))
def visit_Starred(self, node):
self.write('*', node.value, sep='')
# Helper Nodes
def visit_alias(self, node):
self.write(node.name, node.asname, sep=' as ')
def visit_comprehension(self, node):
self.write('for', node.target, 'in', node.iter)
self.write(*node.ifs, prepend=' if ')
from ast import *
import collections
def get_close_names(body):
return ( (i, s.value.func.value.id) for i, s in enumerate(body)
if type(s) == Expr
and type(s.value) == Call
and type(s.value.func) == Attribute
and type(s.value.func.value) == Name
and s.value.func.attr == 'close')
def get_assign_with_names(body, names):
yield from ((i, s.targets[0].id, s.value) for i, s in enumerate(body)
if type(s) == Assign
and type(s.value) == Call
and len(s.targets)==1 and type(s.targets[0]) == Name and s.targets[0].id in names)
def make_with(name, call, sub_body):
return copy_location(
With(items=[withitem(context_expr=call,
optional_vars=Name(id=name, ctx=Store()))],
body=sub_body), call)
def prep_body(body):
body = list(body)
names = list(get_close_names(body))
assigns = get_assign_with_names(body, [x[1] for x in names])
for i, name, call in assigns:
closer = next(j for j, nm in names if nm == name)
body[i:closer+1] = [make_with(name, call, body[i+1:closer])]
return prep_body(body)
return body
class Wither(NodeTransformer):
def generic_visit(self, node):
super().generic_visit(node)
d = {k:getattr(node, k) for k in node._fields}
for k in ['body', 'orelse', 'finalbody']:
body = getattr(node, k, None)
if isinstance(body, collections.Iterable):
d[k] = prep_body(body=getattr(node, k))
return copy_location(type(node)(**d), node)
def wither(node):
return Wither().visit(node)
def transform_and_print(filename):
import pretty_print
f = open(filename)
node = parse(f.read())
f.close()
print(pretty_print.to_source(wither(node)))
if __name__ == '__main__':
import sys
transform_and_print(sys.argv[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment