Skip to content

Instantly share code, notes, and snippets.

@worldbeater
Last active May 30, 2024 07:55
Show Gist options
  • Save worldbeater/c6bfc37b6c3c1933fdb5a4f1aa190133 to your computer and use it in GitHub Desktop.
Save worldbeater/c6bfc37b6c3c1933fdb5a4f1aa190133 to your computer and use it in GitHub Desktop.
Unified Abstract Syntax Tree (UAST). Inspired by https://arxiv.org/abs/2106.09173, created with an intention to implement fuzzy code-to-code search algorithms in a language-agnostic fashion, see "An Approach to Generating Recommendations for Improving the Performance of Software for Heterogeneous Computing Platforms" http://injoit.org/index.php/…
import ast
import pprint
import pycparser
import pycparser.c_ast as cast
# Backus–Naur Form of UAST.
# <body> ::= ('body', <stmt1>, <stmt2>, ...)
# <stmt> ::= ('func', <str>, <params>, <body>)
# | ('if', <expr>, <body>, <body>)
# | ('assign', <expr>, <expr>)
# | ('ret', <expr>)
# | ('loop', <expr>, <body>)
# <params> ::= ('params', <expr1>, <expr2>, ...)
# <args> ::= ('args', <expr1>, <expr2>, ...)
# <expr> ::= ('call', <expr>, <args>)
# | ('seq', <expr1>, <expr2>, ...)
# | ('get', <expr>, <expr>)
# | ('var', <str>)
# | ('const', <str>)
# | ('+', <expr>, <expr>)
# | ('-', <expr>, <expr>)
# | ('*', <expr>, <expr>)
# | ('/', <expr>, <expr>)
# | ('>', <expr>, <expr>)
# ...
PYOPS = {
ast.GtE: '>=',
ast.LtE: '<=',
ast.Gt: '>',
ast.Lt: '<',
ast.Eq: '=',
ast.NotEq: '!=',
ast.Mult: '*',
ast.Sub: '-',
ast.Add: '+',
ast.Pow: '**',
ast.Div: '/',
}
def pyloc(node):
return (node.lineno, node.col_offset)
def py2ir(tree, loc=None):
match tree:
case ast.Assign([ast.Tuple(targets)], ast.Tuple(values)):
return [('assign', py2ir(target), py2ir(value), pyloc(tree))
for target, value
in zip(targets, values)]
case ast.For(ast.Name(id), ast.Call(ast.Name('range'), [arg]), body):
iis0 = ('assign', ('var', id, pyloc(tree)), ('const', '0', pyloc(tree)), pyloc(tree))
add1 = ('+', ('var', id, pyloc(tree)), ('const', '1', pyloc(tree)), pyloc(tree))
incr = ('assign', ('var', id, pyloc(tree)), add1, pyloc(tree))
cond = ('<', ('var', id, pyloc(tree)), py2ir(arg), pyloc(tree))
*bodyparts, loc = py2ir(body, pyloc(tree))
body = (*bodyparts, incr, loc)
loop = ('loop', cond, body, pyloc(tree))
return [iis0, loop]
case ast.Module(body):
return py2ir(body, (0, 0))
case ast.FunctionDef(name, args, body):
return ('func', name, py2ir(args, pyloc(tree)), py2ir(body, pyloc(tree)), pyloc(tree))
case ast.arguments(_, args):
return ('params', *map(py2ir, args), loc)
case ast.arg(name):
return name
case ast.If(test, body, orelse):
return ('if', py2ir(test), py2ir(body, pyloc(tree)), py2ir(orelse, pyloc(tree)), pyloc(tree))
case ast.Compare(left, [op], [right]) | ast.BinOp(left, op, right):
return (PYOPS[type(op)], py2ir(left), py2ir(right), pyloc(tree))
case ast.Assign(targets, value):
return ('assign', py2ir(targets[0]), py2ir(value), pyloc(tree))
case ast.Return(value):
return ('ret', py2ir(value), pyloc(tree))
case ast.Call(name, args):
return ('call', py2ir(name), ('args', *map(py2ir, args), pyloc(tree)), pyloc(tree))
case ast.While(test, body):
return ('loop', py2ir(test), py2ir(body, pyloc(tree)), pyloc(tree))
case ast.List(elts):
return ('seq', *map(py2ir, elts), pyloc(tree))
case ast.Subscript(value, slice):
return ('get', py2ir(value), py2ir(slice), pyloc(tree))
case ast.UnaryOp(ast.USub(), op):
return ('-', ('const', '0', pyloc(tree)), py2ir(op), pyloc(tree))
case ast.UnaryOp(ast.UAdd(), op):
return ('+', ('const', '0', pyloc(tree)), py2ir(op), pyloc(tree))
case ast.Attribute(value, attr):
return ('get', py2ir(value), ('var', attr, pyloc(tree)), pyloc(tree))
case ast.Name(id):
return ('var', id, pyloc(tree))
case ast.Constant(value):
return ('const', str(value), pyloc(tree))
case ast.Expr(value):
return py2ir(value)
case ast.Import():
return []
case list():
output = []
for node in map(py2ir, tree):
output += node if isinstance(node, list) else [node]
assert loc, str(loc)
return ('body', *output, loc)
case _:
print('Unknown Python AST node:')
print(ast.dump(tree, indent=1))
raise ValueError
COPS = {
'<=': '<=',
'>=': '>=',
'<': '<',
'>': '>',
'==': '=',
'!=': '!=',
'*': '*',
'-': '-',
'+': '+',
'/': '/',
}
def cloc(tree):
return (tree.coord.line, tree.coord.column)
def c2ir(tree, loc=None):
match tree:
case cast.For(init=cast.DeclList(decls=decls), cond=cond, next=next, stmt=stmt):
*bodyparts, loc = c2ir(stmt, cloc(tree))
body = (*bodyparts, c2ir(next), loc)
loop = ('loop', c2ir(cond), body, cloc(tree))
return [*map(c2ir, decls), loop]
case cast.FileAST(ext=body):
return c2ir(body, (0, 0))
case cast.Compound(block_items=body):
return c2ir(body, loc)
case cast.FuncDef(decl=cast.Decl(name=name, type=cast.FuncDecl(args=args)), body=body):
args = c2ir(args) if args else ('params', cloc(tree))
return ('func', name, args, c2ir(body, cloc(tree)), cloc(tree))
case cast.ParamList(params=args):
return ('params', *map(c2ir, args), cloc(tree))
case cast.ExprList(exprs=args):
return ('args', *map(c2ir, args), cloc(tree))
case cast.Decl(name=name, init=None):
return name
case cast.If(cond=cond, iftrue=true, iffalse=false):
true = true or cast.Compound(block_items=[])
false = false or cast.Compound(block_items=[])
return ('if', c2ir(cond), c2ir(true, cloc(tree)), c2ir(false, cloc(tree)), cloc(tree))
case cast.BinaryOp(op=op, left=lhs, right=rhs):
return (COPS[op], c2ir(lhs), c2ir(rhs), cloc(tree))
case cast.Assignment(op=op, lvalue=lhs, rvalue=rhs):
return ('assign', c2ir(lhs), c2ir(rhs), cloc(tree))
case cast.Decl(name=name, init=init):
return ('assign', ('var', name, cloc(tree)), c2ir(init), cloc(tree))
case cast.Return(expr=expr):
return ('ret', c2ir(expr), cloc(tree))
case cast.FuncCall(name=name, args=args):
return ('call', c2ir(name), c2ir(args), cloc(tree))
case cast.ArrayRef(name=name, subscript=subscript):
return ('get', c2ir(name), c2ir(subscript), cloc(tree))
case cast.UnaryOp(op=op, expr=expr) if op in COPS:
return (COPS[op], ('const', '0', cloc(tree)), c2ir(expr), cloc(tree))
case cast.UnaryOp(op=op, expr=expr):
return ('call', op, ('args', c2ir(expr), cloc(tree)), cloc(tree))
case cast.ID(name=name):
return ('var', name, cloc(tree))
case cast.Constant(value=value):
return ('const', value, cloc(tree))
case cast.Typename(type=decl) | cast.PtrDecl(type=decl):
return c2ir(decl, cloc(tree))
case cast.TypeDecl(type=cast.IdentifierType(names=[name])):
return ('const', name, loc)
case cast.InitList(exprs=exprs):
return ('seq', *map(c2ir, exprs), cloc(tree))
case list():
output = []
for node in map(c2ir, tree):
output += node if isinstance(node, list) else [node]
assert loc
return ('body', *output, loc)
case _:
print('Unknown C AST node:')
print(tree)
raise ValueError
def unparse(tree, indent=1):
match tree:
case ('var' | 'const', value, _):
return value
case ('body', *stmts, _):
sep = f'\n{" " * indent}'
content = sep.join(unparse(stmt, indent + 1) for stmt in stmts)
return f'(body{sep}{content})'
case (op, *args, _):
content = ' '.join(unparse(arg, indent) for arg in args)
return f'({op} {content})'
case value:
return value
CDEMO = '''
int factorial(int n) {
unsigned long long fact = 1;
for (int i = 0; i < n; i = i + 1) {
fact = fact * (i + 1);
}
return fact;
}
'''.strip()
PYDEMO = '''
def factorial(n):
fact = 1
for i in range(n):
fact = fact * (i + 1)
return fact
'''.strip()
pytree = ast.parse(PYDEMO)
ctree = pycparser.CParser().parse(CDEMO)
pprint.pprint(py2ir(pytree))
print()
print(unparse(c2ir(ctree)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment