Skip to content

Instantly share code, notes, and snippets.

@impiaaa
Last active October 23, 2023 15:03
Show Gist options
  • Save impiaaa/70e650c5cc34e59d69e9cf0f9abfeb54 to your computer and use it in GitHub Desktop.
Save impiaaa/70e650c5cc34e59d69e9cf0f9abfeb54 to your computer and use it in GitHub Desktop.
Demonstrating conversion from Python to LLVM IR
import dis, inspect, typing
import math, sys
import os
from llvmlite import ir
#def putchar(n: int) -> int: pass
def fact(n: int) -> int:
if n <= 1: return 1
else: return n*fact(n-1)
def test1(n: int):
if n == 1: putchar(49)
elif n == 2: putchar(50)
putchar(63)
def test2(n: int):
i = 33
while i < n:
putchar(i)
i = i+1
def test3(n: int) -> int:
x = 14 if n%2 == 0 else 0
x = x+2
return x
alwaysjump = {dis.opmap['JUMP_FORWARD'], dis.opmap['JUMP_ABSOLUTE']}
hasjump = set(dis.hasjrel+dis.hasjabs)
def splitIntoBasicBlocks(instructions):
currentBlock = []
for inst in instructions:
if inst.is_jump_target:
if len(currentBlock) > 0: yield currentBlock
currentBlock = []
currentBlock.append(inst)
if inst.opcode in hasjump or inst.opname == 'RETURN_VALUE':
if len(currentBlock) > 0: yield currentBlock
currentBlock = []
if len(currentBlock) > 0: yield currentBlock
def findSourcesDestinations(basicBlocks):
destinations = [None]*len(basicBlocks)
sources = [set() for block in basicBlocks]
for i, blockInsts in enumerate(basicBlocks):
nextDest = None
instDest = None
if i < len(basicBlocks)-1 and blockInsts[-1].opcode not in alwaysjump:
nextDest = i+1
sources[i+1].add(i)
if blockInsts[-1].opcode in hasjump:
addr = blockInsts[-1].argval
for j, otherBlockInsts in enumerate(basicBlocks):
if otherBlockInsts[0].offset == addr:
instDest = j
sources[j].add(i)
break
destinations[i] = (instDest, nextDest)
return list([(sources[i], destinations[i]) for i in range(len(basicBlocks))])
def printInstruction(inst, rwidth):
print(inst.opname, end='')
if inst.arg is None:
print()
else:
print(str(inst.arg).rjust(rwidth-len(inst.opname)), end='')
if len(inst.argrepr) == 0:
print()
else:
print(' (%s)'%inst.argrepr)
def printInstructions(insts):
maxwidth = max([len(inst.opname) for inst in insts])+max([len(str(inst.arg)) for inst in insts])+1
for inst in insts: printInstruction(inst, maxwidth)
NONE_TYPE = ir.PointerType(ir.IntType(8))
# llvmlite asserts when an instruction is given a void value, but we need a
# placeholder in order to handle phi instructions, so we use a 0-width integer
DUMMY_TYPE = ir.IntType(0)
#def isDummy(t): return isinstance(t, ir.IntType) and t.width == 0
def isDummy(t): return t == DUMMY_TYPE
typeMap = {int: ir.IntType(int(math.log2(sys.maxsize+1)+1)),
float: ir.DoubleType(),
bool: ir.IntType(1),
type(None): NONE_TYPE}
def translateBlock(insts, instDest, nextDest, block, names):
stackDeficit = []
stack = []
builder = ir.IRBuilder(block)
def pop():
if len(stack) > 0:
return stack.pop()
else:
arg = builder.phi(argtype)
stackDeficit.append(arg)
return arg
line = None
for inst in insts:
#print(inst)
#print(block)
#print(stack)
if inst.starts_line is not None:
builder.debug_metadata = builder.module.add_debug_info('DILocation',
{'scope': scope, 'line': inst.starts_line})
argtype = typeMap.get(type(inst.argval), DUMMY_TYPE)
if inst.opname == 'LOAD_CONST':
stack.append(ir.Constant(argtype, inst.argval))
elif inst.opname == 'STORE_FAST':
val = pop()
n = names.get(inst.argval, None)
if n is None:
n = builder.alloca(val.type, name=inst.argval)
names[inst.argval] = n
if val.type != n.type.pointee:
if isDummy(val.type):
val.type = n.type.pointee
elif isDummy(n.type.pointee):
n.type.pointee = val.type
builder.store(val, n)
elif inst.opname == 'LOAD_FAST':
stack.append(builder.load(names[inst.argval]))
elif inst.opcode in alwaysjump:
builder.branch(instDest)
elif inst.opcode in hasjump:
if 'FALSE' in inst.opname:
builder.cbranch(pop(), nextDest, instDest)
else:
builder.cbranch(pop(), instDest, nextDest)
elif inst.opname == 'LOAD_GLOBAL':
stack.append(names[inst.argval])
elif inst.opname == 'CALL_FUNCTION':
args = [pop() for i in range(inst.arg)]
slots = [builder.alloca(arg.type) for arg in args]
for i in range(inst.arg):
builder.store(args[i], slots[i])
stack.append(builder.call(pop(), slots))
elif inst.opname == 'POP_TOP':
stack.pop()
elif inst.opname == 'RETURN_VALUE':
val = pop()
if isinstance(builder.function.return_value.type, ir.VoidType):
assert val.type == NONE_TYPE
builder.ret_void()
elif val.type != NONE_TYPE:
builder.ret(val)
elif inst.opname == 'COMPARE_OP':
a, b = pop(), pop()
stack.append(builder.icmp_signed(dis.cmp_op[inst.arg], b, a))
elif inst.opname.startswith('BINARY_'):
a, b = pop(), pop()
if a.type != b.type:
if isinstance(a.type, ir.IntType) and a.type.width == 0:
a.type = b.type
elif isinstance(b.type, ir.IntType) and b.type.width == 0:
b.type = a.type
stack.append({
'BINARY_MULTIPLY': builder.mul,
'BINARY_MODULO': builder.srem,
'BINARY_ADD': builder.add,
'BINARY_SUBTRACT': builder.sub,
'BINARY_SUBSCR': builder.extract_value,
'BINARY_FLOOR_DIVIDE': builder.sdiv,
'BINARY_TRUE_DIVIDE': builder.fdiv,
'BINARY_LSHIFT': builder.shl,
'BINARY_RSHIFT': builder.lshr,
'BINARY_AND': builder.and_,
'BINARY_XOR': builder.xor,
'BINARY_OR': builder.or_
}[inst.opname](b, a))
else:
raise ValueError(inst.opname)
if insts[-1].opcode not in hasjump and insts[-1].opname != 'RETURN_VALUE':
builder.branch(nextDest)
return stackDeficit, stack
module = ir.Module(name=__file__)
# forward declaration
putchar = ir.Function(module, ir.FunctionType(ir.IntType(32), [ir.PointerType(ir.IntType(64))]), name="putchar")
# magic to get llvm to pay attention to debug info
it = ir.IntType(32)
mf = module.add_named_metadata('llvm.module.flags')
mf.add(module.add_metadata([ir.Constant(it, 2), "Debug Info Version", ir.Constant(it, 3)]))
fi = module.add_debug_info('DIFile', {
'filename': os.path.basename(__file__),
'directory': os.path.dirname(os.path.abspath(__file__))
})
cu = module.add_debug_info('DICompileUnit', {
'file': fi,
'language': ir.DIToken('DW_LANG_Python'),
'emissionKind': ir.DIToken('FullDebug')
}, is_distinct=True)
module.add_named_metadata('llvm.dbg.cu', cu)
ti = module.add_debug_info('DISubroutineType', {'types': module.add_metadata([])})
for func in (fact, test1, test2, test3):
blocks = list(splitIntoBasicBlocks(dis.get_instructions(func)))
srcdest = findSourcesDestinations(blocks)
print(func.__name__)
for i, (block, (sources, destinations)) in enumerate(zip(blocks, srcdest)):
print(i, 'comes from', ', '.join(map(str, sources)))
printInstructions(block)
print('goes to', ', '.join(map(str, destinations)))
print()
sig = inspect.signature(func)
types = typing.get_type_hints(func)
fnty = ir.FunctionType(typeMap.get(sig.return_annotation, ir.VoidType()),
[ir.PointerType(typeMap.get(types.get(parm, None), ir.VoidType())) for parm in sig.parameters])
scope = module.add_debug_info('DISubprogram', {
'name': func.__name__,
'unit': cu,
'scope': fi,
'file': fi,
'type': ti
}, is_distinct=True)
irFunc = ir.Function(module, fnty, name=func.__name__)
irFunc.set_metadata('dbg', scope)
for i, parm in enumerate(sig.parameters): irFunc.args[i].name = parm
irBlocks = [irFunc.append_basic_block() for block in blocks]
names = {parm: irArg for parm, irArg in zip(sig.parameters, irFunc.args)}
names[func.__name__] = irFunc
names["putchar"] = putchar
assembly = [translateBlock(block,
None if instDest is None else irBlocks[instDest],
None if nextDest is None else irBlocks[nextDest],
irBlock,
names)
for block, irBlock, (sources, (instDest, nextDest)) in zip(blocks, irBlocks, srcdest)]
for i in range(len(irBlocks)-1, 0, -1):
if len(irFunc.basic_blocks[i].instructions) == 0:
del irFunc.basic_blocks[i]
del irBlocks[i]
del assembly[i]
del srcdest[i]
for (stackDeficit, stackSurplus), (sources, (instDest, nextDest)) in zip(assembly, srcdest):
for sourceIdx in sources:
for stackIdx, phi in enumerate(stackDeficit):
phi.add_incoming(assembly[sourceIdx][1][stackIdx], irBlocks[sourceIdx])
if isinstance(phi.type, ir.IntType) and phi.type.width == 0:
phi.type = assembly[sourceIdx][1][stackIdx].type
print(module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment