Skip to content

Instantly share code, notes, and snippets.

@usbuild
Last active December 18, 2023 09:11
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save usbuild/b2d0fccf11afcbe8f6878007626865c2 to your computer and use it in GitHub Desktop.
Save usbuild/b2d0fccf11afcbe8f6878007626865c2 to your computer and use it in GitHub Desktop.
A naive calculator with interpreter and jit. Just for teaching and demo.
#include <ctype.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/mman.h>
#include <assert.h>
#define STACK_MAX 100
#define CODE_SIZE (2 * 1024 * 1024)
typedef unsigned char uchar;
int precedent(char op) {
switch (op) {
case '(':
return 0;
case '-':
case '+':
return 1;
case '*':
case '/':
return 2;
}
}
int isoperator(char op) {
if (op == '-' || op == '+' || op == '*' || op == '/') return 1;
return 0;
}
char *postfixify(const char *buf) {
char stack[STACK_MAX] = {0};
char *sp = stack;
char *output = calloc(STACK_MAX, 1);
char *op = output;
while (*buf) {
char c = *buf++;
int digit = 0;
if (c == ' ') continue;
while (isalnum(c)) {
digit = 1;
*(op++) = c;
c = *buf++;
}
if (digit) {
*(op++) = ' ';
}
if (!c) break;
if (c == '(') {
*(sp++) = c;
continue;
}
if (c == ')') {
int i = -1;
for (; sp[i] != '('; --i) {
*(op++) = sp[i];
*(op++) = ' ';
}
sp += i;
continue;
}
if (isoperator(c)) {
int p = precedent(c);
while (sp > stack) {
if (precedent(sp[-1]) >= p) {
*(op++) = sp[-1];
*(op++) = ' ';
sp--;
} else {
break;
}
}
*(sp++) = c;
continue;
}
}
while (sp > stack) {
*(op++) = *--sp;
*(op++) = ' ';
}
return output;
}
enum OP {
NOP = 0,
RET,
PUSHV,
PUSHC,
ADD,
SUB,
MUL,
DIV,
};
int *bccompile(const char *buf) {
int *bc = calloc(STACK_MAX, sizeof(int));
int *bcp = bc;
char *postfix = postfixify(buf);
char *p = postfix;
while (*p) {
char c = *p++;
if (isalnum(c)) {
if (islower(c)) {
*(bcp++) = PUSHV;
*(bcp++) = c - 'a';
} else {
int val = 0;
while (isdigit(c)) {
val = val * 10 + c - '0';
c = *p++;
}
*(bcp++) = PUSHC;
*(bcp++) = val;
}
}
if (c == ' ') continue;
switch (c) {
case '+':
*(bcp++) = ADD;
break;
case '-':
*(bcp++) = SUB;
break;
case '*':
*(bcp++) = MUL;
break;
case '/':
*(bcp++) = DIV;
break;
}
}
*(bcp++) = RET;
free(postfix);
return bc;
}
int interpret(const int *bcp, const int *args) {
int stack[STACK_MAX] = {0};
int *sp = stack;
bcp--;
while (*++bcp) {
switch (*bcp) {
case NOP:
break;
case RET:
return sp[-1];
case PUSHV:
*(sp++) = args[*++bcp];
break;
case PUSHC:
*(sp++) = *(++bcp);
break;
case ADD:
sp[-2] = sp[-2] + sp[-1];
sp--;
break;
case SUB:
sp[-2] = sp[-2] - sp[-1];
sp--;
break;
case MUL:
sp[-2] = sp[-2] * sp[-1];
sp--;
break;
case DIV:
sp[-2] = sp[-2] / sp[-1];
sp--;
break;
}
}
}
#define GPRDEF(_) \
_(RAX) \
_(RCX) \
_(RDX) \
_(RBX) \
_(RSP) \
_(RBP) _(RSI) _(RDI) _(R8) _(R9) _(R10) _(R11) _(R12) _(R13) _(R14) _(R15)
#define RIDENUM(name) RID_##name,
enum { GPRDEF(RIDENUM) RID_MAX };
#define REX_64 0b01001000
#define MOV_RR 0x89
#define MODRM(name1, name2) ((0b11 << 6) | (name1 << 3) | (name2))
uchar *bin_rr(uchar *code, uchar op, uchar src, uchar target, int ext) {
uchar rex = REX_64;
if (src & 0x8) {
rex |= 1 << 2;
src &= ~(uchar)(0x8);
}
if (target & 0x8) {
rex |= 1;
target &= ~(uchar)(0x8);
}
*(code++) = rex;
if (ext) {
*(code++) = 0x0F;
}
*(code++) = op;
*(code++) = ((0b11 << 6) | (src << 3) | (target));
return code;
}
// maybe buggy
uchar *bin_mr(uchar *code, uchar op, uchar src, uchar target, char offset,
int ext) {
uchar rex = REX_64;
if (src & 0x8) {
rex |= 1 << 2;
src &= ~(uchar)(0x8);
}
if (target & 0x8) {
rex |= 1;
target &= ~(uchar)(0x8);
}
*(code++) = rex;
if (ext) {
*(code++) = 0x0F;
}
*(code++) = op;
*(code++) = ((0b01 << 6) | (src << 3) | (target));
*(code++) = offset;
*(code++) = 0xff;
}
uchar *movq_ir(uchar *code, long val, uchar target) {
uchar *v = (uchar *)&val;
uchar rex = REX_64;
if (target & 0x8) {
rex |= 1;
target &= ~(uchar)(0x8);
}
*(code++) = rex;
*(code++) = 0xB8 + target;
*(code++) = v[0];
*(code++) = v[1];
*(code++) = v[2];
*(code++) = v[3];
*(code++) = v[4];
*(code++) = v[5];
*(code++) = v[6];
*(code++) = v[7];
return code;
}
void jitcompile(const int *bcp, uchar *code) {
uchar *cp = code;
uchar *cu;
const char *reg_stack[] = {"rbx", "rcx", "rdi", "rsi", "r8", "r9",
"r10", "r11", "r12", "r13", "r14", "r15"};
uchar reg_stacki[] = {RID_RBX, RID_RCX, RID_RDI, RID_RSI, RID_R8,
RID_R9, RID_R10, RID_R11, RID_R11, RID_R12,
RID_R13, RID_R14, RID_R15};
bcp--;
int sp = 0;
FILE *f = fopen("/tmp/test_jit.s", "w");
fprintf(f, ".globl myfunc\n");
fprintf(f, ".type myfunc, @function\n");
fprintf(f, "myfunc:\n");
fprintf(f, "push %rbp\n");
*(cp++) = 0x50 + RID_RBP;
fprintf(f, "movq %rsp, %rbp\n");
cp = bin_rr(cp, MOV_RR, RID_RSP, RID_RBP, 0);
fprintf(f, "movq %rdi, -8(%rbp)\n");
cp = bin_mr(cp, MOV_RR, RID_RDI, RID_RBP, -8, 0);
fprintf(f, "movq %rdi, -16(%rbp)\n");
cp = bin_mr(cp, MOV_RR, RID_RSI, RID_RBP, -16, 0);
fprintf(f, "movq %rdi, -24(%rbp)\n");
cp = bin_mr(cp, MOV_RR, RID_RDX, RID_RBP, -24, 0);
fprintf(f, "movq %rdi, -32(%rbp)\n");
cp = bin_mr(cp, MOV_RR, RID_R10, RID_RBP, -32, 0);
fprintf(f, "movq %rdi, -40(%rbp)\n");
cp = bin_mr(cp, MOV_RR, RID_R8, RID_RBP, -40, 0);
fprintf(f, "movq %rdi, -48(%rbp)\n");
cp = bin_mr(cp, MOV_RR, RID_R9, RID_RBP, -48, 0);
while (*++bcp) {
switch (*bcp) {
case NOP:
break;
case RET:
fprintf(f, "movq %rbx, %rax\n");
cp = bin_rr(cp, MOV_RR, RID_RBX, RID_RAX, 0);
fprintf(f, "popq %rbp\n");
*(cp++) = 0x58 + RID_RBP;
fprintf(f, "retq\n");
*(cp++) = 0xc3;
for (cu = code; cu < cp; ++cu) {
fprintf(f, "%02x ", *cu);
}
fclose(f);
return;
case PUSHV:
++bcp;
int offset = (*bcp + 1) * -8;
fprintf(f, "movq %d(%rbp), %%%s\n", offset, reg_stack[sp]);
cp = bin_mr(cp, 0x8B, reg_stacki[sp], RID_RBP, offset, 0);
sp++;
break;
case PUSHC:
++bcp;
fprintf(f, "movq $%d, %%%s\n", *bcp, reg_stack[sp]);
cp = movq_ir(cp, *bcp, reg_stacki[sp]);
sp++;
break;
case ADD:
fprintf(f, "addq %%%s, %%%s\n", reg_stack[sp - 1],
reg_stack[sp - 2]);
cp =
bin_rr(cp, 0x01, reg_stacki[sp - 1], reg_stacki[sp - 2], 0);
sp--;
break;
case SUB:
fprintf(f, "subq %%%s, %%%s\n", reg_stack[sp - 1],
reg_stack[sp - 2]);
cp =
bin_rr(cp, 0x29, reg_stacki[sp - 1], reg_stacki[sp - 2], 0);
sp--;
break;
case MUL:
fprintf(f, "imulq %%%s, %%%s\n", reg_stack[sp - 1],
reg_stack[sp - 2]);
cp =
bin_rr(cp, 0xAF, reg_stacki[sp - 2], reg_stacki[sp - 1], 1);
sp--;
break;
case DIV:
fprintf(f, "movq %rax, %%%s\n", reg_stack[sp]);
fprintf(f, "movq %rdx, %%%s\n", reg_stack[sp + 1]);
fprintf(f, "xor %rdx, %rdx\n");
fprintf(f, "movq %%%s, %rax\n", reg_stack[sp - 2]);
fprintf(f, "movq %%%s, -56(%rbp)\n", reg_stack[sp - 1]);
fprintf(f, "cqto\n");
fprintf(f, "idivq -56(%rbp)\n");
fprintf(f, "movq %%rax, %%%s\n", reg_stack[sp - 2]);
fprintf(f, "movq %%%s, %rax\n", reg_stack[sp]);
fprintf(f, "movq %%%s, %rdx\n", reg_stack[sp + 1]);
cp = bin_rr(cp, MOV_RR, RID_RAX, reg_stacki[sp], 0);
cp = bin_rr(cp, MOV_RR, RID_RDX, reg_stacki[sp + 1], 0);
cp = bin_rr(cp, 0x31, RID_RDX, RID_RDX, 0);
cp = bin_rr(cp, MOV_RR, reg_stacki[sp - 2], RID_RAX, 0);
// addr
cp = bin_mr(cp, MOV_RR, reg_stacki[sp - 1], RID_RBP, -56, 0);
*(cp++) = REX_64;
*(cp++) = 0x99; // cqto
*(cp++) = REX_64;
*(cp++) = 0xf7;
*(cp++) = ((0b01 << 6) | (7 << 3) | (RID_RBP));
*(cp++) = -56;
cp = bin_rr(cp, MOV_RR, RID_RAX, reg_stacki[sp - 2], 0);
cp = bin_rr(cp, MOV_RR, reg_stacki[sp], RID_RAX, 0);
cp = bin_rr(cp, MOV_RR, reg_stacki[sp + 1], RID_RDX, 0);
sp--;
break;
}
}
}
typedef long (*MyFunc)();
int calc(const char *val) {
int *bc = bccompile(val);
void *ptr = mmap(0, CODE_SIZE, PROT_READ | PROT_WRITE | PROT_EXEC,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
jitcompile(bc, ptr);
MyFunc f = ptr;
int ret = f();
free(bc);
munmap(ptr, CODE_SIZE);
return ret;
}
#define SIMPLE_TEST_ASSERT(x) assert((x) == calc(#x));
void test_all() {
SIMPLE_TEST_ASSERT(1+2+3);
SIMPLE_TEST_ASSERT(2/2);
SIMPLE_TEST_ASSERT((2 + 2)/2);
SIMPLE_TEST_ASSERT((2 + 2)*2);
SIMPLE_TEST_ASSERT(((2 + 2)*2)/4+1);
SIMPLE_TEST_ASSERT(1+2+3*4-5+6);
}
int main(int argc, const char *argv[]) {
test_all();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment