Skip to content

Instantly share code, notes, and snippets.

@lpereira
Last active January 29, 2023 20:12
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save lpereira/3390dd11c17653d16049b505096a3f93 to your computer and use it in GitHub Desktop.
Save lpereira/3390dd11c17653d16049b505096a3f93 to your computer and use it in GitHub Desktop.
/* Maps a page of the request size near the text page that contains
* a symbol. Useful for hand-rolled JIT compilers to avoid indirect
* calls to symbols in a program on x86-64. (Part of the code here is
* a JIT compiler to encode JSON based on a struct descriptor, but it's
* not finished yet.)
*
* Written by L. Pereira <l@tia.mat.br>
*
* Thanks to Ole André V. Ravnås for the idea to use the NOREPLACE flag
* in mmap(), and to Paul Khuong for helping me make it more robust in
* systems where that flag might not exist.
*
*/
#include <assert.h>
#include <errno.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <sys/mman.h>
#include <unistd.h>
#include <string.h>
#include <stdbool.h>
#include <stdarg.h>
#include <stdlib.h>
#include <limits.h>
#include "lwan-array.h"
#include "json.h"
static void *find_free_page_near_symbol(void *symbol, size_t want_size)
{
/* This is quite inneficient, being O(n²) and all, but should be fine for the
* toy I'm working on.
*/
FILE *maps;
char buffer[512];
char *ptr;
uintptr_t max_want_addr = 0;
uintptr_t symbol_addr = (uintptr_t)symbol;
long page_size = sysconf(_SC_PAGE_SIZE);
maps = fopen("/proc/self/maps", "re");
if (!maps)
return NULL;
want_size = (want_size + (page_size - 1)) & ~(page_size - 1);
while ((ptr = fgets(buffer, 512, maps))) {
uintptr_t start, end;
if (sscanf(ptr, "%lx-%lx", &start, &end) != 2)
continue;
if (symbol_addr >= start && symbol_addr <= end) {
max_want_addr = start - want_size;
rewind(maps);
break;
}
}
if (max_want_addr) {
while ((ptr = fgets(buffer, 512, maps))) {
uintptr_t start, end;
if (sscanf(ptr, "%lx-%lx", &start, &end) != 2)
continue;
if (start >= max_want_addr && end <= max_want_addr + want_size) {
max_want_addr -= page_size;
rewind(maps);
}
}
}
fclose(maps);
return (void *)max_want_addr;
}
static void *try_map_near_symbol(void *symbol, size_t want_size, int prot)
{
int flags = MAP_ANONYMOUS | MAP_FIXED_NOREPLACE | MAP_PRIVATE;
void *addr = symbol;
for (int try = 0; try < 32; try ++) {
addr = find_free_page_near_symbol(addr, want_size);
if (!addr)
break;
void *ptr = mmap(addr, want_size, prot, flags, -1, 0);
if (ptr == addr)
return ptr;
if (ptr != MAP_FAILED) {
munmap(ptr, want_size);
} else {
if (errno == EEXIST && (flags & MAP_FIXED_NOREPLACE)) {
/* If we get EEXIST here, then we know that this kernel
* has NOREPLACE support -- we just got unlucky with a
* race condition and lost. Try finding another address
* again. */
continue;
}
/* Any other error condition means we might want to try
* just passing a hint to mmap() and seeing if the address
* it returns us is the same that we want. */
}
/* This kernel probably has no NOREPLACE support, so disable it
* for the next tries. */
flags &= ~MAP_FIXED_NOREPLACE;
}
return NULL;
}
struct jit_call_patch {
void *target;
size_t patch;
};
DEFINE_ARRAY_TYPE(jit_code, uint8_t)
DEFINE_ARRAY_TYPE(jit_call, struct jit_call_patch)
struct jit {
struct jit_code code;
struct jit_call calls;
int (*append_bytes)(const void *buf, size_t len, void *user_data);
size_t aux_func_encode_str_loc;
size_t size_for_unmapping;
void *mapped;
};
static void jit_emit(struct jit *j, int n, ...)
{
va_list ap;
va_start(ap, n);
while (n--) {
uint8_t *b = jit_code_append(&j->code);
*b = (uint8_t)va_arg(ap, int);
}
va_end(ap);
}
#define JIT32_ARGS(v) \
(uint8_t)((uint32_t)(v) & 0xff), \
(uint8_t)(((uint32_t)(v) >> 8) & 0xff), \
(uint8_t)(((uint32_t)(v) >> 16) & 0xff), \
(uint8_t)(((uint32_t)(v) >> 24) & 0xff)
#define JIT64_ARGS(v) \
(uint8_t)((uint64_t)(v) & 0xff), \
(uint8_t)(((uint64_t)(v) >> 8) & 0xff), \
(uint8_t)(((uint64_t)(v) >> 16) & 0xff), \
(uint8_t)(((uint64_t)(v) >> 24) & 0xff), \
(uint8_t)(((uint64_t)(v) >> 32) & 0xff), \
(uint8_t)(((uint64_t)(v) >> 40) & 0xff), \
(uint8_t)(((uint64_t)(v) >> 48) & 0xff), \
(uint8_t)(((uint64_t)(v) >> 56) & 0xff)
static void jit_emit_call(struct jit *j, void *func)
{
struct jit_call_patch *patch = jit_call_append(&j->calls);
jit_emit(j, 1, 0xe8); /* call imm32 */
patch->target = func;
patch->patch = jit_code_len(&j->code);
jit_emit(j, 4, JIT32_ARGS(0xbebacafe)); /* imm32 (to be patched) */
}
static void jit_init(struct jit *j, void *append_bytes)
{
jit_code_init(&j->code);
jit_call_init(&j->calls);
j->append_bytes = append_bytes;
j->size_for_unmapping = 0;
j->mapped = NULL;
/* Prologue */
jit_emit(j, 1, 0x55); /* push rbp */
jit_emit(j, 3, 0x48, 0x89, 0xe5); /* mov rbp, rsp */
jit_emit(j, 4, 0x48, 0x83, 0xec, 0x20); /* sub rsp, 0x20 */
jit_emit(j, 4, 0x48, 0x89, 0x7d, 0xe8); /* mov -0x18(rbp), rdi */
jit_emit(j, 4, 0x48, 0x89, 0x75, 0xe0); /* mov -0x20(rbp), rsi */
jit_emit(j, 7, 0xc7, 0x45, 0xfc, JIT32_ARGS(0)); /* mov -0x4(rbp), 0 */
}
static void jit_emit_append_const_strz(struct jit *j, const char *str)
{
const size_t len = strlen(str);
jit_emit(j, 4, 0x48, 0x8b, 0x45, 0xe0); /* mov rax, -0x20(rbp) */
jit_emit(j, 3, 0x48, 0x89, 0xc2); /* mov rdx, rax */
jit_emit(j, 5, 0xbe, JIT32_ARGS(len)); /* mov esi, len */
jit_emit(j, 10, 0x48, 0xbf, JIT64_ARGS(str)); /* mov rdi, str */
jit_emit_call(j, j->append_bytes);
jit_emit(j, 3, 0x09, 0x45, 0xfc); /* or -0x4(rbp), eax */
}
static void jit_emit_alignment_nops(struct jit *j)
{
const size_t n_nops = jit_code_len(&j->code) & 31;
for (size_t i = 0; i < n_nops; i++)
jit_emit(j, 1, 0x90); /* nop */
}
/*
* Routines has_zero() and has_value() are from
* https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord
*/
static inline uint64_t has_zero(uint64_t v)
{
return (v - 0x0101010101010101UL) & ~v & 0x8080808080808080UL;
}
static inline uint64_t has_value(uint64_t x, char n)
{
return has_zero(x ^ (~0UL / 255 * (uint64_t)n));
}
static char escape_as(char chr)
{
static const char escaped[] = {'"', '\\', 'b', 'f', 'n', 'r', 't', 't'};
uint64_t mask = has_value(0x225c080c0a0d0909UL, chr);
return mask == 0 ? 0 : escaped[__builtin_clzl(mask) / 8];
}
static const int jit_aux_func_encode_string;
static void jit_emit_aux_func_encode_string(struct jit *j)
{
/* FIXME: arguments to append_bytes! */
j->aux_func_encode_str_loc = jit_code_len(&j->code);
jit_emit(j, 2, 0x41, 0x54); /* push %r12 */
jit_emit(j, 1, 0x55); /* push %rbp */
jit_emit(j, 3, 0x48, 0x89, 0xfd); /* mov %rdi,%rbp */
jit_emit(j, 1, 0x53); /* push %rbx */
jit_emit(j, 3, 0x0f, 0xbe, 0x3f); /* movsbl (%rdi),%edi */
jit_emit(j, 3, 0x40, 0x84, 0xff); /* test %dil,%dil */
jit_emit(j, 2, 0x74, 0x71); /* je 14c00 <foobar2000+0x80> */
jit_emit(j, 3, 0x49, 0x89, 0xeb); /* mov %rbp,%r11 */
jit_emit(j, 3, 0x45, 0x31, 0xe4); /* xor %r12d,%r12d */
jit_emit(j, 2, 0xeb, 0x27); /* jmp 14bbe <foobar2000+0x3e> */
jit_emit_alignment_nops(j);
jit_emit_call(j, j->append_bytes);
jit_emit(j, 3, 0x41, 0x09, 0xc4); /* or %eax,%r12d */
jit_emit_call(j, j->append_bytes);
jit_emit(j, 3, 0x48, 0x89, 0xdd); /* mov %rbx,%rbp */
jit_emit(j, 3, 0x41, 0x09, 0xc4); /* or %eax,%r12d */
jit_emit(j, 3, 0x0f, 0xbe, 0x3b); /* movsbl (%rbx),%edi */
jit_emit(j, 3, 0x49, 0x89, 0xdb); /* mov %rbx,%r11 */
jit_emit(j, 3, 0x40, 0x84, 0xff); /* test %dil,%dil */
jit_emit(j, 2, 0x74, 0x25); /* je 14be3 <foobar2000+0x63> */
/*14bbe:*/
jit_emit_call(j, escape_as);
jit_emit(j, 4, 0x49, 0x8d, 0x5b, 0x01); /* lea 0x1(%r11),%rbx */
jit_emit(j, 2, 0x84, 0xc0); /* test %al,%al */
jit_emit(j, 2, 0x74, 0xe8); /* je 14bb3 <foobar2000+0x33> */
jit_emit(j, 3, 0x4c, 0x39, 0xdd); /* cmp %r11,%rbp */
jit_emit(j, 2, 0x75, 0xd0); /* jne 14ba0 <foobar2000+0x20> */
jit_emit_call(j, j->append_bytes);
jit_emit(j, 3, 0x0f, 0xbe, 0x3b); /* movsbl (%rbx),%edi */
jit_emit(j, 3, 0x41, 0x09, 0xc4); /* or %eax,%r12d */
jit_emit(j, 3, 0x49, 0x89, 0xdb); /* mov %rbx,%r11 */
jit_emit(j, 3, 0x40, 0x84, 0xff); /* test %dil,%dil */
jit_emit(j, 2, 0x75, 0xdb); /* jne 14bbe <foobar2000+0x3e> */
/* 14be3: */
jit_emit(j, 3, 0x48, 0x39, 0xdd); /* cmp %rbx,%rbp */
jit_emit(j, 2, 0x74, 0x08); /* je 14bf0 <foobar2000+0x70> */
jit_emit_call(j, j->append_bytes);
jit_emit(j, 3, 0x41, 0x09, 0xc4); /* or %eax,%r12d */
/* 14bf0: */
jit_emit(j, 1, 0x5b); /* pop %rbx */
jit_emit(j, 1, 0x5d); /* pop %rbp */
jit_emit(j, 3, 0x44, 0x89, 0xe0); /* mov %r12d,%eax */
jit_emit(j, 2, 0x41, 0x5c); /* pop %r12 */
jit_emit(j, 1, 0xc3); /* retq */
jit_emit_alignment_nops(j);
/* 14c00: */
jit_emit(j, 3, 0x45, 0x31, 0xe4); /* xor %r12d,%r12d */
jit_emit(j, 1, 0x5d); /* pop %rbp */
jit_emit(j, 3, 0x44, 0x89, 0xe0); /* mov %r12d,%eax */
jit_emit(j, 2, 0x41, 0x5c); /* pop %r12 */
jit_emit(j, 1, 0xc3); /* retq */
}
static void *jit_generate(struct jit *j)
{
const struct jit_call_patch *patch;
void *mapped;
assert(j->size_for_unmapping == 0);
/* Epilogue */
jit_emit(j, 3, 0x8b, 0x45, 0xfc); /* mov eax, -0x4(rbp) */
jit_emit(j, 1, 0xc9); /* leave */
jit_emit(j, 1, 0xc3); /* ret */
jit_emit_alignment_nops(j);
jit_emit_aux_func_encode_string(j);
mapped = try_map_near_symbol(j->append_bytes, jit_code_len(&j->code),
PROT_READ | PROT_WRITE);
if (mapped == NULL)
goto out;
memcpy(mapped, j->code.base.base, jit_code_len(&j->code));
LWAN_ARRAY_FOREACH (&j->calls, patch) {
int32_t imm32 = (int32_t)(patch->target - mapped - patch->patch - 4);
memcpy((char *)mapped + patch->patch, &imm32, sizeof(imm32));
}
if (mprotect(mapped, jit_code_len(&j->code), PROT_READ | PROT_EXEC) < 0) {
munmap(mapped, jit_code_len(&j->code));
mapped = NULL;
}
out:
if (mapped) {
j->size_for_unmapping = jit_code_len(&j->code);
printf("TRY IN GDB: disassemble %p, %p\n", mapped,
(char *)mapped + jit_code_len(&j->code));
} else {
j->size_for_unmapping = 0;
}
j->mapped = mapped;
jit_code_reset(&j->code);
jit_call_reset(&j->calls);
return mapped;
}
void jit_destroy(struct jit *j)
{
jit_code_reset(&j->code);
jit_call_reset(&j->calls);
if (j->size_for_unmapping)
munmap(j->mapped, j->size_for_unmapping);
}
void jit_emit_json_value(struct jit *j, const struct json_obj_descr *descr)
{
switch (descr->type) {
case JSON_TOK_STRING:
jit_emit_append_const_strz(j, "\"");
// jit_emit_json_escaped_str(j, descr);
jit_emit_append_const_strz(j, "\"");
break;
default:
printf("can't generate code for type %d\n", descr->type);
exit(1);
}
}
void jit_emit_json_object(struct jit *j, const struct json_obj_descr *descrs, size_t n_descr)
{
jit_emit_append_const_strz(j, "{");
for (size_t i = 0; i < n_descr; i++) {
const struct json_obj_descr *descr = &descrs[i];
jit_emit_append_const_strz(j, descr->field_name + descr->field_name_len);
jit_emit_json_value(j, descr);
if (i < n_descr - 1)
jit_emit_append_const_strz(j, ",");
}
jit_emit_append_const_strz(j, "}");
}
static int user_provided_callback(const void *buf, size_t len, void *user_data)
{
write(1, buf, len);
return 42;
}
struct hello_world_json {
const char *message;
};
static const struct json_obj_descr hello_world_json_desc[] = {
JSON_OBJ_DESCR_PRIM(struct hello_world_json, message, JSON_TOK_STRING),
};
int main()
{
struct jit j;
int (*func)(const void *ptr, void *user_data);
jit_init(&j, user_provided_callback);
jit_emit_json_object(&j, hello_world_json_desc, 1);
func = jit_generate(&j);
asm volatile ("int3");
if (!func) {
printf("Could not generate code\n");
return 1;
}
int r = func(NULL, NULL);
printf("Function returned %d\n", r);
jit_destroy(&j);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment