Skip to content

Instantly share code, notes, and snippets.

@alpha123
Last active April 24, 2023 04:27
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save alpha123/de4f22e6533a89b1443145c7042c184b to your computer and use it in GitHub Desktop.
Save alpha123/de4f22e6533a89b1443145c7042c184b to your computer and use it in GitHub Desktop.
Primitive SSA Bytecode Interpreter
#include <stdlib.h>
#include <stdbool.h>
#include <stddef.h>
#define _WITH_GETLINE
#include <stdio.h>
#include <stdint.h>
#include <string.h>
#include <ctype.h>
#include <inttypes.h>
#include <assert.h>
#define LIST_OPCODES(X) \
X(halt) /* HALT 0000 0000 0000 */ \
X(nop) /* NOP 0000 0000 0000 */ \
X(phi) /* PHI rfinal rblock1 rblock2 */ \
X(mov) /* MOV rdest rsrc 0000 */ \
X(call) /* CALL raddr raddroff iaddroff -- call raddr+raddroff+iaddroff */ \
X(ret) /* RET 0000 0000 0000 */ \
X(push) /* PUSH rsrc 0000 0000 */ \
X(pop) /* POP rdest 0000 0000 */ \
X(loadi) /* LOADI rdest ihigh ilow */ \
X(load2n) /* LOAD2N rdest bexp boff ioff -- rdest = 2^bexp+boff+ioff */ \
X(acci) /* ACCI rdest ihigh ilow -- rdest += (ihigh<<16)|ilow */ \
X(acc2n) /* ACC2N rdest bexp boff ioff -- rdest += 2^bexp+boff+ioff */ \
X(addi) /* ADDI rdest rop ival */ \
X(muli) /* MULI rdest rop ival */ \
X(add) /* ADD rdest rop1 rop2 */ \
X(sub) /* SUB rdest rop1 rop2 */ \
X(mul) /* MUL rdest rop1 rop2 */ \
X(div) /* DIV rdest rop1 rop2 */ \
X(cmp) /* CMP rop1 rop2 0000 -- r7 = rop1 <> rop2 */ \
X(jz) /* JZ raddr raddroff iaddroff -- jump raddr+raddroff+iaddroff if r7 == 0 */ \
X(jnz) /* JNZ raddr raddroff iaddroff -- jump raddr+raddroff+iaddroff if r7 != 0 */
#define DEF_ENUM(opc) op_##opc,
typedef enum {
LIST_OPCODES(DEF_ENUM)
} opcode_t;
#undef DEF_ENUM
struct vm {
// register file
int r[256];
// value stack
int *sp, *s_top;
// register stack
int *rsp, *rs_top;
uint64_t *prog;
// register map from virtual numbers to physical registers
uint8_t *rmap;
int pc;
int nextr;
};
void vm_init(struct vm *vm) {
vm->pc = 0;
vm->nextr = 33;
vm->sp = vm->s_top = (int *)calloc(1000, sizeof(int)) + 999;
vm->rsp = vm->rs_top = (int *)calloc(64*1024, sizeof(int)) + 64*1024-1;
vm->rmap = calloc(64*1024, sizeof(uint8_t));
memset(vm->r, 0, sizeof(int) * 256);
for (int i = 0; i < 33; i++)
vm->rmap[i] = i;
}
void vm_destroy(struct vm *vm) {
free(vm->s_top - 999);
free(vm->rs_top - 64*1024+1);
free(vm->rmap);
}
/*
* r0 - 0
* r1 - return address
* r2 - arg1
* r3 - arg2
* r4 - number of args on stack
* r5 - return value
* r6 - number of return values on stack
* r7 - last comparison result
* r8..r32 - preserved across function calls
* r33..r255 - scratch
*/
inline __attribute__((always_inline))
int rget(struct vm *vm, uint16_t n) {
return vm->r[vm->rmap[n]];
}
inline __attribute__((always_inline))
void rput(struct vm *vm, uint16_t n, int x) {
assert(n != 0);
if (vm->rmap[n] == 0)
vm->rmap[n] = vm->nextr++;
vm->r[vm->rmap[n]] = x;
}
#define vm_goto(addr) do{ \
uint64_t inst = vm->prog[(addr)]; \
imm1 = (inst >> 8) & 0xffff; \
imm2 = (inst >> 24) & 0xffff; \
imm3 = inst >> 40; \
goto *dispatch[inst & 0xff]; \
}while(0)
#define jump(addr) vm_goto(vm->pc=(addr))
#define next vm_goto(vm->pc++)
int vm_exec(struct vm *vm) {
#define DEF_DISPATCH(opc) [op_##opc] = &&I_##opc,
static void *dispatch[] = {
LIST_OPCODES(DEF_DISPATCH)
};
#undef DEF_DISPATCH
uint16_t imm1, imm2, imm3;
I_nop:
next;
I_halt:
return *(vm->sp+1);
I_phi:
vm->rmap[imm1] = vm->rmap[imm2] == 0 ? vm->rmap[imm3] : vm->rmap[imm2];
next;
I_mov:
rput(vm, imm1, rget(vm, imm2));
next;
/* CALL raddr raddroff iaddroff -- call raddr+raddroff+iaddroff */
I_call: {
int32_t addr = (int32_t)rget(vm, imm1) + (int32_t)rget(vm, imm2) + (int32_t)imm3;
vm->rsp -= 223;
for (int i = 223; i > 0; i--)
vm->rsp[i] = vm->r[32+i];
*(vm->rsp--) = vm->r[1];
vm->r[1] = vm->pc;
jump(addr);
}
/* RET 0000 0000 0000 */
I_ret: {
int32_t addr = vm->r[1];
vm->r[1] = *(++vm->rsp);
vm->rsp += 223;
for (int i = 223; i > 0; i--)
vm->r[32+i] = vm->rsp[-i];
jump(addr);
}
I_push:
*(vm->sp--) = rget(vm, imm1);
vm->rmap[imm1] = 0;
next;
I_pop:
rput(vm, imm1, *(++vm->sp));
next;
/* LOADI rdest ihigh ilow */
I_loadi:
rput(vm, imm1, (int)(imm2<<16) | (int)imm3);
next;
/* LOAD2N rdest bexp boff ioff -- rdest = 2^bexp+boff+ioff */
I_load2n:
rput(vm, imm1, (int)(1 << (imm2>>8)) + (int)(imm2&0xff) + (int)imm3);
next;
/* ACCI rdest ihigh ilow -- rdest += (ihigh<<16)|ilow */
I_acci:
rput(vm, imm1, rget(vm, imm1) + ((int)(imm2<<16) | (int)imm3));
next;
/* ACC2N rdest bexp boff ioff -- rdest += 2^bexp+boff+ioff */
I_acc2n:
rput(vm, imm1, rget(vm, imm1) + ((int)(1 << (imm2>>8)) + (int)(imm2&0xff) + (int)imm3));
next;
I_addi:
rput(vm, imm1, rget(vm, imm2) + (int)imm3);
next;
I_muli:
rput(vm, imm1, rget(vm, imm2) * (int)imm3);
next;
I_add:
rput(vm, imm1, rget(vm, imm2) + rget(vm, imm3));
next;
I_sub:
rput(vm, imm1, rget(vm, imm2) - rget(vm, imm3));
next;
I_mul:
rput(vm, imm1, rget(vm, imm2) * rget(vm, imm3));
next;
I_div:
rput(vm, imm1, rget(vm, imm2) / rget(vm, imm3));
next;
/* CMP rop1 rop2 0000 -- r7 = rop1 <> rop2 */
I_cmp: {
int x = rget(vm, imm1), y = rget(vm, imm2);
vm->r[7] = (x > y) - (x < y);
next;
}
/* JZ raddr raddroff iaddroff -- jump raddr+raddroff+iaddroff if r7 == 0 */
I_jz:
if (vm->r[7] == 0)
jump((int32_t)rget(vm, imm1) + (int32_t)rget(vm, imm2) + (int32_t)imm3);
else
next;
/* JNZ raddr raddroff iaddroff -- jump raddr+raddroff+iaddroff if r7 != 0 */
I_jnz:
if (vm->r[7] != 0)
jump((int32_t)rget(vm, imm1) + (int32_t)rget(vm, imm2) + (int32_t)imm3);
else
next;
}
// MurmurHash3, basically from Wikipedia
uint32_t hash(const char *s, uint32_t len) {
uint32_t c1 = 0xcc9e2d51, c2 = 0x1b873593, r1 = 15, r2 = 13, m = 5, n = 0xe6546b64,
hash = 0xb16b00b5, idx = 0;
const uint32_t *key = (const uint32_t *)s;
while (idx < len / 4) {
uint32_t k = key[idx++];
k *= c1;
k = (k << r1) | (k >> (-r1 & 31));
k *= c2;
hash ^= k;
hash = (hash << r2) | (hash >> (-r2 & 31));
hash = hash * m + n;
}
uint32_t k = 0;
switch (len % 4) {
case 3:
k ^= (uint32_t)s[len/4+2] << 16;
case 2:
k ^= (uint32_t)s[len/4+1] << 8;
case 1:
k ^= (uint32_t)s[len/4];
k *= c1;
k = (k << r1) | (k >> (-r1 & 31));
k *= c2;
hash ^= k;
}
hash ^= len;
hash ^= (hash >> 16);
hash *= 0x85ebca6b;
hash ^= (hash >> 13);
hash *= 0xc2b2ae35;
hash ^= (hash >> 16);
return hash;
}
// Dead-simple linear probing hash table
struct tbl_entry {
const char *key;
uint32_t len, addr;
};
int insert_label(struct tbl_entry *tbl, uint32_t tlen, const char *s, uint32_t len, uint32_t addr) {
uint32_t h = hash(s, len);
struct tbl_entry *t = tbl + (h & tlen-1);
while (t->key && !(t->len == len && strncmp(t->key, s, len) == 0))
++t;
if (t->key)
return 1;
t->key = s;
t->len = len;
t->addr = addr;
return 0;
}
uint32_t get_label(struct tbl_entry *tbl, uint32_t tlen, const char *s, uint32_t len) {
uint32_t h = hash(s, len);
struct tbl_entry *t = tbl + (h & tlen-1);
while (!(t->len == len && strncmp(t->key, s, len) == 0)) {
if (t == tbl + tlen)
return 0;
++t;
}
return t->addr;
}
#define PACK_I(opc) (((uint64_t)(i3)<<40) | ((uint64_t)(i2)<<24) | ((uint64_t)(i1)<<8) | op_##opc)
#define EMIT(opc) (prog[plen++] = PACK_I(opc))
#define min(a,b) ((a)>(b)?(b):(a))
uint64_t *assemble(const char *src) {
uint32_t paren_depth = 0, plen = 2;
uint64_t *prog = calloc(200, sizeof(uint64_t));
struct tbl_entry *l_tbl = calloc(128, sizeof(struct tbl_entry));
while (*src) switch (*src) {
case '(': {
++paren_depth;
do ++src; while (isspace(*src));
uint32_t p1 = strchr(src, ' ')-src, p2 = strchr(src, '\n')-src, p3 = strchr(src, ')')-src,
len = min(p1, min(p2, p3));
switch (len) {
case 5:
if (strncmp(src, "label", 5) == 0) {
src += 5;
char lbl[50];
if (sscanf(src, " %[a-z]", lbl) != 1) {
fprintf(stderr, "label needs 1 operand\n");
exit(1);
}
if (insert_label(l_tbl, 128, lbl, strlen(lbl), plen)) {
fprintf(stderr, "duplicate label %s\n", lbl);
exit(1);
}
}
else if (strncmp(src, "loadi", 5) == 0) {
src += 5;
uint16_t i1, i2, i3;
uint32_t n;
if (sscanf(src, " r%" SCNu16 " %" SCNu32, &i1, &n) != 2) {
fprintf(stderr, "loadi needs 2 operands\n");
exit(1);
}
i2 = n >> 16;
i3 = n & 0xffff;
EMIT(loadi);
}
else if (strncmp(src, "acc2n", 5) == 0) {
src += 5;
uint16_t i1, i2, i3 = 0;
if (sscanf(src, " r%" SCNu16 " %" SCNu16 " %" SCNu16, &i1, &i2, &i3) != 3) {
fprintf(stderr, "acc2n needs 3 operands\n");
exit(1);
}
EMIT(acc2n);
}
break;
case 4:
if (strncmp(src, "main", 4) == 0) {
src += 4;
uint16_t i1 = 0, i2 = 0, i3 = 0;
prog[0] = PACK_I(cmp);
char lbl[50];
if (sscanf(src, " %[a-z]", lbl) != 1) {
fprintf(stderr, "main needs 1 operand\n");
exit(1);
}
uint32_t addr = get_label(l_tbl, 128, lbl, strlen(lbl));
if (addr == 0) {
fprintf(stderr, "unknown label %s\n", lbl);
exit(1);
}
if (addr > UINT16_MAX) {
fprintf(stderr, "main address must be under %" PRIu16 "\n", UINT16_MAX);
exit(1);
}
i3 = addr;
prog[1] = PACK_I(jz);
}
else if (strncmp(src, "halt", 4) == 0) {
src += 4;
uint16_t i1 = 0, i2 = 0, i3 = 0;
EMIT(halt);
}
else if (strncmp(src, "push", 4) == 0) {
src += 4;
uint16_t i1, i2 = 0, i3 = 0;
if (sscanf(src, " r%" SCNu16, &i1) != 1) {
fprintf(stderr, "push needs 1 operand\n");
exit(1);
}
EMIT(push);
}
else if (strncmp(src, "call", 4) == 0) {
src += 4;
uint16_t i1 = 0, i2 = 0, i3 = 0;
char lbl[50];
if (sscanf(src, " %[a-z]", lbl) != 1) {
fprintf(stderr, "call needs 1 operand\n");
exit(1);
}
uint32_t addr = get_label(l_tbl, 128, lbl, strlen(lbl));
if (addr == 0) {
fprintf(stderr, "unknown label %s\n", lbl);
exit(1);
}
if (addr > UINT16_MAX) {
i1 = 18;
EMIT(push);
i3 = addr - UINT16_MAX;
addr -= UINT16_MAX;
EMIT(loadi);
EMIT(pop);
}
i3 = addr;
EMIT(call);
}
else if (strncmp(src, "acci", 4) == 0) {
src += 4;
uint16_t i1, i2, i3;
uint32_t n;
if (sscanf(src, " r%" SCNu16 " %" SCNu32, &i1, &n) != 2) {
fprintf(stderr, "acci needs 2 operands\n");
exit(1);
}
i2 = n >> 16;
i3 = n & 0xffff;
EMIT(acci);
}
else if (strncmp(src, "addi", 4) == 0) {
src += 4;
uint16_t i1, i2, i3;
if (sscanf(src, " r%" SCNu16 " r%" SCNu16 " %" SCNu16, &i1, &i2, &i3) != 3) {
fprintf(stderr, "addi needs 3 operands\n");
exit(1);
}
EMIT(addi);
}
else if (strncmp(src, "muli", 4) == 0) {
src += 4;
uint16_t i1, i2, i3;
if (sscanf(src, " r%" SCNu16 " r%" SCNu16 " %" SCNu16, &i1, &i2, &i3) != 3) {
fprintf(stderr, "muli needs 3 operands\n");
exit(1);
}
EMIT(muli);
}
break;
case 3:
if (strncmp(src, "nop", 3) == 0) {
src += 3;
uint16_t i1 = 0, i2 = 0, i3 = 0;
EMIT(nop);
}
if (strncmp(src, "ret", 3) == 0) {
src += 3;
uint16_t i1 = 0, i2 = 0, i3 = 0;
EMIT(ret);
}
else if (strncmp(src, "mov", 3) == 0) {
src += 3;
uint16_t i1, i2, i3 = 0;
if (sscanf(src, " r%" SCNu16 " r%" SCNu16, &i1, &i2) != 2) {
fprintf(stderr, "mov needs 2 operands\n");
exit(1);
}
EMIT(mov);
}
else if (strncmp(src, "phi", 3) == 0) {
src += 3;
uint16_t i1, i2, i3;
if (sscanf(src, " r%" SCNu16 " r%" SCNu16 " r%" SCNu16, &i1, &i2, &i3) != 3) {
fprintf(stderr, "phi needs 3 operands\n");
exit(1);
}
EMIT(phi);
}
else if (strncmp(src, "pop", 3) == 0) {
src += 3;
uint16_t i1, i2 = 0, i3 = 0;
if (sscanf(src, " r%" SCNu16, &i1) != 1) {
fprintf(stderr, "pop needs 1 operand\n");
exit(1);
}
EMIT(pop);
}
else if (strncmp(src, "cmp", 3) == 0) {
src += 3;
uint16_t i1, i2, i3 = 0;
if (sscanf(src, " r%" SCNu16 " r%" SCNu16, &i1, &i2) != 2) {
fprintf(stderr, "cmp needs 2 operands\n");
exit(1);
}
EMIT(cmp);
}
else if (strncmp(src, "add", 3) == 0) {
src += 3;
uint16_t i1, i2, i3;
if (sscanf(src, " r%" SCNu16 " r%" SCNu16 " r%" SCNu16, &i1, &i2, &i3) != 3) {
fprintf(stderr, "add needs 3 operands\n");
exit(1);
}
EMIT(add);
}
else if (strncmp(src, "sub", 3) == 0) {
src += 3;
uint16_t i1, i2, i3;
if (sscanf(src, " r%" SCNu16 " r%" SCNu16 " r%" SCNu16, &i1, &i2, &i3) != 3) {
fprintf(stderr, "sub needs 3 operands\n");
exit(1);
}
EMIT(sub);
}
else if (strncmp(src, "mul", 3) == 0) {
src += 3;
uint16_t i1, i2, i3;
if (sscanf(src, " r%" SCNu16 " r%" SCNu16 " r%" SCNu16, &i1, &i2, &i3) != 3) {
fprintf(stderr, "mul needs 3 operands\n");
exit(1);
}
EMIT(mul);
}
else if (strncmp(src, "div", 3) == 0) {
src += 3;
uint16_t i1, i2, i3;
if (sscanf(src, " r%" SCNu16 " r%" SCNu16 " r%" SCNu16, &i1, &i2, &i3) != 3) {
fprintf(stderr, "div needs 3 operands\n");
exit(1);
}
EMIT(div);
}
else if (strncmp(src, "jnz", 3) == 0) {
src += 3;
uint16_t i1 = 0, i2 = 0, i3 = 0;
char lbl[50];
if (sscanf(src, " %[a-z]", lbl) != 1) {
fprintf(stderr, "jnz needs 1 operand\n");
exit(1);
}
uint32_t addr = get_label(l_tbl, 128, lbl, strlen(lbl));
if (addr == 0) {
fprintf(stderr, "unknown label %s\n", lbl);
exit(1);
}
if (addr > UINT16_MAX) {
i1 = 18;
EMIT(push);
i3 = addr - UINT16_MAX;
addr -= UINT16_MAX;
EMIT(loadi);
EMIT(pop);
}
i3 = addr;
EMIT(jnz);
}
break;
case 2:
if (strncmp(src, "jz", 2) == 0) {
src += 2;
uint16_t i1 = 0, i2 = 0, i3 = 0;
char lbl[50];
if (sscanf(src, " %[a-z]", lbl) != 1) {
fprintf(stderr, "jz needs 1 operand\n");
exit(1);
}
uint32_t addr = get_label(l_tbl, 128, lbl, strlen(lbl));
if (addr == 0) {
fprintf(stderr, "unknown label %s\n", lbl);
exit(1);
}
if (addr > UINT16_MAX) {
i1 = 18;
EMIT(push);
i3 = addr - UINT16_MAX;
addr -= UINT16_MAX;
EMIT(loadi);
EMIT(pop);
}
i3 = addr;
EMIT(jz);
}
break;
}
while (*src != ')') ++src;
break;
}
case ')':
--paren_depth;
do ++src; while (isspace(*src));
break;
default:
++src;
}
if (paren_depth != 0) {
fprintf(stderr, "unbalanced parenthesis\n");
exit(1);
}
if (prog[0] == 0) {
fprintf(stderr, "no entry point found\n");
exit(1);
}
return prog;
}
int main(int argc, char **argv) {
struct vm vm;
if (argc == 1) {
char *str = NULL;
size_t s_len = 0;
while (getline(&str, &s_len, stdin) > 0) {
vm_init(&vm);
vm.prog = assemble(str);
int out = vm_exec(&vm);
printf("out: %d\n", out);
free(vm.prog);
vm_destroy(&vm);
str = NULL;
s_len = 0;
}
}
else {
FILE *f = fopen(argv[1], "r");
if (f == NULL) {
fprintf(stderr, "can't open file %s\n", argv[1]);
exit(1);
}
char str[4096];
size_t s_len = fread(str, 1, sizeof str - 1, f);
if (ferror(f)) {
fprintf(stderr, "can't read file %s\n", argv[1]);
fclose(f);
exit(1);
}
str[s_len] = '\0';
vm_init(&vm);
vm.prog = assemble(str);
int out = vm_exec(&vm);
printf("out: %d\n", out);
free(vm.prog);
vm_destroy(&vm);
fclose(f);
}
return 0;
}
(label plustwo)
(addi r5 r2 2)
(ret)
(label timestwo)
(muli r5 r2 2)
(ret)
(label finish)
(phi r38 r36 r37)
(push r38)
(halt)
(label optone)
(mov r2 r35)
(call plustwo)
(mov r36 r5)
(cmp r0 r0)
(jz finish)
(label opttwo)
(mov r2 r35)
(call timestwo)
(mov r37 r5)
(cmp r0 r0)
(jz finish)
(label start)
(loadi r33 10)
(loadi r34 40)
(add r35 r33 r34)
(loadi r40 50)
(cmp r35 r40)
(jz optone)
(jnz opttwo)
(main start)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment