Last active
October 23, 2023 15:03
-
-
Save impiaaa/70e650c5cc34e59d69e9cf0f9abfeb54 to your computer and use it in GitHub Desktop.
Demonstrating conversion from Python to LLVM IR
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dis, inspect, typing | |
import math, sys | |
import os | |
from llvmlite import ir | |
#def putchar(n: int) -> int: pass | |
def fact(n: int) -> int: | |
if n <= 1: return 1 | |
else: return n*fact(n-1) | |
def test1(n: int): | |
if n == 1: putchar(49) | |
elif n == 2: putchar(50) | |
putchar(63) | |
def test2(n: int): | |
i = 33 | |
while i < n: | |
putchar(i) | |
i = i+1 | |
def test3(n: int) -> int: | |
x = 14 if n%2 == 0 else 0 | |
x = x+2 | |
return x | |
alwaysjump = {dis.opmap['JUMP_FORWARD'], dis.opmap['JUMP_ABSOLUTE']} | |
hasjump = set(dis.hasjrel+dis.hasjabs) | |
def splitIntoBasicBlocks(instructions): | |
currentBlock = [] | |
for inst in instructions: | |
if inst.is_jump_target: | |
if len(currentBlock) > 0: yield currentBlock | |
currentBlock = [] | |
currentBlock.append(inst) | |
if inst.opcode in hasjump or inst.opname == 'RETURN_VALUE': | |
if len(currentBlock) > 0: yield currentBlock | |
currentBlock = [] | |
if len(currentBlock) > 0: yield currentBlock | |
def findSourcesDestinations(basicBlocks): | |
destinations = [None]*len(basicBlocks) | |
sources = [set() for block in basicBlocks] | |
for i, blockInsts in enumerate(basicBlocks): | |
nextDest = None | |
instDest = None | |
if i < len(basicBlocks)-1 and blockInsts[-1].opcode not in alwaysjump: | |
nextDest = i+1 | |
sources[i+1].add(i) | |
if blockInsts[-1].opcode in hasjump: | |
addr = blockInsts[-1].argval | |
for j, otherBlockInsts in enumerate(basicBlocks): | |
if otherBlockInsts[0].offset == addr: | |
instDest = j | |
sources[j].add(i) | |
break | |
destinations[i] = (instDest, nextDest) | |
return list([(sources[i], destinations[i]) for i in range(len(basicBlocks))]) | |
def printInstruction(inst, rwidth): | |
print(inst.opname, end='') | |
if inst.arg is None: | |
print() | |
else: | |
print(str(inst.arg).rjust(rwidth-len(inst.opname)), end='') | |
if len(inst.argrepr) == 0: | |
print() | |
else: | |
print(' (%s)'%inst.argrepr) | |
def printInstructions(insts): | |
maxwidth = max([len(inst.opname) for inst in insts])+max([len(str(inst.arg)) for inst in insts])+1 | |
for inst in insts: printInstruction(inst, maxwidth) | |
NONE_TYPE = ir.PointerType(ir.IntType(8)) | |
# llvmlite asserts when an instruction is given a void value, but we need a | |
# placeholder in order to handle phi instructions, so we use a 0-width integer | |
DUMMY_TYPE = ir.IntType(0) | |
#def isDummy(t): return isinstance(t, ir.IntType) and t.width == 0 | |
def isDummy(t): return t == DUMMY_TYPE | |
typeMap = {int: ir.IntType(int(math.log2(sys.maxsize+1)+1)), | |
float: ir.DoubleType(), | |
bool: ir.IntType(1), | |
type(None): NONE_TYPE} | |
def translateBlock(insts, instDest, nextDest, block, names): | |
stackDeficit = [] | |
stack = [] | |
builder = ir.IRBuilder(block) | |
def pop(): | |
if len(stack) > 0: | |
return stack.pop() | |
else: | |
arg = builder.phi(argtype) | |
stackDeficit.append(arg) | |
return arg | |
line = None | |
for inst in insts: | |
#print(inst) | |
#print(block) | |
#print(stack) | |
if inst.starts_line is not None: | |
builder.debug_metadata = builder.module.add_debug_info('DILocation', | |
{'scope': scope, 'line': inst.starts_line}) | |
argtype = typeMap.get(type(inst.argval), DUMMY_TYPE) | |
if inst.opname == 'LOAD_CONST': | |
stack.append(ir.Constant(argtype, inst.argval)) | |
elif inst.opname == 'STORE_FAST': | |
val = pop() | |
n = names.get(inst.argval, None) | |
if n is None: | |
n = builder.alloca(val.type, name=inst.argval) | |
names[inst.argval] = n | |
if val.type != n.type.pointee: | |
if isDummy(val.type): | |
val.type = n.type.pointee | |
elif isDummy(n.type.pointee): | |
n.type.pointee = val.type | |
builder.store(val, n) | |
elif inst.opname == 'LOAD_FAST': | |
stack.append(builder.load(names[inst.argval])) | |
elif inst.opcode in alwaysjump: | |
builder.branch(instDest) | |
elif inst.opcode in hasjump: | |
if 'FALSE' in inst.opname: | |
builder.cbranch(pop(), nextDest, instDest) | |
else: | |
builder.cbranch(pop(), instDest, nextDest) | |
elif inst.opname == 'LOAD_GLOBAL': | |
stack.append(names[inst.argval]) | |
elif inst.opname == 'CALL_FUNCTION': | |
args = [pop() for i in range(inst.arg)] | |
slots = [builder.alloca(arg.type) for arg in args] | |
for i in range(inst.arg): | |
builder.store(args[i], slots[i]) | |
stack.append(builder.call(pop(), slots)) | |
elif inst.opname == 'POP_TOP': | |
stack.pop() | |
elif inst.opname == 'RETURN_VALUE': | |
val = pop() | |
if isinstance(builder.function.return_value.type, ir.VoidType): | |
assert val.type == NONE_TYPE | |
builder.ret_void() | |
elif val.type != NONE_TYPE: | |
builder.ret(val) | |
elif inst.opname == 'COMPARE_OP': | |
a, b = pop(), pop() | |
stack.append(builder.icmp_signed(dis.cmp_op[inst.arg], b, a)) | |
elif inst.opname.startswith('BINARY_'): | |
a, b = pop(), pop() | |
if a.type != b.type: | |
if isinstance(a.type, ir.IntType) and a.type.width == 0: | |
a.type = b.type | |
elif isinstance(b.type, ir.IntType) and b.type.width == 0: | |
b.type = a.type | |
stack.append({ | |
'BINARY_MULTIPLY': builder.mul, | |
'BINARY_MODULO': builder.srem, | |
'BINARY_ADD': builder.add, | |
'BINARY_SUBTRACT': builder.sub, | |
'BINARY_SUBSCR': builder.extract_value, | |
'BINARY_FLOOR_DIVIDE': builder.sdiv, | |
'BINARY_TRUE_DIVIDE': builder.fdiv, | |
'BINARY_LSHIFT': builder.shl, | |
'BINARY_RSHIFT': builder.lshr, | |
'BINARY_AND': builder.and_, | |
'BINARY_XOR': builder.xor, | |
'BINARY_OR': builder.or_ | |
}[inst.opname](b, a)) | |
else: | |
raise ValueError(inst.opname) | |
if insts[-1].opcode not in hasjump and insts[-1].opname != 'RETURN_VALUE': | |
builder.branch(nextDest) | |
return stackDeficit, stack | |
module = ir.Module(name=__file__) | |
# forward declaration | |
putchar = ir.Function(module, ir.FunctionType(ir.IntType(32), [ir.PointerType(ir.IntType(64))]), name="putchar") | |
# magic to get llvm to pay attention to debug info | |
it = ir.IntType(32) | |
mf = module.add_named_metadata('llvm.module.flags') | |
mf.add(module.add_metadata([ir.Constant(it, 2), "Debug Info Version", ir.Constant(it, 3)])) | |
fi = module.add_debug_info('DIFile', { | |
'filename': os.path.basename(__file__), | |
'directory': os.path.dirname(os.path.abspath(__file__)) | |
}) | |
cu = module.add_debug_info('DICompileUnit', { | |
'file': fi, | |
'language': ir.DIToken('DW_LANG_Python'), | |
'emissionKind': ir.DIToken('FullDebug') | |
}, is_distinct=True) | |
module.add_named_metadata('llvm.dbg.cu', cu) | |
ti = module.add_debug_info('DISubroutineType', {'types': module.add_metadata([])}) | |
for func in (fact, test1, test2, test3): | |
blocks = list(splitIntoBasicBlocks(dis.get_instructions(func))) | |
srcdest = findSourcesDestinations(blocks) | |
print(func.__name__) | |
for i, (block, (sources, destinations)) in enumerate(zip(blocks, srcdest)): | |
print(i, 'comes from', ', '.join(map(str, sources))) | |
printInstructions(block) | |
print('goes to', ', '.join(map(str, destinations))) | |
print() | |
sig = inspect.signature(func) | |
types = typing.get_type_hints(func) | |
fnty = ir.FunctionType(typeMap.get(sig.return_annotation, ir.VoidType()), | |
[ir.PointerType(typeMap.get(types.get(parm, None), ir.VoidType())) for parm in sig.parameters]) | |
scope = module.add_debug_info('DISubprogram', { | |
'name': func.__name__, | |
'unit': cu, | |
'scope': fi, | |
'file': fi, | |
'type': ti | |
}, is_distinct=True) | |
irFunc = ir.Function(module, fnty, name=func.__name__) | |
irFunc.set_metadata('dbg', scope) | |
for i, parm in enumerate(sig.parameters): irFunc.args[i].name = parm | |
irBlocks = [irFunc.append_basic_block() for block in blocks] | |
names = {parm: irArg for parm, irArg in zip(sig.parameters, irFunc.args)} | |
names[func.__name__] = irFunc | |
names["putchar"] = putchar | |
assembly = [translateBlock(block, | |
None if instDest is None else irBlocks[instDest], | |
None if nextDest is None else irBlocks[nextDest], | |
irBlock, | |
names) | |
for block, irBlock, (sources, (instDest, nextDest)) in zip(blocks, irBlocks, srcdest)] | |
for i in range(len(irBlocks)-1, 0, -1): | |
if len(irFunc.basic_blocks[i].instructions) == 0: | |
del irFunc.basic_blocks[i] | |
del irBlocks[i] | |
del assembly[i] | |
del srcdest[i] | |
for (stackDeficit, stackSurplus), (sources, (instDest, nextDest)) in zip(assembly, srcdest): | |
for sourceIdx in sources: | |
for stackIdx, phi in enumerate(stackDeficit): | |
phi.add_incoming(assembly[sourceIdx][1][stackIdx], irBlocks[sourceIdx]) | |
if isinstance(phi.type, ir.IntType) and phi.type.width == 0: | |
phi.type = assembly[sourceIdx][1][stackIdx].type | |
print(module) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment