Skip to content

Instantly share code, notes, and snippets.

@skeeto
Last active November 15, 2023 13:20
Show Gist options
  • Save skeeto/5df632bad47bd71f0034d5683e26c998 to your computer and use it in GitHub Desktop.
Save skeeto/5df632bad47bd71f0034d5683e26c998 to your computer and use it in GitHub Desktop.
Partial application JIT demo, with arena
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include <windows.h>
#define new(a, t, n) (t *)alloc(a, sizeof(t)*n)
typedef struct {
char *beg, *end;
} arena;
static void *alloc(arena *a, ptrdiff_t size)
{
if (a->end-a->beg < size) {
*(volatile int *)0 = 0;
}
return memset(a->end -= size, 0, size);
}
// Create an arena of executable memory for creating functions.
static arena newjitarena(ptrdiff_t size)
{
arena a = {0};
int type = MEM_RESERVE|MEM_COMMIT;
a.beg = VirtualAlloc(0, size, type, PAGE_EXECUTE_READWRITE);
a.end = a.beg ? a.beg + size : 0;
return a;
}
// Partially-apply a 1-arg function into a 0-arg function.
static void *partial_1to0(void *target, uintptr_t arg, arena *jit)
{
unsigned char *f = new(jit, unsigned char, 24);
unsigned char *p = f;
*p++ = 0x48; // mov arg, %rcx
*p++ = 0xb9; // "
*p++ = arg>> 0; *p++ = arg>> 8; *p++ = arg>>16; *p++ = arg>>24;
*p++ = arg>>32; *p++ = arg>>40; *p++ = arg>>48; *p++ = arg>>56;
*p++ = 0x48; // mov dst, %rax
*p++ = 0xb8; // "
uintptr_t dst = (uintptr_t)target;
*p++ = dst>> 0; *p++ = dst>> 8; *p++ = dst>>16; *p++ = dst>>24;
*p++ = dst>>32; *p++ = dst>>40; *p++ = dst>>48; *p++ = dst>>56;
*p++ = 0xff; // jmp *%rax
*p++ = 0xe0; // "
return f;
}
// Partial-left-apply a 2-arg function into a 1-arg function.
static void *partial_left2to1(void *target, uintptr_t arg, arena *jit)
{
unsigned char *f = new(jit, unsigned char, 32);
unsigned char *p = f;
*p++ = 0x48; // mov %rcx, %rdx
*p++ = 0x89; // "
*p++ = 0xca; // "
*p++ = 0x48; // mov arg, %rcx
*p++ = 0xba; // "
*p++ = arg>> 0; *p++ = arg>> 8; *p++ = arg>>16; *p++ = arg>>24;
*p++ = arg>>32; *p++ = arg>>40; *p++ = arg>>48; *p++ = arg>>56;
*p++ = 0x48; // mov dst, %rax
*p++ = 0xb8; // "
uintptr_t dst = (uintptr_t)target;
*p++ = dst>> 0; *p++ = dst>> 8; *p++ = dst>>16; *p++ = dst>>24;
*p++ = dst>>32; *p++ = dst>>40; *p++ = dst>>48; *p++ = dst>>56;
*p++ = 0xff; // jmp *%rax
*p++ = 0xe0; // "
return f;
}
// Partial-right-apply a 3-arg function into a 2-arg function.
static void *partial_right3to2(void *target, uintptr_t arg, arena *jit)
{
unsigned char *f = new(jit, unsigned char, 24);
unsigned char *p = f;
*p++ = 0x49; // mov arg, %r8
*p++ = 0xb8; // "
*p++ = arg>> 0; *p++ = arg>> 8; *p++ = arg>>16; *p++ = arg>>24;
*p++ = arg>>32; *p++ = arg>>40; *p++ = arg>>48; *p++ = arg>>56;
*p++ = 0x48; // mov dst, %rax
*p++ = 0xb8; // "
uintptr_t dst = (uintptr_t)target;
*p++ = dst>> 0; *p++ = dst>> 8; *p++ = dst>>16; *p++ = dst>>24;
*p++ = dst>>32; *p++ = dst>>40; *p++ = dst>>48; *p++ = dst>>56;
*p++ = 0xff; // jmp *%rax
*p++ = 0xe0; // "
return f;
}
// Demonstration of partial_1to0
typedef int func0(void);
static int square(int x)
{
return x * x;
}
// Create an array of functions returning the square of their index.
static func0 **gensquarers(int n, arena *jit)
{
func0 **funcs = new(jit, func0 *, n);
for (int i = 0; i < n; i++) {
funcs[i] = partial_1to0(square, i, jit);
}
return funcs;
}
static void demo1(arena scratch, int n)
{
func0 **square = gensquarers(n, &scratch);
for (int i = 0; i < n; i++) {
printf("square[%d]() = %d\n", i, square[i]());
}
}
// Demonstration of partial_left2to1
typedef int func1(int);
static int add(int x, int y)
{
return x + y;
}
// Create an array of functions that bias by their index.
static func1 **genbiasers(int n, arena *jit)
{
func1 **funcs = new(jit, func1 *, n);
for (int i = 0; i < n; i++) {
funcs[i] = partial_left2to1(add, i, jit);
}
return funcs;
}
void demo2(arena scratch, int n)
{
func1 **bias = genbiasers(n, &scratch);
for (int i = 0; i < n; i++) {
printf("bias[%d](10) = %d\n", i, bias[i](10));
}
}
// Demonstration of partial_right3to2
typedef enum {ASCEND, DESCEND} sortdir;
static int intcmp(int *a, int *b, sortdir dir)
{
return dir ? *a - *b : *b - *a;
}
static void demo3(arena scratch)
{
int array[] = {4, 1, 3, 2};
typedef int (*qsortcmp)(const void *, const void *);
qsortcmp descend = partial_right3to2(intcmp, ASCEND, &scratch);
qsortcmp ascend = partial_right3to2(intcmp, DESCEND, &scratch);
qsort(array, 4, sizeof(*array), ascend);
printf("%d %d %d %d\n", array[0], array[1], array[2], array[3]);
qsort(array, 4, sizeof(*array), descend);
printf("%d %d %d %d\n", array[0], array[1], array[2], array[3]);
}
int main(void)
{
arena jit = newjitarena(1<<21);
demo1(jit, 6);
demo2(jit, 6);
demo3(jit);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment