Skip to content

Instantly share code, notes, and snippets.

@SquidDev
Created August 19, 2014 16:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save SquidDev/01479c11fef6406dd5ee to your computer and use it in GitHub Desktop.
Save SquidDev/01479c11fef6406dd5ee to your computer and use it in GitHub Desktop.
Fun stuff with AST trees

Fun stuff with AST trees in python.

  • Converts pow and min into max
  • Converts tuples into lists
  • Converts print statements into lists
  • Converts return x,y,z into _ = [x,y,z]
from Dumper import Dump
from AstTools import *
import ast
Environment = {"A" : 2}
Mode = "exec"
Simple = ast.parse('return pow(2, A + 5), (1, 2, 3)', mode = Mode)
#print(Dump(Simple.body[0], MaxDepth = 16, HideUnderscores = True))
print DumpNode(Simple, AnnotateFields = True, Pad = True)
try:
DebugExecute(Simple, Environment, Mode = Mode)
print "Last result", Environment['_']
except Exception, e:
print e
Transfomer = GenericTransformer()
def PowerToMax(Node):
# Replace all instances of 'pow' and "min" with 'max' (could work with 'print')
if Node.func.id == "pow" or Node.func.id == "min":
return CloneNode(Node, func = CloneNode(Node.func, id = 'max'))
return Node
def TupleToList(Node):
# Replace all tuples with lists
NewNode = ast.List(elts = Node.elts, ctx = Node.ctx)
ast.copy_location(NewNode, Node)
ast.fix_missing_locations(Node)
return NewNode
def PrintToList(Node):
#Replace all prints with lists
NewNode = ast.Expr(value = ast.List(elts = Node.values, ctx = ast.Load()), ctx = ast.Load())
ast.copy_location(NewNode, Node)
ast.fix_missing_locations(NewNode)
return NewNode
def ReturnToAssignment(Node):
# Assign to _ instead of returning
NewNode = ast.Assign(targets = [ast.Name(id = '_', ctx = ast.Store())], value = Node.value)
ast.copy_location(NewNode, Node)
ast.fix_missing_locations(NewNode)
return NewNode
Transfomer.SetNodeHandler("Call", PowerToMax)
Transfomer.SetNodeHandler("Tuple", TupleToList)
Transfomer.SetNodeHandler("Print", PrintToList)
Transfomer.SetNodeHandler("Return", ReturnToAssignment)
Transfomer.visit(Simple)
print DumpNode(Simple, AnnotateFields = True, Pad = True)
try:
DebugExecute(Simple, Environment, Mode = Mode)
print "Last result", Environment['_']
except Exception, e:
print e
from ast import AST, copy_location, NodeTransformer
#import types
class GenericTransformer(NodeTransformer):
def __init__(self, *args, **kwargs):
super(GenericTransformer, self).__init__(*args, **kwargs)
self.NodeTypes = {}
def __getattr__(self, Function):
if Function.startswith("visit_"):
Node = Function.replace("visit_", "")
if self.NodeTypes.has_key(Node):
return self.NodeTypes[Node]
return getattr(super(GenericTransformer, self), Function)
def SetNodeHandler(self, Name, Function):
def Call(Node):
NewNode = Function(Node)
NewNode = self.generic_visit(NewNode)
return NewNode
#self.NodeTypes[Name] = types.MethodType(Function, self)
self.NodeTypes[Name] = Call
def CloneNode(OldNode, **Changes):
"""Clone a node with new arguments"""
# Get node class
NodeClass = OldNode.__class__
if Changes.has_key('NodeClass'):
NodeClass = Changes['NodeClass']
# Get fields
Fields = {}
for Key, Value in IterFields(OldNode):
if Changes.has_key(Key):
Value = Changes[Key]
Fields[Key] = Value
return copy_location(NodeClass(**Fields), OldNode)
def CompileExecute(Node, Environment, Mode = 'exec', Flags = 0):
"""Compile and execute a node"""
return eval(compile(Node, "Simple", Mode, Flags), Environment)
def DebugExecute(*Args, **Kwargs):
print "##Executing node##"
print "Result: " + str(CompileExecute(*Args, **Kwargs))
def DumpNode(Node, AnnotateFields = False, Pad = False):
"""Create a string version of the node"""
Concat = ', '.join
ConcatMany = Concat
if Pad:
ConcatMany = ',\n'.join
def _Format(Node):
if isinstance(Node, AST):
Fields = [(A, _Format(B)) for A, B in IterFields(Node)]
FieldGeneration = None
if AnnotateFields:
FieldGeneration = ('%s = %s' % Field for Field in Fields)
else:
FieldGeneration = (B for A, B in Fields)
ClassOpener = Node.__class__.__name__ + " ("
Contents = Concat(FieldGeneration)
if Pad and Contents != "" and not Contents.isspace():
Contents = "\n" + ContinuousIndent(Contents) + "\n"
return ClassOpener + Contents + ")"
elif isinstance(Node, list):
if len(Node) > 0:
Contents = '['
if Pad: Contents += "\n"
Children = ConcatMany(_Format(X) for X in Node)
if Pad: Children = ContinuousIndent(Children)
Contents += Children
if Pad: Contents += "\n"
return Contents + ']'
return '[]'
return repr(Node)
if not isinstance(Node, AST):
raise TypeError('expected AST, got %r' % Node.__class__.__name__)
return _Format(Node)
def IterFields(Node):
"""(Key, value) pair of all fields in node"""
for Field in Node._fields:
try:
yield Field, getattr(Node, Field)
except AttributeError:
pass
def ContinuousIndent(String, Indent = "\t"):
"""Indent every line"""
return '\n'.join(((Indent + Line) for Line in String.split("\n")))
from __future__ import print_function
from textwrap import TextWrapper
_Id = id
Generics = (int, str, float, long, bool, complex, unicode)
ComplexGenerics = (list, tuple, frozenset)
IgnoreNames = ("method-wrapper", "builtin_function_or_method", "wrapper_descriptor")
def _Dump(Object, Seen = {}, MaxDepth = 2, Indent = '', Doc = False, HideUnderscores = False):
if MaxDepth <= 0:
return str(Object)
Children = []
ChildIndent = TextWrapper(subsequent_indent = Indent + "\t")
ChildItems = False
ChildFunction = False
if isinstance(Object, dict):
ChildItems = Object.keys()
ChildFunction = _ObjectGettter
elif isinstance(Object, ComplexGenerics):
ChildItems = xrange(0, len(Object))
ChildFunction = _ObjectGettter
else:
ChildItems = sorted(dir(Object))
ChildFunction = getattr
for ChildName in ChildItems:
if not (HideUnderscores and isinstance(ChildName, str) and len(ChildName) >= 5 and ChildName.startswith('__') and ChildName.endswith('__')):
Child, DoChildren = _HandleChild(ChildFunction(Object, ChildName), ChildName, Seen, ChildIndent)
Children.append((ChildName, Child, DoChildren))
for (ChildName, Child, DoChildren) in Children:
print(Indent + str(ChildName) + ' : ' + str(Child))
if DoChildren:
_Dump(Child, Seen, MaxDepth - 1, Indent + "\t", Doc, HideUnderscores)
elif Doc and Child.__class__.__name__ in IgnoreNames and hasattr(Child, "__doc__"):
print(ChildIndent.fill(('%r' % Child.__doc__).replace('\\n', '\n')))
def _HandleChild(Child, Name, Seen, Indent):
DoChildren = True
Id = _Id(Child)
if isinstance(Child, str):
Child = Indent.fill(('%r' % Child).replace('\\n', '\n'))
DoChildren = False
elif Child == None or isinstance(Child, Generics) or Child.__class__.__name__ in IgnoreNames or Name == "__class__":
DoChildren = False
elif Seen.has_key(Id):
Child = Seen[Id]
DoChildren = False
else:
Seen[Id] = "<Already seen: " + str(Id) + " : " + str(Child) + ">"
return Child, DoChildren
def _ObjectGettter(Object, ChildName):
return Object[ChildName]
def Dump(Object, **KeyArgs):
KeyArgs['Indent'] = "\t"
print(Object)
_Dump(Object, **KeyArgs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment