Skip to content

Instantly share code, notes, and snippets.

@hisui
Created March 24, 2012 12:51
Show Gist options
  • Save hisui/2182361 to your computer and use it in GitHub Desktop.
Save hisui/2182361 to your computer and use it in GitHub Desktop.
簡易言語のJITコンパイラー方式の言語処理系
// programmed by hisui(https://github.com/hisui)
/**
*
* jitexpr は簡易言語のJITコンパイラー方式の言語処理系です
*
* スクリプトの記述例:
* def fib(n) {
* if(n < 3) 1 else fib(n-1) + fib(n-2);
* }
* i = 0;
* while(i < 30) dump(i, fib(i = i + 1));
*
* ビルド:
* g++ -std=c++0x jitexpr.cpp -o jitexpr.exe
*
* 使い方:
* jitexpr [i/c] FILENAME
* i: バイトコードインタープリターとして実行
* c: ネイティブコードに変換して実行
*
*/
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include <tuple>
#include <deque>
#include <unordered_map>
#include <iostream>
#include <sstream>
#include <fstream>
#include <stdexcept>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <assert.h>
#include <ctype.h>
#include <windows.h>
namespace sf {
namespace jx {
using std::unique_ptr;
using std::shared_ptr;
using std::make_shared;
using std::make_pair;
using std::make_tuple;
//-------------------------------------
// 構文木の定義
//-------------------------------------
// AST nodes
namespace node
{
class Func; // def XXX(...) ...
class Cond; // if(...) ... else ...
class Loop; // while(...) ...
class List; // ...; ...
class Call; // XXX(...)
class SetVar; // XXX = ...
class GetVar; // XXX
class Value; // 2525
}
class Node
{
public:
class Visitor
{
public:
virtual void visit(node::Func &node) = 0;
virtual void visit(node::Cond &node) = 0;
virtual void visit(node::Loop &node) = 0;
virtual void visit(node::List &node) = 0;
virtual void visit(node::Call &node) = 0;
virtual void visit(node::SetVar &node) = 0;
virtual void visit(node::GetVar &node) = 0;
virtual void visit(node::Value &node) = 0;
};
virtual void accept(Visitor &visitor) = 0;
virtual std::string stringify() = 0;
};
template<typename Sub, typename Base=Node>
class NodeTemplate: public Base
{
public:
void accept(Node::Visitor &visitor)
{
visitor.visit(*static_cast<Sub*>(this));
}
};
namespace node {
class Cond: public NodeTemplate<Cond>
{
public:
Cond()
{
}
std::string stringify()
{
std::ostringstream oss;
oss << "cond {";
if(!branches.empty()) {
size_t i = 0;
goto inner;
do {
oss << ", ";
inner:
oss
<< std::get<0>(branches[i])->stringify() << " => "
<< std::get<1>(branches[i])->stringify();
} while(++i < branches.size());
}
oss << "}";
return oss.str();
}
std::vector<std::pair<
shared_ptr<Node>,
shared_ptr<Node>>> branches;
};
class Loop: public NodeTemplate<Loop>
{
public:
Loop(
const shared_ptr<Node> &head,
const shared_ptr<Node> &body)
:head(head)
,body(body)
{
}
std::string stringify()
{
std::ostringstream oss;
oss
<< "loop("
<< head->stringify() << ") {"
<< body->stringify() << "}";
return oss.str();
}
shared_ptr<Node> head;
shared_ptr<Node> body;
};
class List: public NodeTemplate<List>
{
public:
List()
{
}
void append(const shared_ptr<Node> &node)
{
auto list = std::dynamic_pointer_cast<List>(node);
if(!list) {
nodes.push_back(node);
return;
}
auto &src = list->nodes;
auto i = nodes.size();
nodes.resize(i + src.size());
std::copy(src.begin(), src.end(), nodes.begin() + i);
}
std::string stringify()
{
std::ostringstream oss;
oss << "{";
for(auto i = nodes.begin(); i != nodes.end(); ++i) {
oss << (**i).stringify() << ";";
}
oss << "}";
return oss.str();
}
std::vector<shared_ptr<Node>> nodes;
};
class Func: public NodeTemplate<Func,List>
{
public:
Func()
{
}
std::string stringify()
{
std::ostringstream oss;
oss << "(";
if(!args.empty()) {
size_t i = 0;
goto inner;
do {
oss << ", ";
inner:
oss << args[i];
} while(++i < args.size());
}
oss << ") -> " << List::stringify();
return oss.str();
}
std::vector<std::string> args;
};
class Call: public NodeTemplate<Call>
{
public:
Call(const std::string &name)
:name(name)
{
}
Call(const std::string &name,
const shared_ptr<Node> &lhs,
const shared_ptr<Node> &rhs)
:name(name)
{
args.push_back(lhs);
args.push_back(rhs);
}
std::string stringify()
{
std::ostringstream oss;
oss << name << "(";
if(!args.empty()) {
size_t i = 0;
goto inner;
do {
oss << ", ";
inner:
oss << args[i]->stringify();
} while(++i < args.size());
}
oss << ")";
return oss.str();
}
std::string name;
std::vector<shared_ptr<Node>> args;
};
class SetVar: public NodeTemplate<SetVar>
{
public:
SetVar(const std::string &name, const shared_ptr<Node> &node)
:name(name)
,node(node)
{
}
std::string stringify()
{
std::ostringstream oss;
oss << "(" << name << " = " << node->stringify() << ")";
return oss.str();
}
std::string name;
shared_ptr<Node> node;
};
class GetVar: public NodeTemplate<GetVar>
{
public:
GetVar(const std::string &name)
:name(name)
{
}
std::string stringify()
{
return name;
}
std::string name;
};
class Value: public NodeTemplate<Value>
{
public:
Value(int value)
:value(value)
{
}
std::string stringify()
{
std::ostringstream oss;
oss << value;
return oss.str();
}
int value;
};
} //namespace
//-------------------------------------
// パーサーと字句解析器
//-------------------------------------
class SyntaxError: public std::runtime_error
{
public:
SyntaxError(const std::string &what)
:std::runtime_error(what)
{
}
};
class Token
{
public:
Token()
{
}
Token(
const std::string &kind,
const std::string &data)
:kind(kind)
,data(data)
{
}
std::string kind;
std::string data;
};
class Reader
{
public:
Reader(const std::string &src, size_t off)
:src(src)
,off(off)
{
analyze();
}
const Token &head()
{
return _head;
}
const Token &walkAndGet()
{
walk();
return _head;
}
Token getAndWalk()
{
auto head = _head;
walk();
return head;
}
void expect(const std::string &kind)
{
if(_head.kind != kind) {
throw SyntaxError(
"expected `"+ kind +"', but `"+ _head.kind +"' found.");
}
walk();
}
private:
void walk()
{
analyze();
// std::cerr << "Reader.walk: (" << _head.kind << ", '" << _head.data << "')." << std::endl;
return;
}
void analyze()
{
for(;;++off) {
if(off >= src.size()) {
_head.kind = "eof";
_head.data = "";
return;
}
if(!isspace(src[off])) break;
}
auto i = off++;
auto c = src[i];
if(isalpha(c)) { // indentifier
while(off < src.size() && isalnum(src[off])) ++off;
_head.kind = "id";
_head.data = src.substr(i, off - i);
return;
}
if(isdigit(c)) { // Value
while(off < src.size() && isdigit(src[off])) ++off;
_head.kind = "int";
_head.data = src.substr(i, off - i);
return;
}
switch(c) { // operators
case '+': case '-':
case '*': case '/':
case '{': case '}':
case '(': case ')':
case ',': case ';':
case '&': case '|':
case '=':
case '<': case '>':
_head.kind = std::string(&c, 1);
_head.data = "";
return;
}
throw SyntaxError("Invalid char: `"+ std::string(&c, 1) +"'");
}
std::string src;
size_t off;
Token _head;
};
class Parser: private Reader
{
public:
Parser(const std::string &src)
:Reader(src, 0)
{
}
void parse()
{
// the main function(entry point)
auto main = make_shared<node::Func>();
functions.insert(std::make_pair("<main>", main));
while(head().kind != "eof") {
// a definition of function
if(head().data == "def") {
auto name = walkAndGet().data;
if(functions.find(name) != functions.end()) {
throw SyntaxError("Duplicated function name: `"+ name +"'");
}
auto func = make_shared<node::Func>();
expect("id");
expect("(");
std::vector<std::string> args;
if(head().kind != ")") {
for(;;) {
func->args.push_back(head().data);
expect("id");
if(head().kind != ",") {
break;
}
walkAndGet();
}
}
expect(")");
func->append(newStmt());
functions.insert(std::make_pair(name, func));
continue;
}
// main program fragment
main->append(newStmt());
}
}
shared_ptr<Node> newStmt()
{
bool is_block;
auto stmt = newExpr(is_block);
if(!is_block) {
expect(";");
}
return stmt;
}
shared_ptr<Node> newExpr(bool &is_block)
{
// { ...; ...; }
is_block = ( head().kind == "{" );
if(is_block) {
walkAndGet();
auto block = make_shared<node::List>();
while(head().kind != "}") {
if(head().kind == "eof") {
throw SyntaxError("unexpected EOF.");
}
block->append(newStmt());
}
walkAndGet();
return block;
}
// expression
return infixSet();
}
shared_ptr<Node> newExpr()
{
bool unused;
return newExpr(unused);
}
shared_ptr<Node> infixSet()
{
auto lhs = infixCmp();
if(head().kind == "=") {
walkAndGet();
auto var = std::dynamic_pointer_cast<node::GetVar>(lhs);
if(!var) {
throw SyntaxError("cannot assing to non-variable expression.");
}
lhs = make_shared<node::SetVar>(var->name, infixSet());
}
return lhs;
}
shared_ptr<Node> infixCmp()
{
auto lhs = infixAdd();
for(;;) {
auto kind = head().kind;
if(kind != "=="
&& kind != "!="
&& kind != ">="
&& kind != "<="
&& kind != "<"
&& kind != ">") {
break;
}
walkAndGet();
lhs = make_shared<node::Call>(kind, lhs, infixAdd());
}
return lhs;
}
shared_ptr<Node> infixAdd()
{
auto lhs = infixMul();
for(;;) {
auto kind = head().kind;
if(kind != "+" && kind != "-") {
break;
}
walkAndGet();
lhs = make_shared<node::Call>(kind, lhs, infixMul());
}
return lhs;
}
shared_ptr<Node> infixMul()
{
auto lhs = innermost();
for(;;) {
auto kind = head().kind;
if(kind != "*" && kind != "/") {
break;
}
walkAndGet();
lhs = make_shared<node::Call>(kind, lhs, innermost());
}
return lhs;
}
shared_ptr<Node> innermost()
{
// ( ... )
if(head().kind == "(") {
walkAndGet();
auto expr = newExpr();
expect(")");
return expr;
}
// if(...) ...; else ...;
if(head().data == "if") {
walkAndGet();
expect("(");
auto e = newExpr();
expect(")");
auto cond = make_shared<node::Cond>();
cond->branches.push_back(make_pair(e, newExpr()));
if(head().data == "else") {
walkAndGet();
cond->branches.push_back(
make_pair(make_shared<node::Value>(1), newExpr()));
}
return cond;
}
// while(...) ...;
if(head().data == "while") {
walkAndGet();
expect("(");
auto e = newExpr();
expect(")");
return make_shared<node::Loop>(e, newExpr());
}
// Value: 2525
if(head().kind == "int") {
return make_shared<node::Value>(atoi(getAndWalk().data.c_str()));
}
// variable or function call
if(head().kind == "id") {
auto name = getAndWalk().data;
if(head().kind != "(") {
return make_shared<node::GetVar>(name);
}
auto call = make_shared<node::Call>(name);
if(walkAndGet().kind != ")") {
for(;;) {
call->args.push_back(newExpr());
if(head().kind != ",") {
break;
}
walkAndGet();
}
}
expect(")");
return call;
}
throw SyntaxError("unexpected token: "+ head().kind);
}
std::unordered_map<std::string,
shared_ptr<node::Func>> functions;
};
//-------------------------------------
// スタックマシン
//-------------------------------------
// VM instructions
enum: uintptr_t
{
OP_invalid = 0,
OP_noop,
OP_setvar,
OP_getvar,
OP_setarg,
OP_getarg,
OP_return,
OP_invoke,
OP_add,
OP_sub,
OP_mul,
OP_div,
OP_lt,
OP_gt,
OP_cmp,
OP_iftrue,
OP_iffalse,
OP_jump,
OP_push,
OP_pop,
OP_dup,
OP_dump,
};
class Procedure
{
public:
Procedure(const std::string &name, size_t id,
size_t argNum=0,
size_t varNum=0)
:name(name)
,id(id)
,argNum(argNum)
,varNum(varNum)
,code(0)
,native(nullptr)
{
}
~Procedure()
{
if(native) {
VirtualFree(native, 0, MEM_RELEASE);
}
}
std::string name; // procedure name
size_t id;
size_t argNum; // number of arguments
size_t varNum; // number of local variables
size_t stackUsage; // max value of stack size on which this procedure runs
std::vector<intptr_t> code; // compiled bytecode
uint8_t *native;
};
class State
{
public:
State(const std::string &script);
Procedure *lookupProcedure(size_t i)
{
return procedures[i].get();
}
Procedure *lookupProcedure(const std::string &name)
{
auto i = std::find_if(procedures.begin(), procedures.end(),
[&] (const unique_ptr<Procedure> &e) { return e->name == name; });
return i != procedures.end()
? (*i).get()
: nullptr ;
}
private:
std::vector<unique_ptr<Procedure>> procedures;
};
class BytecodeAssembler: public Node::Visitor
{
struct Label
{
Label()
:pos(0)
,ref(0)
{
}
intptr_t pos;
intptr_t ref;
};
public:
BytecodeAssembler(State *state, const shared_ptr<node::Func> &function)
:state(state)
,function(function)
{
auto &args = function->args;
for(auto i = args.begin(); i != args.end(); ++i) {
vars.insert(make_pair(*i, -(vars.size() + 1)));
}
}
void run()
{
valueRequired.push_back(true);
// traces the syntax tree by DFS
visit(static_cast<node::List&>(*function));
emit(OP_return);
// resolves all labels
for(auto i = labels.begin(); i != labels.end(); ++i) {
auto &label = **i;
for(auto ref = label.ref; ref; ) {
auto off = label.pos - ref - 1; // !!
std::swap(off, code[ref - 1]);
ref = off;
}
}
}
void visit(node::Func &node)
{
assert( false );
}
void visit(node::Cond &node)
{
auto &branches = node.branches;
assert(!branches.empty());
bool b = wantvalue();
Label *exit = newLabel();
Label *next = nullptr;
for(auto i = branches.begin(); i != branches.end(); ++i) {
if(next) {
setLabel(next);
}
intoSubTree(*i->first, true);
emit(OP_iffalse);
emitLabel(( next = newLabel() ));
intoSubTree(*i->second, b);
emit(OP_jump);
emitLabel(exit);
}
setLabel(next);
if(b) {
emit(OP_push);
emit(0);
}
setLabel(exit);
emit(OP_noop);
}
void visit(node::Loop &node)
{
auto head = newLabel();
auto body = newLabel();
bool b = wantvalue();
if(b) {
emit(OP_push);
emit(0);
}
emit(OP_jump);
emitLabel(head);
// loop body
setLabel(body);
if(b) {
emit(OP_pop);
}
intoSubTree(*node.body, b);
// loop head(condition expr)
setLabel(head);
intoSubTree(*node.head, true);
// branch
emit(OP_iftrue);
emitLabel(body);
}
void visit(node::List &node)
{
bool b = wantvalue();
auto &nodes = node.nodes;
if(nodes.empty()) {
if(b) {
emit(OP_push);
emit(0);
}
return;
}
size_t i = 0;
while(i + 1 < nodes.size()) {
intoSubTree(*nodes[i++], false);
}
intoSubTree(*nodes[i], b);
}
void visit(node::Call &node)
{
auto &args = node.args;
for(auto i = args.rbegin(); i != args.rend(); ++i) {
intoSubTree(**i, true);
}
if(node.name == "dump") {
emit(OP_dump);
emit(args.size());
}
else if(node.name.size() == 1 && node.args.size() == 2) {
switch(node.name[0]) {
case '+': emit(OP_add); break;
case '-': emit(OP_sub); break;
case '*': emit(OP_mul); break;
case '/': emit(OP_div); break;
case '<': emit(OP_lt ); break;
case '>': emit(OP_gt ); break;
default:
goto proceed;
}
}
else {
proceed:
auto proc = state->lookupProcedure(node.name);
if(!proc) {
throw std::runtime_error("undefined function: `"+ node.name +"'");
}
if(proc->argNum != node.args.size()) {
throw std::runtime_error("argument number mismatch: `"+ node.name +"'");
}
emit(OP_invoke);
emit(proc->id);
}
if(!wantvalue()) {
emit(OP_pop);
}
}
void visit(node::SetVar &node)
{
bool b = wantvalue();
auto i = newIndex(node.name);
intoSubTree(*node.node, true);
if(b) {
emit(OP_dup);
}
if(i < 0) {
emit(OP_setarg);
emit(-i);
}
else {
emit(OP_setvar);
emit(i);
}
}
void visit(node::GetVar &node)
{
if(wantvalue()) {
auto i = newIndex(node.name);
if(i < 0) {
emit(OP_getarg);
emit(-i);
}
else {
emit(OP_getvar);
emit(i);
}
}
}
void visit(node::Value &node)
{
if(wantvalue()) {
emit(OP_push);
emit(node.value);
}
}
void intoSubTree(Node &node, bool wantvalue)
{
auto size = valueRequired.size();
valueRequired.push_back(wantvalue);
node.accept(*this);
if(size != valueRequired.size()) {
assert(valueRequired.size() > size);
valueRequired.resize(size);
}
}
bool wantvalue()
{
bool b = valueRequired.back();
valueRequired.pop_back();
return b;
}
int newIndex(const std::string &name)
{
auto i = vars.find(name);
if(i != vars.end()) {
return i->second;
}
auto index = vars.size() - function->args.size();
vars.insert(make_pair(name, index));
return index;
}
Label *newLabel()
{
labels.push_back(unique_ptr<Label>(new Label()));
return labels.back().get();
}
void setLabel(Label *label)
{
assert(label->pos == 0);
label->pos = code.size() + 1;
}
void emit(intptr_t op)
{
code.push_back(op);
}
void emitLabel(Label *label)
{
code.push_back(label->ref);
label->ref = code.size(); // linked list.. <(^_^;)
}
// initial properties
State *state;
shared_ptr<node::Func> function;
// data to be constructed
std::vector<intptr_t> code; // into which compiled bytecodes is
std::vector<unique_ptr<Label>> labels;
std::deque<bool> valueRequired;
std::unordered_map<std::string,int> vars; // variable table
};
class BytecodeDisassembler
{
public:
BytecodeDisassembler(const intptr_t *code, size_t size)
:code(code)
,size(size)
{
}
void run(std::ostream &out)
{
auto pc = code;
while(pc != code + size) {
out << "\t" << pc - code << ": ";
switch(*pc++) {
default:
throw std::logic_error("BytecodeDisassembler: unknown opcode!");
case OP_noop:
out << "noop";
break;
case OP_setvar: out << "setvar var:" << *pc++; break;
case OP_getvar: out << "getvar var:" << *pc++; break;
case OP_setarg: out << "setarg arg:" << *pc++; break;
case OP_getarg: out << "getarg arg:" << *pc++; break;
case OP_return:
out << "return";
break;
case OP_invoke:
out << "invoke proc:" << *pc++;
break;
case OP_add: out << "add"; break;
case OP_sub: out << "sub"; break;
case OP_mul: out << "mul"; break;
case OP_div: out << "div"; break;
case OP_lt: out << "lt"; break;
case OP_gt: out << "gt"; break;
case OP_iftrue :{ auto t = *pc++; out << "iftrue code:" << pc + t - code; break; }
case OP_iffalse :{ auto t = *pc++; out << "iffalse code:" << pc + t - code; break; }
case OP_jump :{ auto t = *pc++; out << "jump code:" << pc + t - code; break; }
case OP_push:
out << "push " << *pc++;
break;
case OP_pop: out << "pop"; break;
case OP_dup: out << "dup"; break;
case OP_dump:
out << "dump " << *pc++;
break;
}
out << std::endl;
}
}
private:
const intptr_t *code;
size_t size;
};
class Interpreter
{
public:
Interpreter(State *state, size_t stacksize)
:state(state)
,stack(stacksize)
{
}
int run(int ep)
{
Procedure *proc = state->lookupProcedure(ep);
Procedure *prev = 0;
intptr_t *sp = &stack[0];
intptr_t *pc = 0;
intptr_t *bp = 0;
goto start;
for(;;) {
switch(*pc++) {
default:
throw std::logic_error("unknown opcode!");
case OP_noop:
break;
case OP_setvar: bp[*pc++] = *--sp; break;
case OP_getvar: *sp++ = bp[*pc++]; break;
case OP_setarg: bp[-*pc++ - 3] = *--sp; break;
case OP_getarg: *sp++ = bp[-*pc++ - 3]; break;
case OP_return:
{
auto value = *--sp;
int varNum = proc->varNum;
int argNum = proc->argNum;
sp -= varNum;
bp = reinterpret_cast<intptr_t*>(*--sp);
pc = reinterpret_cast<intptr_t*>(*--sp);
proc = reinterpret_cast<Procedure*>(*--sp);
if(!proc) {
return value;
}
sp -= argNum;
*sp++ = value;
}
break;
case OP_invoke:
prev = proc;
proc = state->lookupProcedure(*pc);
start:
*sp++ = intptr_t(prev);
*sp++ = intptr_t(pc+1);
*sp++ = intptr_t(bp);
bp = sp;
sp += proc->varNum;
pc = &proc->code[0];
break;
case OP_add: --sp; *sp += sp[-1]; break;
case OP_sub: --sp; *sp -= sp[-1]; break;
case OP_mul: --sp; *sp *= sp[-1]; break;
case OP_div: --sp; *sp /= sp[-1]; break;
case OP_lt: --sp; *sp = *sp < sp[-1]; break;
case OP_gt: --sp; *sp = *sp > sp[-1]; break;
case OP_iftrue : { auto t = *pc++; if( *--sp) pc += t; } break;
case OP_iffalse: { auto t = *pc++; if(!*--sp) pc += t; } break;
case OP_jump:
pc += *pc++;
break;
case OP_push:
*sp++ = *pc++;
break;
case OP_pop:
--sp;
break;
case OP_dup:
*sp = sp[-1];
++sp;
break;
case OP_dump:
{
int argc = *pc++;
printf("dump:");
for(int i = 0; i < argc; ++i) {
printf(" %d", *--sp);
}
puts("");
*sp++ = 0;
}
break;
}
}
}
private:
State *state;
std::vector<intptr_t> stack;
};
//-------------------------------------
// コード生成
//-------------------------------------
enum RegisterKind
{
// 汎用レジスタ
EAX = 0,
ECX,
EDX,
EBX,
ESP,
EBP,
ESI,
EDI,
};
class Imm
{
public:
Imm(int32_t data)
:data(data)
{
}
int32_t data;
};
class Mem
{
public:
Mem(uint32_t addr)
:addr(addr)
{
}
uint32_t addr;
};
class Reg
{
public:
Reg(RegisterKind kind)
:kind(kind)
{
}
RegisterKind kind;
};
class Ref
{
public:
Ref(RegisterKind kind)
:kind(kind)
{
}
RegisterKind kind;
};
class SIB
{
public:
SIB(RegisterKind base, RegisterKind index, int scale, int disp)
:base(base)
,index(index)
,scale(scale)
,disp(disp)
{
}
SIB(RegisterKind base, int disp)
:base(base)
,scale(-1)
,disp(disp)
{
}
RegisterKind base;
RegisterKind index;
int scale;
int disp;
};
int dump(int argc, ...)
{
va_list ap;
va_start(ap, argc);
printf("dump:");
for(int i = 0; i < argc; ++i) {
printf(" %d", va_arg(ap, int));
}
puts("");
return 0;
}
static auto dump_p = &dump;
class CodeGen_x86
{
public:
CodeGen_x86(Procedure *procedure, State *state)
:procedure(procedure)
,state(state)
{
}
void run()
{
auto pc = &procedure->code[0];
auto end = pc+procedure->code.size();
emit_PUSH(Reg(EBP));
emit_MOV(Reg(EBP), Reg(ESP));
if(procedure->varNum) {
emit_SUB(Reg(ESP), Imm(4 * procedure->varNum));
}
while(pc != end) {
putOffset(pc);
switch(*pc++) {
default:
throw std::logic_error("CodeGen_x86: unknown opcode!");
case OP_noop:
break;
case OP_setvar: emit_POP(Reg(EAX)); emit_MOV(SIB(EBP, (*pc+++1) * -4 ), Reg(EAX)); break;
case OP_setarg: emit_POP(Reg(EAX)); emit_MOV(SIB(EBP, *pc++ * 4 + 4), Reg(EAX)); break;
case OP_getvar: emit_PUSH(SIB(EBP, (*pc+++1) * -4 )); break;
case OP_getarg: emit_PUSH(SIB(EBP, *pc++ * 4 + 4)); break;
case OP_return:
emit_POP(Reg(EAX));
emit_MOV(Reg(ESP), Reg(EBP));
emit_POP(Reg(EBP));
emit_RET();
break;
case OP_invoke:
{
auto callee = state->lookupProcedure(*pc++);
emit_CALL(reinterpret_cast<void**>(&callee->native));
emit_ADD(Reg(ESP), Imm(callee->argNum * 4));
emit_PUSH(Reg(EAX));
}
break;
case OP_add:
emit_POP(Reg(EAX));
emit_POP(Reg(EBX));
emit_ADD(Reg(EAX), Reg(EBX));
emit_PUSH(Reg(EAX));
break;
case OP_sub:
emit_POP(Reg(EAX));
emit_POP(Reg(EBX));
emit_SUB(Reg(EAX), Reg(EBX));
emit_PUSH(Reg(EAX));
break;
case OP_mul:
emit_POP(Reg(EAX));
emit_POP(Reg(EBX));
emit_IMUL(Reg(EBX));
emit_PUSH(Reg(EAX));
break;
case OP_div:
emit_POP(Reg(EAX));
emit_POP(Reg(EBX));
emit_IDIV(Reg(EBX));
emit_PUSH(Reg(EAX));
break;
case OP_lt:
emit_POP(Reg(EAX));
emit_POP(Reg(EBX));
emit_SUB(Reg(EAX), Reg(EBX));
emit_AND(Imm(0x80000000u));
emit_PUSH(Reg(EAX));
break;
case OP_gt:
emit_POP(Reg(EBX));
emit_POP(Reg(EAX));
emit_SUB(Reg(EAX), Reg(EBX));
emit_AND(Imm(0x80000000u));
emit_PUSH(Reg(EAX));
break;
case OP_iftrue:
{
auto t = *pc++;
emit_POP(Reg(EAX));
emit_TEST(Imm(0xffffffffu));
emit_JNZ(pc + t);
}
break;
case OP_iffalse:
{
auto t = *pc++;
emit_POP(Reg(EAX));
emit_TEST(Imm(0xffffffffu));
emit_JZ(pc + t);
}
break;
case OP_jump:
{
auto t = *pc++;
emit_JMP(pc + t);
}
break;
case OP_push:
emit_PUSH(Imm(*pc++));
break;
case OP_pop:
emit_POP(Reg(EAX));
break;
case OP_dup:
emit_POP(Reg(EAX));
emit_PUSH(Reg(EAX));
emit_PUSH(Reg(EAX));
break;
case OP_dump:
{
auto count = *pc++;
emit_PUSH(Imm(count));
emit_CALL(reinterpret_cast<void**>(&dump_p));
emit_ADD(Reg(ESP), Imm(count * 4));
}
break;
}
}
}
void emit_ADD(const Reg &a, const Reg &b)
{
emitByte(0x03);
emitModRM(3, a.kind, b.kind);
}
void emit_ADD(const Reg &reg, const Imm &imm)
{
emitByte(0x81);
emitModRM(3, 0, reg.kind);
emitUI32(imm.data);
}
void emit_SUB(const Reg &reg, const Imm &imm)
{
emitByte(0x81);
emitModRM(3, 5, reg.kind);
emitUI32(imm.data);
}
void emit_SUB(const Reg &a, const Reg &b)
{
emitByte(0x2b);
emitModRM(3, a.kind, b.kind);
}
void emit_IMUL(const Reg &reg)
{
emitByte(0xf7);
emitModRM(3, 5, reg.kind);
}
void emit_IDIV(const Reg &reg)
{
emitByte(0xf7);
emitModRM(3, 7, reg.kind);
}
void emit_NOT(const Reg &reg)
{
emitByte(0xf7);
emitModRM(3, 2, reg.kind);
}
void emit_AND(const Imm &imm)
{
emitByte(0x25);
emitUI32(imm.data);
}
void emit_MOV(const Ref &dest, const Reg &src)
{
emitByte(0x89);
emitModRM(0, src.kind, dest.kind);
}
void emit_MOV(const Reg &dest, const Reg &src)
{
emitByte(0x89);
emitModRM(3, src.kind, dest.kind);
}
void emit_MOV(const SIB &src, const Reg &dest)
{
emitByte(0x89);
emitSIB(dest.kind, src);
}
void emit_MOV(const Reg &dest, const SIB &src)
{
emitByte(0x8b);
emitSIB(dest.kind, src);
}
void emit_PUSH(const Imm &imm)
{
emitByte(0x68);
emitUI32(imm.data);
}
void emit_PUSH(const Reg &reg)
{
switch(reg.kind) {
case EAX: emitByte(0x50); break;
case ECX: emitByte(0x51); break;
case EDX: emitByte(0x52); break;
case EBX: emitByte(0x53); break;
case ESP: emitByte(0x54); break;
case EBP: emitByte(0x55); break;
case ESI: emitByte(0x56); break;
case EDI: emitByte(0x57); break;
default:
throw std::logic_error("(*_*) bug?");
}
}
void emit_PUSH(const SIB &src)
{
emitByte(0xff);
emitSIB(RegisterKind(6), src);
}
void emit_POP(const Reg &reg)
{
switch(reg.kind) {
case EAX: emitByte(0x58); break;
case ECX: emitByte(0x59); break;
case EDX: emitByte(0x5a); break;
case EBX: emitByte(0x5b); break;
case ESP: emitByte(0x5c); break;
case EBP: emitByte(0x5d); break;
case ESI: emitByte(0x5e); break;
case EDI: emitByte(0x5f); break;
default:
throw std::logic_error("(*_*) bug?");
}
}
void emit_POP(const SIB &dest)
{
emitByte(0x8f);
emitSIB(RegisterKind(0), dest);
}
void emit_CALL(void **call)
{
emitByte(0xe8);
emitUI32(0);
calls.push_back(make_pair(code.size(), call));
}
void emit_RET()
{
emitByte(0xc3); // 0xcb
}
void emit_TEST(const Imm &imm)
{
emitByte(0xa9);
emitUI32(imm.data);
}
void emit_JMP(void *key)
{
emitByte(0xe9);
emitJump(key);
}
void emit_JNZ(void *key)
{
emitByte(0x0f);
emitByte(0x85);
emitJump(key);
}
void emit_JZ(void *key)
{
emitByte(0x0f);
emitByte(0x84);
emitJump(key);
}
void emitByte(uint8_t b)
{
code.push_back(b);
}
void emitUI16(uint16_t value)
{
// little-endian
emitByte(value & 0xff); value >>= 8;
emitByte(value & 0xff);
}
void emitUI32(uint32_t value)
{
writeUI32(code.size(), value);
}
void writeUI32(size_t pos, uint32_t value)
{
if(code.size() < pos + 4) {
code.resize(pos + 4);
}
// little-endian
code[ pos] = value & 0xff; value >>= 8;
code[++pos] = value & 0xff; value >>= 8;
code[++pos] = value & 0xff; value >>= 8;
code[++pos] = value & 0xff;
}
void emitModRM(int mod, int reg, int _rm)
{
emitByte(
(mod & 0x03) << 6 |
(reg & 0x07) << 3 |
(_rm & 0x07));
}
void emitSIB(int base, int index, int scale)
{
emitByte(
(scale & 0x03) << 6 |
(index & 0x07) << 3 |
( base & 0x07));
}
void emitSIB(RegisterKind kind, const SIB &sib)
{
if(sib.scale != -1) {
emitModRM(2, kind, 4);
emitSIB(sib.base, sib.index, sib.scale);
}
else {
emitModRM(2, kind, sib.base);
}
emitUI32(sib.disp);
}
void emitJump(void *key)
{
emitUI32(0);
jumps.push_back(make_pair(code.size(), key));
}
void link(uint8_t *base)
{
for(auto i = jumps.begin(); i != jumps.end(); ++i) {
writeUI32(i->first - 4, offsetMap[i->second] - i->first);
}
for(auto i = calls.begin(); i != calls.end(); ++i) {
writeUI32(i->first - 4, intptr_t(*i->second) - intptr_t(base + i->first));
}
}
void putOffset(void *key)
{
offsetMap.insert(make_pair(key, code.size()));
}
Procedure *procedure;
State *state;
std::unordered_map<void*,size_t> offsetMap;
std::vector<uint8_t> code;
std::vector<std::pair<size_t,void *>> jumps;
std::vector<std::pair<size_t,void**>> calls;
};
State::State(const std::string &script)
{
// parses the script to create abstract syntax tree for each functions
Parser parser(script);
parser.parse();
auto &functions = parser.functions;
#ifdef DEBUG
for(auto i = functions.begin(); i != functions.end(); ++i) {
std::cerr
<< "FUNCTION: " << i->first << std::endl
<< "\t" << i->second->stringify()
<< std::endl
<< std::endl;
}
#endif
// creates empty procedures(to later be linked each other)
for(auto i = functions.begin(); i != functions.end(); ++i) {
auto &name = std::get<0>(*i);
auto &func = std::get<1>(*i);
procedures.push_back(unique_ptr<Procedure>(
new Procedure(name, procedures.size(), func->args.size())));
}
// compiles the ASTs as bytecode and links functions
for(auto i = functions.begin(); i != functions.end(); ++i) {
auto &name = std::get<0>(*i);
auto &func = std::get<1>(*i);
auto proc = lookupProcedure(name);
BytecodeAssembler assembler(this, func);
assembler.run();
proc->code = std::move(assembler.code);
proc->varNum = assembler.vars.size() - proc->argNum;
}
#ifdef DEBUG
for(auto i = procedures.begin(); i != procedures.end(); ++i) {
std::cerr << "PROCEDURE[" << (**i).id << "]: " << (**i).name << std::endl;
BytecodeDisassembler disassembler(&(**i).code[0], (**i).code.size());
disassembler.run(std::cerr);
std::cerr << std::endl;
}
#endif
// generates x86 code
std::vector<unique_ptr<CodeGen_x86>> a;
for(auto i = procedures.begin(); i != procedures.end(); ++i) {
auto g = new CodeGen_x86((*i).get(), this);
a.push_back(unique_ptr<CodeGen_x86>(g));
g->run();
}
for(size_t i = 0; i < a.size(); ++i) {
procedures[i]->native = reinterpret_cast<uint8_t*>(VirtualAlloc(
NULL, a[i]->code.size(), MEM_COMMIT, PAGE_EXECUTE_READWRITE));
}
for(size_t i = 0; i < a.size(); ++i) {
auto &code = a[i]->code;
auto dest = procedures[i]->native;
a[i]->link(dest);
memcpy(dest, &code[0], code.size());
}
#ifdef DEBUG
auto fp = fopen("aotexpr.bin", "wb");
for(size_t i = 0; i < a.size(); ++i) {
auto &code = a[i]->code;
fwrite(&code[0], 1, code.size(), fp);
}
fclose(fp);
#endif
}
}} // namespace
int main(int argc, char **argv)
{
if(argc < 3) {
fprintf(stderr, "usage:aotexpr [i/c] FILENAME\n");
return -1;
}
auto fp = fopen(argv[2], "r");
if(fp == NULL) {
fprintf(stderr, "cannot open file: %s\n", argv[1]);
return -1;
}
std::string content;
for(;;) {
char buf[4096];
auto read = fread(buf, 1, sizeof(buf), fp);
if(read > 0) {
content.append(buf, read);
}
if(read < sizeof(buf)) {
break;
}
}
fclose(fp);
try {
sf::jx::State state(content);
auto proc = state.lookupProcedure("<main>");
int value;
if(!strcmp("i", argv[1])) {
// インタープリターモード
sf::jx::Interpreter interp(&state, 4096);
value = interp.run(proc->id);
}
else {
// AOTコンパイラーモード
value = reinterpret_cast<int(*)()>(proc->native)();
}
printf("<main> returned: %d\n", value);
} catch(std::runtime_error &e) {
std::cerr << "Unhandled exception: " << e.what() << std::endl;
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment