Skip to content

Instantly share code, notes, and snippets.

@gmalysa
Created Feb 6, 2022
Embed
What would you like to do?
creating partially evaluated functions dynamically in C
// You must build this with -O2 in order to have __t2 be optimized correctly, otherwise
// the offset that it calculates will be incorrect and the program will crash
#include <errno.h>
#include <sys/mman.h>
#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#define NUM_EVENTS 4
typedef int (*callback)(void *ctx);
struct handler {
callback fn;
void *ctx;
};
static struct handler handlers[NUM_EVENTS] = {0};
int register_callback(uint32_t id, callback fn, void *ctx) {
if (id >= NUM_EVENTS)
return -EINVAL;
if (handlers[id].fn)
return -EBUSY;
handlers[id].fn = fn;
handlers[id].ctx = ctx;
return 0;
}
int raise_event(uint32_t id) {
if (id >= NUM_EVENTS)
return -EINVAL;
if (handlers[id].fn)
return handlers[id].fn(handlers[id].ctx);
return -ENOSYS;
}
int print_data(void *ctx) {
const char *name = ctx;
printf("got event name = %s\n", name);
return 0;
}
static const char event0[] = "event 0";
static const char event1[] = "event 1";
struct ctx_wrap {
int (*fn)(const char *, int);
const char *name;
int id;
};
int wrapped_callback(void *ctx) {
struct ctx_wrap *wrap = ctx;
return wrap->fn(wrap->name, wrap->id);
}
struct ctx_wrap *alloc_wrap(int (*fn)(const char *, int), const char *name, int id) {
struct ctx_wrap *ret;
ret = calloc(sizeof(*ret), 1);
if (!ret)
return NULL;
ret->fn = fn;
ret->name = name;
ret->id = id;
return ret;
}
void free_wrap(struct ctx_wrap *wrap) {
free(wrap);
}
void test_events(void) {
register_callback(0, print_data, event0);
register_callback(1, print_data, event1);
raise_event(0);
raise_event(1);
}
static const char event2[] = "event 2";
int print_two_data(const char *name, int id) {
printf("got event name %s, id %d\n", name, id);
return 0;
}
void test_events2(void) {
struct ctx_wrap *wrap = alloc_wrap(print_two_data, event2, 100);
register_callback(2, wrapped_callback, wrap);
raise_event(2);
free_wrap(wrap);
}
int foo(int x) {
return x + 1;
}
__attribute__((naked,section("thunk")))
void __thunk() {
asm volatile("lea -0x17(%rip), %rax");
asm volatile("mov 8(%rax), %rdi");
asm volatile("mov 0(%rax), %rax");
asm volatile("jmp *%rax");
}
extern uint8_t __start_thunk;
extern uint8_t __stop_thunk;
struct partial_fn {
int (*fn)(int);
uint64_t a0;
char body[];
};
#define LEA_SIZE 0x07
__attribute__((section("t2")))
int __t2() {
struct partial_fn *ctx;
asm("lea %c1(%%rip), %0" : "=r" (ctx) : "i" (-LEA_SIZE - offsetof(struct partial_fn, body)));
return ctx->fn(ctx->a0);
}
extern uint8_t __start_t2;
extern uint8_t __stop_t2;
typedef int (*retint)(void);
retint partial(int (*fn)(int), int x) {
uint64_t *buf;
retint result;
size_t tsize;
size_t size;
tsize = &__stop_thunk - &__start_thunk;
size = tsize + 2*sizeof(uint64_t);
buf = mmap(NULL, size, PROT_EXEC | PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
if (buf == -1) {
perror("mapping failed");
return NULL;
}
buf[0] = (uint64_t) fn;
buf[1] = x;
memcpy(buf+2, __thunk, tsize);
return (retint)(buf+2);
}
retint partial2(int (*fn)(int), int x) {
struct partial_fn *result;
size_t tsize;
size_t size;
tsize = &__stop_t2 - &__start_t2;
size = tsize + sizeof(*result);
result = mmap(NULL, size, PROT_EXEC | PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
if (result == (void *)-1) {
perror("mapping failed");
return NULL;
}
result->fn = fn;
result->a0 = x;
memcpy(result->body, __t2, tsize);
return (retint) result->body;
}
extern uint8_t __start_t3;
extern uint8_t __stop_t3;
__attribute__((section("t3")))
int __t3(void) {
void *x = &__t3;
struct partial_fn *args = x - sizeof(struct partial_fn);
return args->fn(args->a0);
}
retint partial3(int (*fn)(int), int x) {
struct partial_fn *result;
size_t tsize;
size_t size;
tsize = &__stop_t3 - &__start_t3;
size = tsize + sizeof(*result);
result = mmap(NULL, size, PROT_EXEC | PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
if (result == (void *)-1) {
perror("mapping failed");
return NULL;
}
result->fn = fn;
result->a0 = x;
memcpy(result->body, __t3, tsize);
return (retint) result->body;
}
int main(int argc, char **argv) {
retint foo_partial;
foo_partial = partial(foo, 7);
printf("%d\n", foo_partial());
foo_partial = partial2(foo, 3);
printf("%d\n", foo_partial());
foo_partial = partial3(foo, 11);
printf("%d\n", foo_partial());
test_events();
test_events2();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment