Instantly share code, notes, and snippets.

@brianairb /ast.py Secret
Created Aug 27, 2016

Embed
What would you like to do?
from pyparsing import col, lineno
VOID = intern('void')
INT = intern('int')
FLOAT = intern('float')
STRING = intern('string')
OBJECT = intern('object')
FUNCTION = intern('function')
ANY = intern('any')
GLOBAL = intern('GLOBAL')
class AstException(Exception):
def __init__(self, node, msg):
if node.loc:
Exception.__init__(self, '%s (line %d, col %d)' % (msg, node.lineno, node.col))
else:
Exception.__init__(self, '%s' % msg)
class Node(object):
def __init__(self, loc=None):
self.children = []
self.loc = loc
self._scope = None
self.comment = None
def __repr__(self):
return '<%s>' % type(self).__name__
@property
def lineno(self):
return None if self.loc is None else self.loc[0]
@property
def col(self):
return None if self.loc is None else self.loc[1]
@property
def scope(self):
if self._scope is None:
return self.parent.scope
return self._scope
@property
def func(self):
return self.parent.func
@property
def vartype(self):
return VOID
@staticmethod
def traverse(node, preop, postop, depth=0):
if preop:
node = preop(node, depth)
if node is not None:
for i in xrange(len(node.children)):
if node.children[i] is not None:
node.children[i].parent = node
node.children[i] = Node.traverse(node.children[i], preop, postop, depth+1)
if postop:
node = postop(node, depth)
return node
def scope_visitor(self):
self._scope = None
def validate(self):
pass
def find_ancestor(self, cls_list):
if any([isinstance(self, cls) for cls in cls_list]):
return self
if self.parent is None:
return None
return self.parent.find_ancestor(cls_list)
def is_constant(self):
constant = all([node.is_constant() for node in self.children])
return constant and len(self.children) > 0
class LiteralNode(Node):
def __init__(self, value, **kwargs):
super(LiteralNode, self).__init__(**kwargs)
self.keyword = None
self.value = value
def __repr__(self):
return str(self.value)
@property
def vartype(self):
if isinstance(self.value, str):
return STRING
elif isinstance(self.value, float):
return FLOAT
else:
return INT
@classmethod
def from_tokens(cls, st, loc, toks):
return cls(toks[0], loc=(lineno(loc, st), col(loc, st)))
def is_constant(self):
return True
def copy(self):
return LiteralNode(self.value)
class VarNode(Node):
def __init__(self, name, **kwargs):
super(VarNode, self).__init__(**kwargs)
self.name = name
self.lval = False
def __repr__(self):
return '<VarNode %s>' % self.name
@property
def decl(self):
return self.scope.get(self.name)
@property
def vartype(self):
node = self.decl
if node is None:
raise AstException(self, '%s not defined' % self.name)
elif isinstance(node, DeclNode) and node.size > 1:
raise AstException(self, 'Array must be subscripted')
return node.decltype
@classmethod
def from_tokens(cls, st, loc, toks):
return cls(toks[0], loc=(lineno(loc, st), col(loc, st)))
def copy(self):
node = LiteralNode(self.name)
node.lval = self.lval
return node
class BinOpNode(Node):
def __init__(self, op, lhs, rhs, **kwargs):
super(BinOpNode, self).__init__(**kwargs)
self.op = op
self.children = [lhs, rhs]
def __repr__(self):
return '<BinOp %s>' % self.op
@property
def lhs(self):
return self.children[0]
@property
def rhs(self):
return self.children[1]
@property
def vartype(self):
lhs_vartype = self.lhs.vartype
rhs_vartype = self.rhs.vartype
if self.op in '< > <= >= != =='.split(' '):
# comparisons always produce an integer
return INT
if lhs_vartype is ANY:
return rhs_vartype
else:
return lhs_vartype
@classmethod
def from_tokens(cls, st, loc, toks):
assert len(toks) == 1
toks = toks[0]
node = toks[0]
for x in xrange(1, len(toks), 2):
node = cls(toks[x], node, toks[x+1], loc=(lineno(loc, st), col(loc, st)))
return node
def validate(self):
if self.op in '= += -= *= /= %='.split(' '):
if isinstance(self.lhs, VarNode) or isinstance(self.lhs, SubscriptNode):
self.lhs.lval = True
else:
raise AstException(self, 'Invalid lvalue')
lhs_vartype = self.lhs.vartype
rhs_vartype = self.rhs.vartype
if lhs_vartype is not rhs_vartype and lhs_vartype is not ANY and rhs_vartype is not ANY:
raise AstException(self, '\'%s\' with wrong type (%s is not %s)' % (self.op, lhs_vartype, rhs_vartype))
def copy(self):
return BinOpNode(self.op, self.lhs.copy(), self.rhs.copy())
class AssignNode(BinOpNode):
def __init__(self, op, lhs, rhs, decltype=None, **kwargs):
super(AssignNode, self).__init__(op, lhs, rhs, **kwargs)
self.decltype = intern(decltype) if decltype else None
def __repr__(self):
return '<AssignNode %s %s>' % (self.op, self.decltype)
@property
def size(self):
return 1
@property
def name(self):
return self.lhs.name
@classmethod
def from_tokens(cls, st, loc, toks):
return cls('=', toks[1], toks[2], toks[0], loc=(lineno(loc, st), col(loc, st)))
@classmethod
def from_tokens_op(cls, st, loc, toks):
return cls(toks[1], toks[0], toks[2], loc=(lineno(loc, st), col(loc, st)))
#assert len(toks) == 1
#toks = toks[0]
#assert len(toks) == 3
#return cls(toks[1], toks[0], toks[2], loc=(lineno(loc, st), col(loc, st)))
def is_global(self):
return self.func is None or self.func.name is GLOBAL
def scope_visitor(self):
if self.decltype is not None:
self.scope.add(self.lhs.name, self)
def validate(self):
super(AssignNode, self).validate()
if self.func == None:
if not self.decltype:
raise AstException(self, '%s missing declaration type' % self.name)
if not self.rhs.is_constant():
raise AstException(self, '%s initializer must be a constant expression' % self.name)
def copy(self):
return BinOpNode(self.op, self.lhs.copy(), self.rhs.copy(), self.decltype)
class UnOpNode(Node):
def __init__(self, op, rhs, **kwargs):
super(UnOpNode, self).__init__(**kwargs)
self.op = op
self.children = [rhs]
def __repr__(self):
return '<UnOpNode %s>' % self.op
@property
def rhs(self):
return self.children[0]
@property
def vartype(self):
return self.rhs.vartype
@classmethod
def from_tokens(cls, st, loc, toks):
assert len(toks) == 1
toks = toks[0]
assert len(toks) == 2
return cls(toks[0], toks[1], loc=(lineno(loc, st), col(loc, st)))
def copy(self):
return UnOpNode(self.op, self.rhs.copy())
class CallNode(Node):
def __init__(self, callee, args, **kwargs):
super(CallNode, self).__init__(**kwargs)
self.callee = callee
self.children = args
def __repr__(self):
return '<CallNode %s>' % self.callee
@property
def args(self):
return self.children
@property
def decl(self):
return self.scope.get(self.callee)
@property
def vartype(self):
node = self.decl
return node.rettype
def scope_visitor(self):
# Once we have a scope, fix order of arguments
if not isinstance(self.decl, BuiltinNode):
self.children.reverse()
def validate(self):
node = self.decl
if node is None:
raise AstException(self, '%s is not defined' % self.callee)
if not isinstance(node, FuncNode) and not isinstance(node, BuiltinNode):
raise AstException(self, '%s is not a function' % self.callee)
params = [x.decltype for x in node.params]
if len(params) != len(self.children):
raise AstException(self, 'Invalid arguments for \'%s\': expected %d, got %d' % (node.name, len(params), len(self.children)))
if not isinstance(node, BuiltinNode):
params.reverse()
for i in xrange(len(params)):
if params[i] is not ANY and self.children[i].vartype is not ANY and params[i] is not self.children[i].vartype:
raise AstException(self, 'Wrong type for \'%s\': expected %s, got %s' % (node.name, params[i], self.children[i].vartype))
elif params[i] is FUNCTION and self.children[i].decl.rettype is not VOID:
raise AstException(self, 'Function callbacks must have a void return type')
@classmethod
def from_tokens(cls, st, loc, toks):
return cls(toks[0], toks[1:], loc=(lineno(loc, st), col(loc, st)))
def is_constant(self):
return False
class ReturnNode(Node):
def __init__(self, value=None, **kwargs):
super(ReturnNode, self).__init__(**kwargs)
self.children = [value]
@property
def value(self):
return self.children[0]
def validate(self):
vartype = self.value.vartype if self.value is not None else VOID
node = self.func
if node.rettype is not vartype:
raise AstException(self, 'Function return type mismatch: expected %s, got %s' % (node.rettype, self.value.vartype))
@classmethod
def from_tokens(cls, st, loc, toks):
return cls(*toks, loc=(lineno(loc, st), col(loc, st)))
class BreakNode(Node):
@classmethod
def from_tokens(cls, st, loc, toks):
return cls(loc=(lineno(loc, st), col(loc, st)))
class ContinueNode(Node):
@classmethod
def from_tokens(cls, st, loc, toks):
return cls(loc=(lineno(loc, st), col(loc, st)))
class GotoNode(Node):
def __init__(self, target, cond=None, **kwargs):
super(GotoNode, self).__init__(**kwargs)
self.target = target
self.children = [cond]
def __repr__(self):
return '<GotoNode %s%s>' % (self.target, ' conditional' if self.cond else '')
@property
def cond(self):
return self.children[0]
@classmethod
def from_tokens(cls, st, loc, toks):
return cls(toks[0], loc=(lineno(loc, st), col(loc, st)))
class LabelNode(Node):
def __init__(self, label, **kwargs):
super(LabelNode, self).__init__(**kwargs)
self.label = label
def __repr__(self):
return '<LabelNode %s>' % self.label
@classmethod
def from_tokens(cls, st, loc, toks):
return cls(toks[0], loc=(lineno(loc, st), col(loc, st)))
def scope_visitor(self):
self.func.scope.add(self.label, self)
class DeclNode(Node):
def __init__(self, decltype, name, size=None, **kwargs):
super(DeclNode, self).__init__(**kwargs)
self.decltype = intern(decltype)
self.name = name
if size is not None and size <= 1:
raise AstException(self, 'Array size must be greater than 1')
self.size = size or 1
def __repr__(self):
return '<DeclNode %s %s %d>' % (self.decltype, self.name, self.size)
def is_global(self):
return self.func is None or self.func.name is GLOBAL
@classmethod
def from_tokens(cls, st, loc, toks):
return cls(*toks, loc=(lineno(loc, st), col(loc, st)))
def scope_visitor(self):
self.scope.add(self.name, self)
class SubscriptNode(Node):
def __init__(self, name, sub, **kwargs):
super(SubscriptNode, self).__init__(**kwargs)
self.name = name
self.children = [sub]
self.lval = False
def __repr__(self):
return '<SubscriptNode %s>' % self.name
@property
def subscript(self):
return self.children[0]
@property
def decl(self):
return self.scope.get(self.name)
@property
def vartype(self):
node = self.decl
if node is None:
raise AstException(self, '%s is not defined' % self.name)
elif node.size <= 1:
raise AstException(self, '%s is not an array' % self.name)
return node.decltype
@classmethod
def from_tokens(cls, st, loc, toks):
return cls(*toks, loc=(lineno(loc, st), col(loc, st)))
class IfNode(Node):
def __init__(self, cond, ifthen, ifelse=None, **kwargs):
super(IfNode, self).__init__(**kwargs)
self.children = [cond, ifthen, ifelse]
@property
def cond(self):
return self.children[0]
@cond.setter
def cond(self, value):
self.children[0] = value
@property
def ifthen(self):
return self.children[1]
@ifthen.setter
def ifthen(self, value):
self.children[1] = value
@property
def ifelse(self):
return self.children[2]
@ifelse.setter
def ifelse(self, value):
self.children[2] = value
@classmethod
def from_tokens(cls, st, loc, toks):
return cls(*toks, loc=(lineno(loc, st), col(loc, st)))
def validate(self):
if self.cond.vartype is not INT:
raise AstException(self, 'Condition must be int-type')
class WhileNode(Node):
def __init__(self, cond, body, **kwargs):
super(WhileNode, self).__init__(**kwargs)
self.children = [cond, body]
@property
def cond(self):
return self.children[0]
@property
def body(self):
return self.children[1]
@classmethod
def from_tokens(cls, st, loc, toks):
return cls(*toks, loc=(lineno(loc, st), col(loc, st)))
def scope_visitor(self):
self._scope = Scope(self.parent.scope)
def validate(self):
if self.cond.vartype is not INT:
raise AstException(self, 'Condition must be int-type')
class ForNode(Node):
def __init__(self, init, cond, inc, body, **kwargs):
super(ForNode, self).__init__(**kwargs)
self.children = [init, cond, inc, body]
@property
def init(self):
return self.children[0]
@property
def cond(self):
return self.children[1]
@property
def inc(self):
return self.children[2]
@property
def body(self):
return self.children[3]
@classmethod
def from_tokens(cls, st, loc, toks):
return cls(*toks, loc=(lineno(loc, st), col(loc, st)))
def scope_visitor(self):
self._scope = Scope(self.parent.scope)
def validate(self):
if self.cond is not None:
if self.cond.vartype is not INT:
raise AstException(self, 'Condition must be int-type')
class BlockNode(Node):
def __init__(self, nodes, **kwargs):
super(BlockNode, self).__init__(**kwargs)
self.children = nodes
@classmethod
def from_tokens(cls, st, loc, toks):
return cls(*toks, loc=(lineno(loc, st), col(loc, st)))
def scope_visitor(self):
self._scope = Scope(self.parent.scope)
class BuiltinNode(Node):
def __init__(self, num, rettype, name, params):
self.num = num
self.rettype = intern(rettype)
self.name = name
self.params = [DeclNode(param, None) for param in params]
class FuncNode(Node):
def __init__(self, rettype, name, params, body, **kwargs):
super(FuncNode, self).__init__(**kwargs)
self.rettype = intern(rettype)
self.name = name
self.tmp_id = 0
self.children = params + [body]
self.decltype = FUNCTION
def __repr__(self):
return '<FuncNode %s %s>' % (self.rettype, self.name)
@property
def func(self):
return self
@property
def params(self):
return self.children[:-1]
@property
def body(self):
return self.children[-1]
@body.setter
def body(self, value):
self.children[-1] = value
@classmethod
def from_tokens(cls, st, loc, toks):
return cls(toks[0], toks[1], toks[2:-1], toks[-1], loc=(lineno(loc, st), col(loc, st)))
def new_label(self):
node = LabelNode('.L%d' % self.tmp_id)
self.tmp_id += 1
self.scope.add(node.label, node)
return node
def scope_visitor(self):
self._scope = Scope(self.parent.scope)
self.parent.scope.add(self.name, self)
class GlobalNode(Node):
def __init__(self, children):
super(GlobalNode, self).__init__()
self.children = children if isinstance(children, list) else children.asList()
self.parent = None
@property
def func(self):
return None
@classmethod
def from_tokens(cls, st, loc, toks):
return cls(toks)
def scope_visitor(self):
self._scope = Scope(None)
class Scope(object):
def __init__(self, parent):
self.parent = parent
self.names = {}
def get(self, name):
node = self.names.get(name, None)
if node is None and self.parent is not None:
node = self.parent.get(name)
return node
def add(self, name, node):
if self.get(name) is not None:
raise AstException(node, '%s is already defined' % name)
self.names[name] = node
def remove(self, name, node):
assert name in self.names and self.names[name] == node
del self.names[name]
def print_ast(root):
def visitor(node, depth):
if node is None:
return node
print '%s%s (%s)' % (' '*depth*2, node, node.vartype)
return node
Node.traverse(root, visitor, None)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment