Created
December 7, 2012 18:48
-
-
Save danielwaterworth/4235457 to your computer and use it in GitHub Desktop.
idris bytecode interpreter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys, os | |
trace = False#True | |
debug = False#True | |
step = False#True | |
class Data(object): | |
pass | |
class Type(Data): | |
def __init__(self, name): | |
self.name = name | |
def __eq__(self, other): | |
return isinstance(other, Type) and other.name == self.name | |
def __repr__(self): | |
return self.name | |
class Constructor(Data): | |
def __init__(self, tag, args): | |
self.tag = tag | |
self.args = args | |
def __eq__(self, other): | |
if isinstance(other, Constructor) and self.tag == other.tag: | |
if not len(self.args) == len(other.args): | |
return False | |
for i in xrange(len(self.args)): | |
a = self.args[i] | |
b = other.args[i] | |
if not a.__eq__(b): | |
return False | |
return True | |
else: | |
return False | |
def __repr__(self): | |
args = [] | |
for arg in self.args: | |
args.append(arg.__repr__()) | |
return "Con %d %s" % (self.tag, '[' + ', '.join(args) + ']') | |
class Int(Data): | |
def __init__(self, n): | |
self.n = n | |
def __eq__(self, other): | |
return isinstance(other, Int) and self.n == other.n | |
def __repr__(self): | |
return str(self.n) | |
class Float(Data): | |
def __init__(self, f): | |
self.f = f | |
def __eq__(self, other): | |
return isinstance(other, Float) and self.f == other.f | |
def __repr__(self): | |
return str(self.f) | |
class Ptr(Data): | |
def __init__(self, ptr): | |
self.ptr = ptr | |
def __eq__(self, other): | |
return isinstance(other, Ptr) and self.ptr == other.ptr | |
def __repr__(self): | |
return 'PTR' | |
class VMPtr(Data): | |
def __init__(self, vm): | |
self.vm = vm | |
def __eq__(self, other): | |
return isinstance(other, VMPtr) and self.vm == other.vm | |
def __repr__(self): | |
return 'VMPTR' | |
class String(Data): | |
def __init__(self, s): | |
self.s = s | |
def __eq__(self, other): | |
return isinstance(other, String) and self.s == other.s | |
def __repr__(self): | |
return '"' + replace(self.s, '\n', '\\n') + '"' | |
class Unit(Data): | |
def __repr__(self): | |
return 'UNIT' | |
def __eq__(self, other): | |
return isinstance(other, Unit) | |
class RegisterRef(object): | |
pass | |
class TmpRef(RegisterRef): | |
def __repr__(self): | |
return "Tmp" | |
class RValRef(RegisterRef): | |
def __repr__(self): | |
return "RVal" | |
class LRef(RegisterRef): | |
def __init__(self, n): | |
self.n = n | |
def __repr__(self): | |
return "L%d" % self.n | |
class TRef(RegisterRef): | |
def __init__(self, n): | |
self.n = n | |
def __repr__(self): | |
return "T%d" % self.n | |
class Operation(object): | |
pass | |
class Assign(Operation): | |
def __init__(self, a, b): | |
self.a = a | |
self.b = b | |
def run(self, vm): | |
if trace: | |
print "ASSIGN %s %s" % (self.a.__repr__(), self.b.__repr__()) | |
vm.set_register(self.a, vm.get_register(self.b)) | |
class AssignConst(Operation): | |
def __init__(self, a, b): | |
self.a = a | |
self.b = b | |
def run(self, vm): | |
if trace: | |
print "ASSIGNCONST %s %s" % (self.a.__repr__(), self.b.__repr__()) | |
vm.set_register(self.a, self.b) | |
class MakeCon(Operation): | |
def __init__(self, reg, ty, args, pos): | |
self.reg = reg | |
self.ty = ty | |
self.args = args | |
self.pos = pos | |
def run(self, vm): | |
if trace: | |
print "MAKECON %s %s %s %d" % (self.reg.__repr__(), self.ty, self.args, self.pos) | |
values = [] | |
for v in self.args: | |
values.append(vm.get_register(v)) | |
vm.set_register(self.reg, Constructor(self.ty, values)) | |
class Case(Operation): | |
def __init__(self, reg, cases, default): | |
self.reg = reg | |
self.cases = cases | |
self.default = default | |
def run(self, vm): | |
if trace: | |
print "CASE" | |
vm.start_case() | |
reg = vm.get_register(self.reg) | |
if isinstance(reg, Constructor): | |
vm.bcs = self.cases.get(reg.tag, self.default) | |
assert vm.bcs | |
else: | |
assert self.default | |
vm.bcs = self.default | |
vm.pc = 0 | |
class ConstCase(Operation): | |
def __init__(self, reg, cases, default): | |
self.reg = reg | |
self.cases = cases | |
self.default = default | |
def run(self, vm): | |
if trace: | |
print "CONSTCASE" | |
vm.start_case() | |
reg = vm.get_register(self.reg) | |
for (value, bcs) in self.cases.iteritems(): | |
if value.__eq__(reg): | |
assert bcs | |
vm.bcs = bcs | |
break | |
else: | |
assert self.default | |
vm.bcs = self.default | |
vm.pc = 0 | |
class Reserve(Operation): | |
def __init__(self, n): | |
self.n = n | |
def run(self, vm): | |
if trace: | |
print "RESERVE %d" % self.n | |
vm.reserve(self.n) | |
class AddTop(Operation): | |
def __init__(self, n): | |
self.n = n | |
def run(self, vm): | |
if trace: | |
print "ADDTOP %d" % self.n | |
vm.top += self.n | |
class TopBase(Operation): | |
def __init__(self, n): | |
self.n = n | |
def run(self, vm): | |
if trace: | |
print "TOPBASE %d" % self.n | |
vm.top = vm.base + self.n | |
class BaseTop(Operation): | |
def __init__(self, n): | |
self.n = n | |
def run(self, vm): | |
if trace: | |
print "BASETOP %d" % self.n | |
vm.base = vm.top + self.n | |
class StoreOld(Operation): | |
def run(self, vm): | |
if trace: | |
print "STOREOLD" | |
vm.new_frame_ptr = vm.base | |
class Rebase(Operation): | |
def run(self, vm): | |
if trace: | |
print "REBASE" | |
vm.base = vm.frame_ptr | |
class Slide(Operation): | |
def __init__(self, n): | |
self.n = n | |
def run(self, vm): | |
if trace: | |
print "SLIDE %d" % self.n | |
for i in xrange(self.n): | |
vm.set_register(LRef(i), vm.get_register(TRef(i))) | |
class Project(Operation): | |
def __init__(self, reg, i, arity): | |
self.reg = reg | |
self.i = i | |
self.arity = arity | |
def run(self, vm): | |
if trace: | |
print "PROJECT %s %d %d" % (self.reg.__repr__(), self.i, self.arity) | |
value = vm.get_register(self.reg) | |
for i in xrange(self.arity): | |
vm.set_register(LRef(i + self.i), value.args[i]) | |
class Op(Operation): | |
def __init__(self, reg, op, args): | |
self.reg = reg | |
self.op = op | |
self.args = args | |
def run(self, vm): | |
if trace: | |
args = [] | |
for arg in self.args: | |
args.append(arg.__repr__()) | |
print "OP %s %s %s" % (self.reg.__repr__(), self.op, '[' + ', '.join(args) + ']') | |
if self.op == 'LBPlus' or self.op == 'LPlus': | |
assert len(self.args) == 2 | |
values = [] | |
for arg in self.args: | |
values.append(vm.get_register(arg)) | |
vm.set_register(self.reg, Int(values[0].n + values[1].n)) | |
elif self.op == 'LBMinus' or self.op == 'LMinus': | |
assert len(self.args) == 2 | |
values = [] | |
for arg in self.args: | |
values.append(vm.get_register(arg)) | |
vm.set_register(self.reg, Int(values[0].n - values[1].n)) | |
elif self.op == 'LVMPtr': | |
assert len(self.args) == 0 | |
vm.set_register(self.reg, VMPtr(vm)) | |
elif self.op == 'LEq': | |
assert len(self.args) == 2 | |
values = [] | |
for arg in self.args: | |
values.append(vm.get_register(arg)) | |
vm.set_register(self.reg, Int(1 if values[0].__eq__(values[1]) else 0)) | |
elif self.op == 'LStrConcat': | |
values = [] | |
for arg in self.args: | |
values.append(vm.get_register(arg).s) | |
vm.set_register(self.reg, String(''.join(values))) | |
else: | |
raise Exception("unknown prim op: %s" % self.op) | |
class Null(Operation): | |
def __init__(self, reg): | |
self.reg = reg | |
def run(self, vm): | |
if trace: | |
print "NULL %s" % self.reg.__repr__() | |
vm.set_register(self.reg, None) | |
class Error(Operation): | |
def __init__(self, err): | |
self.err = err | |
def run(self, vm): | |
raise Exception(self.err) | |
class Call(Operation): | |
def __init__(self, fn): | |
self.fn = fn | |
def run(self, vm): | |
if trace: | |
print "CALL %s" % self.fn | |
vm.call(self.fn) | |
class TailCall(Operation): | |
def __init__(self, fn): | |
self.fn = fn | |
def run(self, vm): | |
if trace: | |
print "TAILCALL %s" % self.fn | |
vm.tail_call(self.fn) | |
class ProjectInto(Operation): | |
def __init__(self, dst, src, i): | |
self.dst = dst | |
self.src = src | |
self.i = i | |
def run(self, vm): | |
if trace: | |
print "PROJECTINTO %s %s %d" % (self.dst, self.src, self.i) | |
vm.set_register(self.dst, vm.get_register(self.src).args[self.i]) | |
class ForeignCall(Operation): | |
def __init__(self, reg, fn, ret, args): | |
self.reg = reg | |
self.fn = fn | |
self.ret = ret | |
self.args = args | |
def run(self, vm): | |
if trace: | |
args = [] | |
for arg in self.args: | |
args.append('(' + arg[0].__repr__() + ', '+ arg[1] + ')') | |
print "FOREIGNCALL %s %s %s %s" % (self.reg.__repr__(), self.fn, self.ret, '[' + ', '.join(args) + ']') | |
if self.fn == 'putStr': | |
print vm.get_register(self.args[0][0]).s | |
elif self.fn == 'idris_numArgs': | |
vm.set_register(self.reg, Int(1)) | |
elif self.fn == 'idris_getArg': | |
vm.set_register(self.reg, String("./test")) | |
else: | |
raise Exception("unknown foreign call: %s" % self.fn) | |
def lstrip(s): | |
for i in xrange(len(s)): | |
if not s[i].isspace(): | |
return s[i:] | |
return '' | |
def rstrip(s): | |
for i in xrange(len(s)): | |
n = len(s)-i-1 | |
assert n >= 0 | |
if not s[n].isspace(): | |
return s[:n+1] | |
return '' | |
def strip(s): | |
return lstrip(rstrip(s)) | |
def split(s, separator): | |
output = [] | |
i = 0 | |
p = 0 | |
while i < len(s): | |
if s[i:].startswith(separator): | |
output.append(s[p:i]) | |
i += len(separator) | |
p = i | |
else: | |
i += 1 | |
output.append(s[p:]) | |
return output | |
def replace(s, b, a): | |
return a.join(split(s, b)) | |
def readline(fd): | |
line = [] | |
while True: | |
s = os.read(fd, 1) | |
line.append(s) | |
if not (s and s != '\n'): | |
break | |
return ''.join(line) | |
def get_indent(line): | |
n = 0 | |
for c in line: | |
if c != ' ': | |
return n/4 | |
else: | |
n += 1 | |
raise Exception("unexpected entirely spaces") | |
def parse_reg(reg): | |
if reg == 'RVal': | |
return RValRef() | |
elif reg == 'Tmp': | |
return TmpRef() | |
elif reg[0] == 'L': | |
return LRef(int(reg[1:])) | |
elif reg[0] == 'T': | |
return TRef(int(reg[1:])) | |
else: | |
raise Exception("cannot parse register: %s" % reg) | |
def parse_string(s): | |
n = len(s) - 1 | |
assert n >= 0 | |
assert s[0] == '"' and s[n] == '"' | |
return s[1:n] | |
def parse_type(s): | |
return s | |
def parse_record(r): | |
v = split(r, ' : ') | |
assert len(v) == 2 | |
reg, ty = v | |
return (parse_reg(reg), ty) | |
def parse_reg_type_list(s): | |
n = len(s) - 1 | |
assert n >= 0 | |
assert s[0] == '[' and s[n] == ']' | |
s = s[1:n] | |
records = [] | |
for v in split(s, ', '): | |
records.append(parse_record(v)) | |
return records | |
def isdigits(s): | |
for v in s: | |
if not v.isdigit(): | |
return False | |
return True | |
def isalnums(s): | |
for v in s: | |
if not v.isalnum(): | |
return False | |
return True | |
def parse_const(s): | |
n = len(s) - 1 | |
assert n >= 0 | |
if s[0] == '"' and s[n] == '"': | |
return String(s[1:n]) | |
elif isdigits(s): | |
return Int(int(s)) | |
elif isalnums(s): | |
return Type(s) | |
else: | |
raise Exception("unknown const %s" % s) | |
def parse_reg_list(s): | |
if s == '[]': | |
return [] | |
else: | |
output = [] | |
n = len(s) - 1 | |
assert n >= 1 | |
for v in split(s[1:n], ', '): | |
output.append(parse_reg(v)) | |
return output | |
class Clause(object): | |
pass | |
class SomeClause(Clause): | |
def __init__(self, case, bcs): | |
self.case = case | |
self.bcs = bcs | |
class EmptyClause(Clause): | |
pass | |
class ConstClause(object): | |
pass | |
class SomeConstClause(ConstClause): | |
def __init__(self, case, bcs): | |
self.case = case | |
self.bcs = bcs | |
class EmptyConstClause(ConstClause): | |
pass | |
def parse_clause(indent, fd): | |
n = os.lseek(fd, 0, os.SEEK_CUR) | |
line = readline(fd) | |
case = strip(line)[:-1] | |
if get_indent(line) == indent and case != 'default': | |
assert strip(line)[-1] == ':' | |
return SomeClause(int(case), parse_block(indent + 1, fd)) | |
else: | |
os.lseek(fd, n, os.SEEK_SET) | |
return EmptyClause() | |
def parse_const_clause(indent, fd): | |
n = os.lseek(fd, 0, os.SEEK_CUR) | |
line = readline(fd) | |
case = strip(line)[:-1] | |
if get_indent(line) == indent and case != 'default': | |
assert strip(line)[-1] == ':' | |
return SomeConstClause(parse_const(case), parse_block(indent + 1, fd)) | |
else: | |
os.lseek(fd, n, os.SEEK_SET) | |
return EmptyConstClause() | |
def parse_default_clause(indent, fd): | |
n = os.lseek(fd, 0, os.SEEK_CUR) | |
line = readline(fd) | |
if get_indent(line) == indent and strip(line) == 'default:': | |
return parse_block(indent + 1, fd) | |
else: | |
os.lseek(fd, n, os.SEEK_SET) | |
return None | |
def parse_clauses(indent, fd): | |
clauses = [] | |
while True: | |
clause = parse_clause(indent, fd) | |
if isinstance(clause, SomeClause): | |
clauses.append((clause.case, clause.bcs)) | |
else: | |
break | |
d = {} | |
for name, bcs in clauses: | |
d[name] = bcs | |
return (d, parse_default_clause(indent, fd)) | |
def parse_const_clauses(indent, fd): | |
clauses = [] | |
while True: | |
clause = parse_const_clause(indent, fd) | |
if isinstance(clause, SomeConstClause): | |
clauses.append((clause.case, clause.bcs)) | |
else: | |
break | |
d = {} | |
for name, bcs in clauses: | |
d[name] = bcs | |
return (d, parse_default_clause(indent, fd)) | |
def parse_bytecode(indent, fd): | |
pos = os.lseek(fd, 0, os.SEEK_CUR) | |
line = rstrip(readline(fd)) | |
if line == '': | |
os.lseek(fd, pos, os.SEEK_SET) | |
return None | |
else: | |
n = get_indent(line) | |
if n != indent: | |
os.lseek(fd, pos, os.SEEK_SET) | |
return None | |
line = strip(line) | |
pieces = line.split(' ') | |
op = pieces[0] | |
args = pieces[1:] | |
if op == 'NULL': | |
assert len(args) == 1 | |
return Null(parse_reg(args[0])) | |
elif op == 'MKCON': | |
n = os.lseek(fd, 0, os.SEEK_CUR) | |
return MakeCon(parse_reg(args[0]), int(args[1]), parse_reg_list(' '.join(args[2:])), n) | |
elif op == 'STOREOLD': | |
assert len(args) == 0 | |
return StoreOld() | |
elif op == 'REBASE': | |
assert len(args) == 0 | |
return Rebase() | |
elif op == 'ASSIGN': | |
assert len(args) == 2 | |
return Assign(parse_reg(args[0]), parse_reg(args[1])) | |
elif op == 'PROJECT': | |
assert len(args) == 3 | |
return Project(parse_reg(args[0]), int(args[1]), int(args[2])) | |
elif op == 'PROJECTINTO': | |
assert len(args) == 3 | |
return ProjectInto(parse_reg(args[0]), parse_reg(args[1]), int(args[2])) | |
elif op == 'RESERVE': | |
assert len(args) == 1 | |
return Reserve(int(args[0])) | |
elif op == 'TOPBASE': | |
assert len(args) == 1 | |
return TopBase(int(args[0])) | |
elif op == 'BASETOP': | |
assert len(args) == 1 | |
return BaseTop(int(args[0])) | |
elif op == 'ADDTOP': | |
assert len(args) == 1 | |
return AddTop(int(args[0])) | |
elif op == 'SLIDE': | |
assert len(args) == 1 | |
return Slide(int(args[0])) | |
elif op == 'CALL': | |
return Call(' '.join(args)) | |
elif op == 'TAILCALL': | |
return TailCall(' '.join(args)) | |
elif op == 'ERROR': | |
return Error(' '.join(args)) | |
elif op == 'OP': | |
return Op(parse_reg(args[0]), args[1], parse_reg_list(' '.join(args[2:]))) | |
elif op == 'CASE': | |
assert len(args) == 1 | |
reg = parse_reg(args[0][:-1]) | |
clauses = parse_clauses(indent + 1, fd) | |
return Case(reg, clauses[0], clauses[1]) | |
elif op == 'CONSTCASE': | |
assert len(args) == 1 | |
reg = parse_reg(args[0][:-1]) | |
clauses = parse_const_clauses(indent + 1, fd) | |
return ConstCase(reg, clauses[0], clauses[1]) | |
elif op == 'ASSIGNCONST': | |
reg = parse_reg(args[0]) | |
val = parse_const(' '.join(args[1:])) | |
return AssignConst(reg, val) | |
elif op == 'FOREIGNCALL': | |
reg = parse_reg(args[0]) | |
fn = parse_string(args[1]) # FIXME: the string may not be contained within args[1] | |
ret = parse_type(args[2]) | |
args = parse_reg_type_list(' '.join(args[3:])) | |
return ForeignCall(reg, fn, ret, args) | |
else: | |
raise Exception("unknown operation: %s" % op) | |
def parse_block(indent, fd): | |
bcs = [] | |
while True: | |
bc = parse_bytecode(indent, fd) | |
if bc: | |
bcs.append(bc) | |
else: | |
return bcs | |
def parse_declaration(fd): | |
line = strip(readline(fd)) | |
if len(line) >= 1: | |
n = len(line)-1 | |
assert n >= 0 | |
name = line[:n] | |
bcs = parse_block(1, fd) | |
line = strip(readline(fd)) | |
assert line == '' | |
return (name, bcs) | |
else: | |
return (None, []) | |
def parse(filename): | |
fd = os.open(filename, os.O_RDONLY, 0777) | |
decls = {} | |
while True: | |
name, bcs = parse_declaration(fd) | |
if name: | |
decls[name] = bcs | |
else: | |
break | |
os.close(fd) | |
return decls | |
class StackFrame(object): | |
pass | |
class CallFrame(StackFrame): | |
def __init__(self, bcs, pc, frame_ptr, new_frame_ptr): | |
self.bcs = bcs | |
self.pc = pc | |
self.frame_ptr = frame_ptr | |
self.new_frame_ptr = new_frame_ptr | |
class CaseFrame(StackFrame): | |
def __init__(self, bcs, pc): | |
self.bcs = bcs | |
self.pc = pc | |
class VM(object): | |
def __init__(self, declarations): | |
self.tmp = None | |
self.rval = None | |
self.stack = [] | |
self.base = 0 | |
self.top = 0 | |
self.frame_ptr = 0 | |
self.new_frame_ptr = 0 | |
self.callstack = [] | |
self.declarations = declarations | |
self.bcs = declarations["{runMain0}"] | |
self.pc = 0 | |
def start_case(self): | |
self.callstack.append(CaseFrame(self.bcs, self.pc)) | |
def ret(self): | |
if trace: | |
print "RET" | |
if self.callstack: | |
frame = self.callstack.pop() | |
if isinstance(frame, CallFrame): | |
self.bcs = frame.bcs | |
self.pc = frame.pc | |
self.frame_ptr = frame.frame_ptr | |
self.new_frame_ptr = frame.new_frame_ptr | |
elif isinstance(frame, CaseFrame): | |
self.bcs = frame.bcs | |
self.pc = frame.pc | |
else: | |
raise TypeError("unknown frame type") | |
return False | |
else: | |
return True | |
def call(self, fn): | |
self.callstack.append(CallFrame(self.bcs, self.pc, self.frame_ptr, self.new_frame_ptr)) | |
self.frame_ptr = self.new_frame_ptr | |
self.bcs = self.declarations[fn] | |
self.pc = 0 | |
def tail_call(self, fn): | |
self.bcs = self.declarations[fn] | |
self.pc = 0 | |
def run(self): | |
n = 0 | |
while True: | |
n += 1 | |
if self.pc >= len(self.bcs): | |
if self.ret(): | |
break | |
else: | |
bc = self.bcs[self.pc] | |
self.pc += 1 | |
bc.run(self) | |
if debug: | |
print "step: %d" % n | |
print "tmp: %s" % (self.tmp.__repr__() if self.tmp else 'None') | |
print "rval: %s" % (self.rval.__repr__() if self.rval else 'None') | |
stack = [] | |
for s in self.stack: | |
stack.append(s.__repr__() if s else 'None') | |
print "stack: %s" % ('[' + ', '.join(stack) + ']') | |
print "base: %d" % self.base | |
print "top: %d" % self.top | |
print "oldbase: %d" % self.frame_ptr | |
print "myoldbase: %d" % self.new_frame_ptr | |
def get_register(self, r): | |
if isinstance(r, RValRef): | |
return self.rval | |
elif isinstance(r, TmpRef): | |
return self.tmp | |
elif isinstance(r, LRef): | |
return self.stack[self.base + r.n] | |
elif isinstance(r, TRef): | |
return self.stack[self.top + r.n] | |
else: | |
raise TypeError("Unknown reference type") | |
def set_register(self, r, v): | |
if isinstance(r, RValRef): | |
self.rval = v | |
elif isinstance(r, TmpRef): | |
self.tmp = v | |
elif isinstance(r, LRef): | |
self.stack[self.base + r.n] = v | |
elif isinstance(r, TRef): | |
self.stack[self.top + r.n] = v | |
else: | |
raise TypeError("Unknown reference type") | |
def reserve(self, n): | |
while len(self.stack) < self.top + n: | |
self.stack.append(None) | |
for i in xrange(n): | |
self.stack[self.top + i] = None | |
def entry_point(argv): | |
try: | |
filename = argv[1] | |
except IndexError: | |
print "You must supply a filename" | |
return 1 | |
VM(parse(filename)).run() | |
return 0 | |
def target(*args): | |
return entry_point, None | |
if __name__ == "__main__": | |
entry_point(sys.argv) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment