Skip to content

Instantly share code, notes, and snippets.

@danielwaterworth
Created December 7, 2012 18:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save danielwaterworth/4235457 to your computer and use it in GitHub Desktop.
Save danielwaterworth/4235457 to your computer and use it in GitHub Desktop.
idris bytecode interpreter
import sys, os
trace = False#True
debug = False#True
step = False#True
class Data(object):
pass
class Type(Data):
def __init__(self, name):
self.name = name
def __eq__(self, other):
return isinstance(other, Type) and other.name == self.name
def __repr__(self):
return self.name
class Constructor(Data):
def __init__(self, tag, args):
self.tag = tag
self.args = args
def __eq__(self, other):
if isinstance(other, Constructor) and self.tag == other.tag:
if not len(self.args) == len(other.args):
return False
for i in xrange(len(self.args)):
a = self.args[i]
b = other.args[i]
if not a.__eq__(b):
return False
return True
else:
return False
def __repr__(self):
args = []
for arg in self.args:
args.append(arg.__repr__())
return "Con %d %s" % (self.tag, '[' + ', '.join(args) + ']')
class Int(Data):
def __init__(self, n):
self.n = n
def __eq__(self, other):
return isinstance(other, Int) and self.n == other.n
def __repr__(self):
return str(self.n)
class Float(Data):
def __init__(self, f):
self.f = f
def __eq__(self, other):
return isinstance(other, Float) and self.f == other.f
def __repr__(self):
return str(self.f)
class Ptr(Data):
def __init__(self, ptr):
self.ptr = ptr
def __eq__(self, other):
return isinstance(other, Ptr) and self.ptr == other.ptr
def __repr__(self):
return 'PTR'
class VMPtr(Data):
def __init__(self, vm):
self.vm = vm
def __eq__(self, other):
return isinstance(other, VMPtr) and self.vm == other.vm
def __repr__(self):
return 'VMPTR'
class String(Data):
def __init__(self, s):
self.s = s
def __eq__(self, other):
return isinstance(other, String) and self.s == other.s
def __repr__(self):
return '"' + replace(self.s, '\n', '\\n') + '"'
class Unit(Data):
def __repr__(self):
return 'UNIT'
def __eq__(self, other):
return isinstance(other, Unit)
class RegisterRef(object):
pass
class TmpRef(RegisterRef):
def __repr__(self):
return "Tmp"
class RValRef(RegisterRef):
def __repr__(self):
return "RVal"
class LRef(RegisterRef):
def __init__(self, n):
self.n = n
def __repr__(self):
return "L%d" % self.n
class TRef(RegisterRef):
def __init__(self, n):
self.n = n
def __repr__(self):
return "T%d" % self.n
class Operation(object):
pass
class Assign(Operation):
def __init__(self, a, b):
self.a = a
self.b = b
def run(self, vm):
if trace:
print "ASSIGN %s %s" % (self.a.__repr__(), self.b.__repr__())
vm.set_register(self.a, vm.get_register(self.b))
class AssignConst(Operation):
def __init__(self, a, b):
self.a = a
self.b = b
def run(self, vm):
if trace:
print "ASSIGNCONST %s %s" % (self.a.__repr__(), self.b.__repr__())
vm.set_register(self.a, self.b)
class MakeCon(Operation):
def __init__(self, reg, ty, args, pos):
self.reg = reg
self.ty = ty
self.args = args
self.pos = pos
def run(self, vm):
if trace:
print "MAKECON %s %s %s %d" % (self.reg.__repr__(), self.ty, self.args, self.pos)
values = []
for v in self.args:
values.append(vm.get_register(v))
vm.set_register(self.reg, Constructor(self.ty, values))
class Case(Operation):
def __init__(self, reg, cases, default):
self.reg = reg
self.cases = cases
self.default = default
def run(self, vm):
if trace:
print "CASE"
vm.start_case()
reg = vm.get_register(self.reg)
if isinstance(reg, Constructor):
vm.bcs = self.cases.get(reg.tag, self.default)
assert vm.bcs
else:
assert self.default
vm.bcs = self.default
vm.pc = 0
class ConstCase(Operation):
def __init__(self, reg, cases, default):
self.reg = reg
self.cases = cases
self.default = default
def run(self, vm):
if trace:
print "CONSTCASE"
vm.start_case()
reg = vm.get_register(self.reg)
for (value, bcs) in self.cases.iteritems():
if value.__eq__(reg):
assert bcs
vm.bcs = bcs
break
else:
assert self.default
vm.bcs = self.default
vm.pc = 0
class Reserve(Operation):
def __init__(self, n):
self.n = n
def run(self, vm):
if trace:
print "RESERVE %d" % self.n
vm.reserve(self.n)
class AddTop(Operation):
def __init__(self, n):
self.n = n
def run(self, vm):
if trace:
print "ADDTOP %d" % self.n
vm.top += self.n
class TopBase(Operation):
def __init__(self, n):
self.n = n
def run(self, vm):
if trace:
print "TOPBASE %d" % self.n
vm.top = vm.base + self.n
class BaseTop(Operation):
def __init__(self, n):
self.n = n
def run(self, vm):
if trace:
print "BASETOP %d" % self.n
vm.base = vm.top + self.n
class StoreOld(Operation):
def run(self, vm):
if trace:
print "STOREOLD"
vm.new_frame_ptr = vm.base
class Rebase(Operation):
def run(self, vm):
if trace:
print "REBASE"
vm.base = vm.frame_ptr
class Slide(Operation):
def __init__(self, n):
self.n = n
def run(self, vm):
if trace:
print "SLIDE %d" % self.n
for i in xrange(self.n):
vm.set_register(LRef(i), vm.get_register(TRef(i)))
class Project(Operation):
def __init__(self, reg, i, arity):
self.reg = reg
self.i = i
self.arity = arity
def run(self, vm):
if trace:
print "PROJECT %s %d %d" % (self.reg.__repr__(), self.i, self.arity)
value = vm.get_register(self.reg)
for i in xrange(self.arity):
vm.set_register(LRef(i + self.i), value.args[i])
class Op(Operation):
def __init__(self, reg, op, args):
self.reg = reg
self.op = op
self.args = args
def run(self, vm):
if trace:
args = []
for arg in self.args:
args.append(arg.__repr__())
print "OP %s %s %s" % (self.reg.__repr__(), self.op, '[' + ', '.join(args) + ']')
if self.op == 'LBPlus' or self.op == 'LPlus':
assert len(self.args) == 2
values = []
for arg in self.args:
values.append(vm.get_register(arg))
vm.set_register(self.reg, Int(values[0].n + values[1].n))
elif self.op == 'LBMinus' or self.op == 'LMinus':
assert len(self.args) == 2
values = []
for arg in self.args:
values.append(vm.get_register(arg))
vm.set_register(self.reg, Int(values[0].n - values[1].n))
elif self.op == 'LVMPtr':
assert len(self.args) == 0
vm.set_register(self.reg, VMPtr(vm))
elif self.op == 'LEq':
assert len(self.args) == 2
values = []
for arg in self.args:
values.append(vm.get_register(arg))
vm.set_register(self.reg, Int(1 if values[0].__eq__(values[1]) else 0))
elif self.op == 'LStrConcat':
values = []
for arg in self.args:
values.append(vm.get_register(arg).s)
vm.set_register(self.reg, String(''.join(values)))
else:
raise Exception("unknown prim op: %s" % self.op)
class Null(Operation):
def __init__(self, reg):
self.reg = reg
def run(self, vm):
if trace:
print "NULL %s" % self.reg.__repr__()
vm.set_register(self.reg, None)
class Error(Operation):
def __init__(self, err):
self.err = err
def run(self, vm):
raise Exception(self.err)
class Call(Operation):
def __init__(self, fn):
self.fn = fn
def run(self, vm):
if trace:
print "CALL %s" % self.fn
vm.call(self.fn)
class TailCall(Operation):
def __init__(self, fn):
self.fn = fn
def run(self, vm):
if trace:
print "TAILCALL %s" % self.fn
vm.tail_call(self.fn)
class ProjectInto(Operation):
def __init__(self, dst, src, i):
self.dst = dst
self.src = src
self.i = i
def run(self, vm):
if trace:
print "PROJECTINTO %s %s %d" % (self.dst, self.src, self.i)
vm.set_register(self.dst, vm.get_register(self.src).args[self.i])
class ForeignCall(Operation):
def __init__(self, reg, fn, ret, args):
self.reg = reg
self.fn = fn
self.ret = ret
self.args = args
def run(self, vm):
if trace:
args = []
for arg in self.args:
args.append('(' + arg[0].__repr__() + ', '+ arg[1] + ')')
print "FOREIGNCALL %s %s %s %s" % (self.reg.__repr__(), self.fn, self.ret, '[' + ', '.join(args) + ']')
if self.fn == 'putStr':
print vm.get_register(self.args[0][0]).s
elif self.fn == 'idris_numArgs':
vm.set_register(self.reg, Int(1))
elif self.fn == 'idris_getArg':
vm.set_register(self.reg, String("./test"))
else:
raise Exception("unknown foreign call: %s" % self.fn)
def lstrip(s):
for i in xrange(len(s)):
if not s[i].isspace():
return s[i:]
return ''
def rstrip(s):
for i in xrange(len(s)):
n = len(s)-i-1
assert n >= 0
if not s[n].isspace():
return s[:n+1]
return ''
def strip(s):
return lstrip(rstrip(s))
def split(s, separator):
output = []
i = 0
p = 0
while i < len(s):
if s[i:].startswith(separator):
output.append(s[p:i])
i += len(separator)
p = i
else:
i += 1
output.append(s[p:])
return output
def replace(s, b, a):
return a.join(split(s, b))
def readline(fd):
line = []
while True:
s = os.read(fd, 1)
line.append(s)
if not (s and s != '\n'):
break
return ''.join(line)
def get_indent(line):
n = 0
for c in line:
if c != ' ':
return n/4
else:
n += 1
raise Exception("unexpected entirely spaces")
def parse_reg(reg):
if reg == 'RVal':
return RValRef()
elif reg == 'Tmp':
return TmpRef()
elif reg[0] == 'L':
return LRef(int(reg[1:]))
elif reg[0] == 'T':
return TRef(int(reg[1:]))
else:
raise Exception("cannot parse register: %s" % reg)
def parse_string(s):
n = len(s) - 1
assert n >= 0
assert s[0] == '"' and s[n] == '"'
return s[1:n]
def parse_type(s):
return s
def parse_record(r):
v = split(r, ' : ')
assert len(v) == 2
reg, ty = v
return (parse_reg(reg), ty)
def parse_reg_type_list(s):
n = len(s) - 1
assert n >= 0
assert s[0] == '[' and s[n] == ']'
s = s[1:n]
records = []
for v in split(s, ', '):
records.append(parse_record(v))
return records
def isdigits(s):
for v in s:
if not v.isdigit():
return False
return True
def isalnums(s):
for v in s:
if not v.isalnum():
return False
return True
def parse_const(s):
n = len(s) - 1
assert n >= 0
if s[0] == '"' and s[n] == '"':
return String(s[1:n])
elif isdigits(s):
return Int(int(s))
elif isalnums(s):
return Type(s)
else:
raise Exception("unknown const %s" % s)
def parse_reg_list(s):
if s == '[]':
return []
else:
output = []
n = len(s) - 1
assert n >= 1
for v in split(s[1:n], ', '):
output.append(parse_reg(v))
return output
class Clause(object):
pass
class SomeClause(Clause):
def __init__(self, case, bcs):
self.case = case
self.bcs = bcs
class EmptyClause(Clause):
pass
class ConstClause(object):
pass
class SomeConstClause(ConstClause):
def __init__(self, case, bcs):
self.case = case
self.bcs = bcs
class EmptyConstClause(ConstClause):
pass
def parse_clause(indent, fd):
n = os.lseek(fd, 0, os.SEEK_CUR)
line = readline(fd)
case = strip(line)[:-1]
if get_indent(line) == indent and case != 'default':
assert strip(line)[-1] == ':'
return SomeClause(int(case), parse_block(indent + 1, fd))
else:
os.lseek(fd, n, os.SEEK_SET)
return EmptyClause()
def parse_const_clause(indent, fd):
n = os.lseek(fd, 0, os.SEEK_CUR)
line = readline(fd)
case = strip(line)[:-1]
if get_indent(line) == indent and case != 'default':
assert strip(line)[-1] == ':'
return SomeConstClause(parse_const(case), parse_block(indent + 1, fd))
else:
os.lseek(fd, n, os.SEEK_SET)
return EmptyConstClause()
def parse_default_clause(indent, fd):
n = os.lseek(fd, 0, os.SEEK_CUR)
line = readline(fd)
if get_indent(line) == indent and strip(line) == 'default:':
return parse_block(indent + 1, fd)
else:
os.lseek(fd, n, os.SEEK_SET)
return None
def parse_clauses(indent, fd):
clauses = []
while True:
clause = parse_clause(indent, fd)
if isinstance(clause, SomeClause):
clauses.append((clause.case, clause.bcs))
else:
break
d = {}
for name, bcs in clauses:
d[name] = bcs
return (d, parse_default_clause(indent, fd))
def parse_const_clauses(indent, fd):
clauses = []
while True:
clause = parse_const_clause(indent, fd)
if isinstance(clause, SomeConstClause):
clauses.append((clause.case, clause.bcs))
else:
break
d = {}
for name, bcs in clauses:
d[name] = bcs
return (d, parse_default_clause(indent, fd))
def parse_bytecode(indent, fd):
pos = os.lseek(fd, 0, os.SEEK_CUR)
line = rstrip(readline(fd))
if line == '':
os.lseek(fd, pos, os.SEEK_SET)
return None
else:
n = get_indent(line)
if n != indent:
os.lseek(fd, pos, os.SEEK_SET)
return None
line = strip(line)
pieces = line.split(' ')
op = pieces[0]
args = pieces[1:]
if op == 'NULL':
assert len(args) == 1
return Null(parse_reg(args[0]))
elif op == 'MKCON':
n = os.lseek(fd, 0, os.SEEK_CUR)
return MakeCon(parse_reg(args[0]), int(args[1]), parse_reg_list(' '.join(args[2:])), n)
elif op == 'STOREOLD':
assert len(args) == 0
return StoreOld()
elif op == 'REBASE':
assert len(args) == 0
return Rebase()
elif op == 'ASSIGN':
assert len(args) == 2
return Assign(parse_reg(args[0]), parse_reg(args[1]))
elif op == 'PROJECT':
assert len(args) == 3
return Project(parse_reg(args[0]), int(args[1]), int(args[2]))
elif op == 'PROJECTINTO':
assert len(args) == 3
return ProjectInto(parse_reg(args[0]), parse_reg(args[1]), int(args[2]))
elif op == 'RESERVE':
assert len(args) == 1
return Reserve(int(args[0]))
elif op == 'TOPBASE':
assert len(args) == 1
return TopBase(int(args[0]))
elif op == 'BASETOP':
assert len(args) == 1
return BaseTop(int(args[0]))
elif op == 'ADDTOP':
assert len(args) == 1
return AddTop(int(args[0]))
elif op == 'SLIDE':
assert len(args) == 1
return Slide(int(args[0]))
elif op == 'CALL':
return Call(' '.join(args))
elif op == 'TAILCALL':
return TailCall(' '.join(args))
elif op == 'ERROR':
return Error(' '.join(args))
elif op == 'OP':
return Op(parse_reg(args[0]), args[1], parse_reg_list(' '.join(args[2:])))
elif op == 'CASE':
assert len(args) == 1
reg = parse_reg(args[0][:-1])
clauses = parse_clauses(indent + 1, fd)
return Case(reg, clauses[0], clauses[1])
elif op == 'CONSTCASE':
assert len(args) == 1
reg = parse_reg(args[0][:-1])
clauses = parse_const_clauses(indent + 1, fd)
return ConstCase(reg, clauses[0], clauses[1])
elif op == 'ASSIGNCONST':
reg = parse_reg(args[0])
val = parse_const(' '.join(args[1:]))
return AssignConst(reg, val)
elif op == 'FOREIGNCALL':
reg = parse_reg(args[0])
fn = parse_string(args[1]) # FIXME: the string may not be contained within args[1]
ret = parse_type(args[2])
args = parse_reg_type_list(' '.join(args[3:]))
return ForeignCall(reg, fn, ret, args)
else:
raise Exception("unknown operation: %s" % op)
def parse_block(indent, fd):
bcs = []
while True:
bc = parse_bytecode(indent, fd)
if bc:
bcs.append(bc)
else:
return bcs
def parse_declaration(fd):
line = strip(readline(fd))
if len(line) >= 1:
n = len(line)-1
assert n >= 0
name = line[:n]
bcs = parse_block(1, fd)
line = strip(readline(fd))
assert line == ''
return (name, bcs)
else:
return (None, [])
def parse(filename):
fd = os.open(filename, os.O_RDONLY, 0777)
decls = {}
while True:
name, bcs = parse_declaration(fd)
if name:
decls[name] = bcs
else:
break
os.close(fd)
return decls
class StackFrame(object):
pass
class CallFrame(StackFrame):
def __init__(self, bcs, pc, frame_ptr, new_frame_ptr):
self.bcs = bcs
self.pc = pc
self.frame_ptr = frame_ptr
self.new_frame_ptr = new_frame_ptr
class CaseFrame(StackFrame):
def __init__(self, bcs, pc):
self.bcs = bcs
self.pc = pc
class VM(object):
def __init__(self, declarations):
self.tmp = None
self.rval = None
self.stack = []
self.base = 0
self.top = 0
self.frame_ptr = 0
self.new_frame_ptr = 0
self.callstack = []
self.declarations = declarations
self.bcs = declarations["{runMain0}"]
self.pc = 0
def start_case(self):
self.callstack.append(CaseFrame(self.bcs, self.pc))
def ret(self):
if trace:
print "RET"
if self.callstack:
frame = self.callstack.pop()
if isinstance(frame, CallFrame):
self.bcs = frame.bcs
self.pc = frame.pc
self.frame_ptr = frame.frame_ptr
self.new_frame_ptr = frame.new_frame_ptr
elif isinstance(frame, CaseFrame):
self.bcs = frame.bcs
self.pc = frame.pc
else:
raise TypeError("unknown frame type")
return False
else:
return True
def call(self, fn):
self.callstack.append(CallFrame(self.bcs, self.pc, self.frame_ptr, self.new_frame_ptr))
self.frame_ptr = self.new_frame_ptr
self.bcs = self.declarations[fn]
self.pc = 0
def tail_call(self, fn):
self.bcs = self.declarations[fn]
self.pc = 0
def run(self):
n = 0
while True:
n += 1
if self.pc >= len(self.bcs):
if self.ret():
break
else:
bc = self.bcs[self.pc]
self.pc += 1
bc.run(self)
if debug:
print "step: %d" % n
print "tmp: %s" % (self.tmp.__repr__() if self.tmp else 'None')
print "rval: %s" % (self.rval.__repr__() if self.rval else 'None')
stack = []
for s in self.stack:
stack.append(s.__repr__() if s else 'None')
print "stack: %s" % ('[' + ', '.join(stack) + ']')
print "base: %d" % self.base
print "top: %d" % self.top
print "oldbase: %d" % self.frame_ptr
print "myoldbase: %d" % self.new_frame_ptr
def get_register(self, r):
if isinstance(r, RValRef):
return self.rval
elif isinstance(r, TmpRef):
return self.tmp
elif isinstance(r, LRef):
return self.stack[self.base + r.n]
elif isinstance(r, TRef):
return self.stack[self.top + r.n]
else:
raise TypeError("Unknown reference type")
def set_register(self, r, v):
if isinstance(r, RValRef):
self.rval = v
elif isinstance(r, TmpRef):
self.tmp = v
elif isinstance(r, LRef):
self.stack[self.base + r.n] = v
elif isinstance(r, TRef):
self.stack[self.top + r.n] = v
else:
raise TypeError("Unknown reference type")
def reserve(self, n):
while len(self.stack) < self.top + n:
self.stack.append(None)
for i in xrange(n):
self.stack[self.top + i] = None
def entry_point(argv):
try:
filename = argv[1]
except IndexError:
print "You must supply a filename"
return 1
VM(parse(filename)).run()
return 0
def target(*args):
return entry_point, None
if __name__ == "__main__":
entry_point(sys.argv)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment