Skip to content

Instantly share code, notes, and snippets.

@jmikkola
Created December 22, 2019 16:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jmikkola/eeb18e8ae590db955d325a3cb74a5a70 to your computer and use it in GitHub Desktop.
Save jmikkola/eeb18e8ae590db955d325a3cb74a5a70 to your computer and use it in GitHub Desktop.
import re
from ctypes import c_long, CFUNCTYPE
import llvmlite.binding as llvm
from llvmlite import ir
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
# If quoted strings are supported, add:
# |"(?:[^"\\]|\\.)*"
# right after the pattern for comments
token_re = re.compile(r'([(]|[)]|;;[^\n]*\n|[^\s()]+|\s+|\n)')
int_re = re.compile(r'^\d+$')
int64 = ir.IntType(64)
voidptr = ir.IntType(8).as_pointer()
def make_int(i):
return int64(int(i))
def make_fn_type(n_args):
arg_types = tuple([int64] * n_args)
return ir.FunctionType(int64, arg_types)
def tokenize(text):
tokens = token_re.findall(text)
return [t for t in tokens if not t.isspace()]
def parse(tokens):
stack = ([], None)
for t in tokens:
if t == '(':
stack = ([], stack)
elif t == ')':
(finished_list, stack) = stack
stack[0].append(finished_list)
elif not t.startswith(';;'):
stack[0].append(t)
return stack[0]
class Matcher:
def __init__(self, ast):
self.ast = ast
self.match = None
def matches(self, pattern):
self.match = pattern(self.ast)
return self.match is not None
def int_pattern(ast):
if isinstance(ast, str) and int_re.match(ast):
return [int(ast)]
return None
def atom_pattern(ast):
if isinstance(ast, str):
return [ast]
return None
def match_definition(ast, definition):
if isinstance(definition, str):
if definition == '$':
return [ast]
elif definition == '$l':
if isinstance(ast, list):
return [ast]
return None
elif definition == '$a':
if isinstance(ast, str):
return [ast]
return None
else:
if ast == definition:
return [] # Don't bother including literals in the results
else:
return None
if isinstance(definition, set):
if ast in definition:
return [ast]
return None
else:
if isinstance(ast, str):
return None
if definition[-1] != '*' and len(definition) != len(ast):
return None
results = []
for (node, defn) in zip(ast, definition):
if defn == '*':
break
result = match_definition(node, defn)
if result is None:
return None
results.extend(result)
if definition[-1] == '*':
results.append(ast[len(definition)-1:])
return results
def pattern(definition):
return lambda ast: match_definition(ast, definition)
binary_operators = set('+ - * / == > <'.split())
def indent(depth):
return ' ' * depth
def print_indent(s, depth):
if False:
print(indent(depth) + s)
class Context:
def __init__(self, module):
self.module = module
self.printf = None
self.global_fmt = None
self.builder = None
self.block = None
def get_function(self, name):
for f in self.module.functions:
if f.name == name:
return f
def recognize_top(ast, context, depth=0):
m = Matcher(ast)
if m.matches(pattern(['var', '$a', '$'])):
print_indent('global var {} {}'.format(*m.match), depth)
recognize(m.match[1], context, {}, depth+1)
elif m.matches(pattern(['defn', '$a', '$l', '$'])):
print_indent('function {}({}) = {}'.format(*m.match), depth)
name, args, body = m.match
func = ir.Function(context.module, make_fn_type(len(args)), name=name)
scope = {name: arg for (name, arg) in zip(args, func.args)}
block = func.append_basic_block(name='entry')
context.block = block
builder = ir.IRBuilder(block)
context.builder = builder
result = recognize(m.match[2], context, scope, depth+1)
builder.ret(result)
context.builder = None
context.block = None
else:
print('invalid top-level expression')
def recognize(ast, context, scope, depth=0):
m = Matcher(ast)
if m.matches(int_pattern):
num = int(m.match[0])
print_indent('int: {}'.format(num), depth)
return make_int(num)
elif m.matches(atom_pattern):
atom = m.match[0]
print_indent('var: {}'.format(atom), depth)
if atom in scope:
return scope[atom]
return context.module.get_global(atom)
elif m.matches(pattern([binary_operators, '$', '$'])):
print_indent('bin_op: {} {} {}'.format(m.match[0], m.match[1], m.match[2]), depth)
left = recognize(m.match[1], context, scope, depth+1)
right = recognize(m.match[2], context, scope, depth+1)
op = m.match[0]
if op == '+':
return context.builder.add(left, right)
elif op == '-':
return context.builder.sub(left, right)
elif op == '*':
return context.builder.mul(left, right)
elif op == '/':
return context.builder.div(left, right)
elif op in ('<', '==', '>'):
return context.builder.icmp_signed(op, left, right)
else:
print('todo, handle op ' + op)
elif m.matches(pattern(['if', '$', '$', '$'])):
print_indent('if2 {} {} {}'.format(*m.match), depth)
test = recognize(m.match[0], context, scope, depth+1)
with context.builder.if_else(test) as (then, otherwise):
with then:
then_block = context.builder.block
then_result = recognize(m.match[1], context, scope, depth+1)
with otherwise:
else_block = context.builder.block
else_result = recognize(m.match[2], context, scope, depth+1)
result = context.builder.phi(int64, name="ifresult")
result.add_incoming(then_result, then_block)
result.add_incoming(else_result, else_block)
return result
elif m.matches(pattern(['print', '$'])):
print_indent('print {}'.format(m.match[0]), depth)
val = recognize(m.match[0], context, scope, depth+1)
fmt_arg = context.builder.bitcast(context.global_fmt, voidptr)
context.builder.call(context.printf, [fmt_arg, val])
return make_int(0)
elif m.matches(pattern(['do', '*'])):
print_indent('do {}'.format(m.match[0]), depth)
result = None
for stmt in m.match[0]:
result = recognize(stmt, context, scope, depth+1)
if result is None:
result = make_int(0)
return result
elif m.matches(pattern(['$a', '*'])):
print_indent('call {} with {}'.format(*m.match), depth)
fn_name, args = m.match
function = context.get_function(fn_name)
if len(function.args) != len(args):
raise Exception('wrong number of args for ' + fn_name)
arg_values = [recognize(arg, context, scope, depth+1) for arg in args]
return context.builder.call(function, arg_values, 'calltmp')
else:
print_indent("todo: {}".format(ast), depth)
def build(text):
module = ir.Module(name='testmodule')
context = Context(module=module)
print_fmt = "%d\n\0"
c_print_fmt = ir.Constant(
ir.ArrayType(ir.IntType(8), len(print_fmt)),
bytearray(print_fmt.encode("utf8")),
)
global_fmt = ir.GlobalVariable(module, c_print_fmt.type, name="print_fmt")
global_fmt.linkage = 'internal'
global_fmt.global_constant = True
global_fmt.initializer = c_print_fmt
context.global_fmt = global_fmt
printf_ty = ir.FunctionType(ir.IntType(32), [voidptr], var_arg=True)
printf = ir.Function(module, printf_ty, name="printf")
context.printf = printf
parsed = parse(tokenize(text))
for ast in parsed:
recognize_top(ast, context)
# print()
print(module)
return module
def create_execution_engine():
target = llvm.Target.from_default_triple()
target_machine = target.create_target_machine()
backing_mod = llvm.parse_assembly('')
engine = llvm.create_mcjit_compiler(backing_mod, target_machine)
return engine
def compile_ir(engine, llvm_ir):
module = llvm.parse_assembly(llvm_ir)
module.verify()
engine.add_module(module)
engine.finalize_object()
engine.run_static_constructors()
module = build('''
(defn add (a b)
(+ a b))
(defn fib (n)
(if (< n 3)
1
(+ (fib (- n 1)) (fib (- n 2)))))
(defn main ()
(do
(print (add 1 2))
(print (fib 37))))
''')
engine = create_execution_engine()
compile_ir(engine, str(module))
func_ptr = engine.get_function_address("main")
main = CFUNCTYPE(c_long)(func_ptr)
main()
# target triple = "x86_64-pc-linux-gnu"
# @.str = private unnamed_addr constant [4 x i8] c"%d\0A\00", align 1
# declare i32 @printf(i8*, ...) #1
# %1 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @.str, i32 0, i32 0), i32 12345)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment