Skip to content

Instantly share code, notes, and snippets.

@BachiLi
Created June 16, 2020 04:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save BachiLi/30c7ada4fcba62fe220fd29597a31407 to your computer and use it in GitHub Desktop.
Save BachiLi/30c7ada4fcba62fe220fd29597a31407 to your computer and use it in GitHub Desktop.
from adt import ADT
from adt import memo as ADTMemo
import ast
import inspect
# Define the grammar
cohe = ADT("""
module cohe {
float_expr = FloatConst ( float val )
| FloatAdd ( float_expr lhs, float_expr rhs )
int_expr = IntConst ( int val )
| ThreadIdx ( )
stmt = FloatAddTo ( string target, int_expr index, float_expr e )
block = Block ( stmt* s )
kernel = Kernel ( string name, string* in_args, string* out_args, block body )
}
""")
# Inject memoization
ADTMemo(cohe, ['FloatConst',
'FloatAdd',
'IntConst',
'ThreadIdx',
'FloatAddTo',
'Block',
'Kernel'])
# Utilities
def is_const(expr):
return isinstance(expr, cohe.FloatConst) or \
isinstance(expr, cohe.IntConst) or \
isinstance(expr, cohe.ThreadIdx)
def get_const(expr):
if isinstance(expr, cohe.FloatConst) or \
isinstance(expr, cohe.IntConst):
return expr.val
elif isinstance(expr, cohe.ThreadIdx):
return 'thread_index'
class Input: pass
class Output: pass
def parse(kernel):
"""
Given a Python function kernel, parse it to a cohe AST.
"""
# kernel_globals = kernel.__globals__
def visit_FunctionDef(node):
kernel_name = node.name
args = node.args
assert(args.vararg is None)
assert(args.kwarg is None)
args = args.args
# Check & load the arguments
inputs = []
outputs = []
for arg in args:
assert(arg.annotation is not None)
if arg.annotation.id == 'Input':
inputs.append(arg.arg)
elif arg.annotation.id == 'Output':
outputs.append(arg.arg)
else:
assert(False)
body = []
for b in node.body:
s = visit_stmt(b)
if s is not None:
body.append(s)
return cohe.Kernel(kernel_name, inputs, outputs, cohe.Block(body))
def visit_expr(node):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
name = node.func.id
elif isinstance(node.func, ast.Attribute):
name = node.func.attr
else:
assert False, f'Unknown Call node function {type(node.func).__name__}'
if name == 'ThreadIdx':
return cohe.ThreadIdx()
else:
assert False, 'Unimplement function call'
elif isinstance(node, ast.Num):
if isinstance(node.n, int):
return cohe.IntConst(node.n)
elif isinstance(node.n, float):
return cohe.FloatConst(node.n)
else:
assert False, f'Unknown Num.n {type(node.n)}'
else:
assert False, f'Unknown expr {type(node).__name__}'
def visit_lhs(node):
if isinstance(node, ast.Subscript):
assert isinstance(node.slice, ast.Index)
return node.value.id, visit_expr(node.slice.value)
else:
assert False, f'Unknown left hand side {type(node).__name__}'
def visit_stmt(node):
if isinstance(node, ast.AugAssign):
target, index = visit_lhs(node.target)
assert isinstance(node.op, ast.Add), 'Only += is supported'
value = visit_expr(node.value)
return cohe.FloatAddTo(target, index, value)
else:
assert False, f'Unknown statement {type(node).__name__}'
module = ast.parse(inspect.getsource(kernel))
assert(len(module.body) == 1)
assert(type(module.body[0]) == ast.FunctionDef)
return visit_FunctionDef(module.body[0])
# Codegen
class Codegen:
"""
cohe AST to ispc code
"""
def __init__(self):
self.expr_dict = {}
self.code = ''
self.tab_count = 0
def get_handle(self, expr):
if is_const(expr):
return get_const(expr)
else:
return self.expr_dict[expr]
def emit_tabs(self):
self.code += '\t' * self.tab_count
def emit_expr(self, expr):
if isinstance(expr, cohe.FloatConst) or \
isinstance(expr, cohe.IntConst) or \
isinstance(expr, cohe.ThreadIdx):
# Const is always inlined to the expression
pass
elif isinstance(expr, cohe.FloatAdd):
if expr in self.expr_dict:
# Skip generated exprs
pass
self.emit_expr(expr.lhs)
self.emit_expr(expr.rhs)
lhs = self.get_handle(expr.lhs)
rhs = self.get_handle(expr.rhs)
expr_id = len(self.expr_dict)
self.expr_dict[expr] = expr_id
self.emit_tabs()
self.code += f'float _t{expr_id} = {lhs} + {rhs};\n'
def emit_stmt(self, stmt):
assert(isinstance(stmt, cohe.stmt))
assert(isinstance(stmt, cohe.FloatAddTo))
self.emit_expr(stmt.e)
self.emit_expr(stmt.index)
self.emit_tabs()
e = self.get_handle(stmt.e)
index = self.get_handle(stmt.index)
self.code += f'{stmt.target}[{index}] += {e};\n'
def emit_block(self, block):
assert(isinstance(block, cohe.block))
for stmt in block.s:
self.emit_stmt(stmt)
def emit_kernel(self, kernel):
assert(isinstance(kernel, cohe.kernel))
assert(len(kernel.in_args + kernel.out_args) > 0)
self.code += f'void {kernel.name}('
for i, arg in enumerate(kernel.in_args + kernel.out_args):
if i > 0:
self.code += ', '
self.code += 'uniform float *' + arg
self.code += ', uniform int num_threads'
self.code += ') {\n'
self.tab_count += 1
self.emit_tabs()
self.code += 'foreach (thread_index = 0 ... num_threads) {\n'
self.tab_count += 1
self.emit_block(kernel.body)
self.tab_count -= 1
self.emit_tabs()
self.code += '}\n'
self.tab_count -= 1
self.code += '}\n'
# Program
def foo(out: Output):
out[cohe.ThreadIdx()] += 1.0
prog = parse(foo)
cg = Codegen()
cg.emit_kernel(prog)
print(cg.code)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment