Skip to content

Instantly share code, notes, and snippets.

@pervognsen
Last active February 7, 2024 11:51
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save pervognsen/96d116fff14d95ffe51cf084c8604d64 to your computer and use it in GitHub Desktop.
Save pervognsen/96d116fff14d95ffe51cf084c8604d64 to your computer and use it in GitHub Desktop.
// Heavily based on ideas from https://github.com/LuaJIT/LuaJIT/blob/v2.1/src/lj_opt_fold.c
// The most fundamental deviation is that I eschew the big hash table and the lj_opt_fold()
// trampoline for direct tail calls. The biggest problem with a trampoline is that you lose
// the control flow context. Another problem is that there's too much short-term round-tripping
// of data through memory. It's also easier to do ad-hoc sharing between rules with my approach.
// From what I can tell, it also isn't possible to do general reassociation with LJ's fold engine
// since that requires non-tail recursion, so LJ does cases like (x + n1) + n2 => x + (n1 + n2)
// but not (x + n1) + (y + n2) => x + (y + (n1 + n2)) which is common in address generation. The
// code below has some not-so-obvious micro-optimizations for register passing and calling conventions,
// e.g. the unary_cse/binary_cse parameter order, the use of long fields in ValueRef.
// Assembly listing for add: https://gist.github.com/pervognsen/592365b12c1295fa49907349054c3089
#include <stdint.h>
enum {
NUM,
NEG,
ADD,
SUB,
NUM_OPS,
MAX_VALUES = 1 << 16,
};
#ifdef _MSC_VER
#define INLINE __forceinline
#else
#define INLINE __attribute__((always_inline))
#endif
#define MAX(x, y) ((x) >= (y) ? (x) : (y))
#define SWAP(x, y) do { ValueRef t = (x); (x) = (y); (y) = t; } while (0)
#define VALUE (buffer + MAX_VALUES/2)
typedef struct Value Value;
typedef struct ValueRef ValueRef;
struct Value {
uint8_t op;
int16_t prev;
union {
uint32_t num;
struct {
int16_t left_pos;
int16_t right_pos;
};
};
};
// This spreads across two registers on SysV (LP64) and one on Win64 (LLP64) for calls.
struct ValueRef {
long pos;
long op;
};
int16_t latest[NUM_OPS];
Value buffer[MAX_VALUES];
int16_t bottom, top;
INLINE uint32_t getnum(ValueRef ref) {
return VALUE[ref.pos].num;
}
INLINE ValueRef getleft(ValueRef ref) {
Value *v = &VALUE[ref.pos];
return (ValueRef){v->left_pos, VALUE[v->left_pos].op};
}
INLINE ValueRef getright(ValueRef ref) {
Value *v = &VALUE[ref.pos];
return (ValueRef){v->right_pos, VALUE[v->right_pos].op};
}
long num_pos(uint32_t num_) {
for (long pos = bottom; pos < 0; pos++) {
if (VALUE[pos].num == num_) return pos;
}
long pos = --bottom;
VALUE[pos] = (Value){NUM, .num = num_};
return pos;
}
INLINE ValueRef num(uint32_t num_) {
return (ValueRef){num_pos(num_), NUM};
}
long unary_cse_pos(ValueRef left, long op) {
long pos = latest[op];
while (pos > left.pos) {
Value *v = &VALUE[pos];
if (v->left_pos == left.pos)
return pos;
pos = v->prev;
}
pos = top++;
VALUE[pos] = (Value){op, latest[op], .left_pos = left.pos};
latest[op] = pos;
return pos;
}
INLINE ValueRef unary_cse(ValueRef left, long op) {
return (ValueRef){unary_cse_pos(left, op), op};
}
long binary_cse_pos(ValueRef left, ValueRef right, long op) {
long pos = latest[op];
while (pos > MAX(left.pos, right.pos)) {
Value *v = &VALUE[pos];
if (v->left_pos == left.pos && v->right_pos == right.pos)
return pos;
pos = v->prev;
}
pos = top++;
VALUE[pos] = (Value){op, latest[op], .left_pos = left.pos, .right_pos = right.pos};
latest[op] = pos;
return pos;
}
INLINE ValueRef binary_cse(ValueRef left, ValueRef right, long op) {
return (ValueRef){binary_cse_pos(left, right, op), op};
}
// Opaque for demo purposes.
ValueRef neg(ValueRef left);
ValueRef sub(ValueRef left, ValueRef right);
ValueRef add(ValueRef left, ValueRef right) {
// Canonical commutation: x + y => y + x (constants have negative pos and move to the left)
if (left.pos > right.pos) SWAP(left, right);
switch (left.op) {
case NEG:
// Strength reduction: (-x) + y => y - x
return sub(right, getleft(left));
case NUM:
// Strength reduction: 0 + x => x
if (getnum(left) == 0) return right;
// Constant folding: n1 + n2 => n
if (right.op == NUM) return num(getnum(left) + getnum(right));
}
switch (right.op) {
case NEG:
// Strength reduction: x + (-y) => x - y
return sub(left, getleft(right));
case ADD:
// Reassociation: x + (y + z) => (x + y) + z
return add(add(left, getleft(right)), getright(right));
}
return binary_cse(left, right, ADD);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment