Skip to content

Instantly share code, notes, and snippets.

Created August 10, 2010 22:45
Show Gist options
  • Save anonymous/518146 to your computer and use it in GitHub Desktop.
Save anonymous/518146 to your computer and use it in GitHub Desktop.
#include <algorithm>
#include <chrono>
#include <cstdint>
#include <deque>
#include <functional>
#include <forward_list>
#include <iterator>
#include <iostream>
#include <memory>
#include <limits>
#include <stack>
#include <stdexcept>
#include <vector>
#include <sys/mman.h>
template <class F>
std::chrono::milliseconds
take_time(F&& f) {
namespace chrono = std::chrono;
auto const begin = chrono::system_clock::now();
f();
auto const end = chrono::system_clock::now();
return chrono::duration_cast<chrono::milliseconds>(end - begin);
}
template <class T>
struct mmap_allocator {
typedef T* pointer;
typedef T const* const_pointer;
typedef T& reference;
typedef T const& const_reference;
typedef void* void_pointer;
typedef void const* const_void_pointer;
typedef T value_type;
typedef size_t size_type;
typedef ptrdiff_t difference_type;
template <class U>
struct rebind {
typedef mmap_allocator<U> other;
};
T* allocate(size_t const n) {
auto p = reinterpret_cast<T*>(
mmap(NULL, ((n * sizeof(T)) + 4095) & ~4096,
PROT_READ | PROT_WRITE | PROT_EXEC,
MAP_ANON | MAP_SHARED, 0, 0));
if (p == reinterpret_cast<T*>(MAP_FAILED)) {
throw std::bad_alloc();
}
return p;
}
void deallocate(T* p, size_t n) {
munmap(p, ((n * sizeof(T)) + 4095) & ~4096);
}
template <class... Args>
void construct(T* p, Args&&... args) {
new (p) T(std::forward<Args>(args)...);
}
void destroy(T* p) {
p->~T();
}
};
struct x86_64_asm {
private:
enum {
REX_W = 1 << 6 | 1 << 3,
REX_R = 1 << 6 | 1 << 2,
REX_X = 1 << 6 | 1 << 1,
REX_B = 1 << 6 | 1
};
typedef std::deque<unsigned char> block_type;
typedef std::forward_list<block_type> code_type;
public:
typedef code_type::iterator marker;
// PUSH r64: 50+ rd
// POP r64: 58+ rd
// RET: C3
// ADD r/m64 imm8: REX.W 83 /0 ib
// AND r/m64 imm32: REX.W 81 /0 id
// AND r/m64 imm8: REX.W 83 /4 ib
// AND r/m64 imm32: REX.W 81 /4 id
// SUB r/m64 imm8: REX.W 83 /5 ib
// SUB r/m64 imm32: REX.W 81 /5 id
// CMP r/m64 imm8: REX.W 83 /7 ib
// JMP rel32: E9 imm32
// JE rel32: 0F 84 imm32
// JNE rel32: 0F 85 imm32
// CALL r/m64: FF /2
// MOV r/m64, r64: REX.W 89 /r
// MOV r64, r/m64: REX.W 8B /r
// MOV r64, imm64: REX.W B8+ rq
static_assert(sizeof(ptrdiff_t) == 8, "sizeof(ptrdiff_t) == 8");
static_assert(sizeof(ptrdiff_t) == sizeof(void*), "sizeof(ptrdiff_t) == sizeof(void*)");
void emit_enter() {
auto bb = cons_block();
// PUSH rbp
bb->insert(bb->end(), {0x50 + 5});
// MOV rbp, rsp
bb->insert(bb->end(), {REX_W, 0x89, 0xE5});
// PUSH r14
bb->insert(bb->end(), {REX_B, 0x50 + 6});
// PUSH r15
bb->insert(bb->end(), {REX_B, 0x50 + 7});
// MOV r14, &putchar
bb->insert(bb->end(), {REX_W | REX_B, 0xB8 + 6});
write_imm64(bb, ptrdiff_t(&putchar));
// MOV r15, &getchar
bb->insert(bb->end(), {REX_W | REX_B, 0xB8 + 7});
write_imm64(bb, ptrdiff_t(&getchar));
printf("%p %p\n", &putchar, &getchar);
}
void emit_leave() {
auto bb = code_.begin();
// POP r14
bb->insert(bb->end(), {REX_B, 0x58 + 6});
// POP r15
bb->insert(bb->end(), {REX_B, 0x58 + 7});
// MOV rsp, rbp
bb->insert(bb->end(), {REX_W, 0x89, 0xEC});
// POP rbp
bb->insert(bb->end(), {0x58 + 5});
// RET
bb->insert(bb->end(), {0xC3});
}
void emit_modify_stack_pointer(ptrdiff_t n) {
auto bb = code_.begin();
n *= sizeof(int);
if (std::numeric_limits<int8_t>::min() <= n && n <= std::numeric_limits<int8_t>::max()) {
if (n > 0) {
// ADD rdi, imm8
bb->insert(bb->end(), {REX_W, 0x83, 0xC7, (unsigned char)(n)});
} else {
// SUB rdi, imm8
bb->insert(bb->end(), {REX_W, 0x83, 0xEF, (unsigned char)(-n)});
}
} else if (std::numeric_limits<int32_t>::min() <= n && n <= std::numeric_limits<int32_t>::max()) {
if (n > 0) {
// ADD rdi, imm32
bb->insert(bb->end(), {REX_W, 0x81, 0xC7});
write_imm32(bb, n);
} else {
// SUB rdi, imm32
bb->insert(bb->end(), {REX_W, 0x81, 0xEF});
write_imm32(bb, -n);
}
} else {
// FIMXE
throw std::runtime_error(__func__);
}
}
void emit_modify_stack_value(ptrdiff_t n) {
auto bb = code_.begin();
if (std::numeric_limits<int8_t>::min() <= n && n <= std::numeric_limits<int8_t>::max()) {
if (n > 0) {
// ADD dword ptr [rdi], imm8
bb->insert(bb->end(), {0x83, 0x07, (unsigned char)(n)});
} else {
// SUB dword ptr [rdi], imm8
bb->insert(bb->end(), {0x83, 0x2F, (unsigned char)(-n)});
}
} else if (std::numeric_limits<int32_t>::min() <= n && n <= std::numeric_limits<int32_t>::max()) {
if (n > 0) {
// ADD dword ptr [rdi], imm32
bb->insert(bb->end(), {0x81, 0x07});
write_imm32(bb, n);
} else {
// SUB dword ptr [rdi], imm32
bb->insert(bb->end(), {0x81, 0x2F});
write_imm32(bb, -n);
}
} else {
// FIMXE
throw std::runtime_error(__func__);
}
}
void emit_putchar() {
auto bb = code_.begin();
// PUSH rdi
bb->insert(bb->end(), {0x50 + 7});
// MOV rdi, dword ptr [rdi]
bb->insert(bb->end(), {0x8B, 0x3F});
// CALL r14
bb->insert(bb->end(), {REX_B, 0xFF, 0xD6});
// POP rdi
bb->insert(bb->end(), {0x58 + 7});
}
void emit_getc() {
auto bb = code_.begin();
// PUSH rdi
bb->insert(bb->end(), {0x50 + 7});
// CALL r15
bb->insert(bb->end(), {REX_B, 0xFF, 0xD7});
// MOV dword ptr [rdi], rax
bb->insert(bb->end(), {0x89, 0x07});
// POP rdi
bb->insert(bb->end(), {0x58 + 7});
}
marker emit_jump_forward() {
auto bb = code_.begin();
// CMP dword ptr [rdi], 0
bb->insert(bb->end(), {0x83, 0x3F, 0});
// JE rel
bb->insert(bb->end(), {0x0F, 0x84});
write_imm32(bb, 0);
cons_block();
return bb;
}
void emit_jump_backward(marker const& m) {
auto bb = code_.begin();
auto const rel =
9 + std::accumulate(bb, m, ptrdiff_t(0),
[] (ptrdiff_t const acc, block_type const& b) { return acc + b.size(); });
// CMP dword ptr [rdi], 0
bb->insert(bb->end(), {0x83, 0x3F, 0});
// JNE rel
bb->insert(bb->end(), {0x0F, 0x85});
write_imm32(bb, -rel);
m->erase(m->end() - 4, m->end());
write_imm32(m, rel);
cons_block();
}
template <class Iterator>
void write(Iterator it) {
code_.reverse();
std::for_each(code_.cbegin(), code_.cend(), [&] (block_type const& b) {
std::copy(b.cbegin(), b.cend(), it);
});
}
private:
static int static_getchar() {
return getchar();
}
static void static_putchar(int ch) {
putchar(ch);
}
marker cons_block() {
code_.push_front(block_type());
return code_.begin();
}
void write_imm32(marker bb, ptrdiff_t x) const {
bb->insert(bb->end(), {
(unsigned char)(x & 0xFF),
(unsigned char)((x >> 8) & 0xFF),
(unsigned char)((x >> 16) & 0xFF),
(unsigned char)((x >> 24) & 0xFF)});
}
void write_imm64(marker bb, ptrdiff_t x) const {
bb->insert(bb->end(), {
(unsigned char)(x & 0xFF),
(unsigned char)((x >> 8) & 0xFF),
(unsigned char)((x >> 16) & 0xFF),
(unsigned char)((x >> 24) & 0xFF),
(unsigned char)((x >> 32) & 0xFF),
(unsigned char)((x >> 40) & 0xFF),
(unsigned char)((x >> 48) & 0xFF),
(unsigned char)((x >> 56) & 0xFF)});
}
code_type code_;
};
struct bf_function {
private:
typedef std::vector<unsigned char, mmap_allocator<unsigned char> > code_vector;
public:
template <class Iterator>
bf_function(Iterator begin, Iterator end)
: code_(std::make_shared<code_vector>(begin, end)) {
}
void operator ()(int* stack) {
auto f = reinterpret_cast<void (*)(int*)>(code_->data());
// printf("disass %p %p\n", &code_->front(), &code_->back() + 1);
// asm ("int3;");
f(stack);
}
private:
std::shared_ptr<
std::vector<
unsigned char,
mmap_allocator<unsigned char>
>
> code_;
};
template <class Assembler>
struct bf_compiler {
typedef std::function<void (int*)> function_type;
template <class It>
function_type compile(It begin, It end) {
Assembler assm;
std::stack<typename Assembler::marker> markers;
assm.emit_enter();
while (begin != end) {
switch (*begin) {
case '>':
case '<':
case '+':
case '-': {
ptrdiff_t n = 0;
auto const ch = *begin;
for (; begin != end && *begin == ch; ++begin, ++n) {}
switch (ch) {
case '>': assm.emit_modify_stack_pointer(n); break;
case '<': assm.emit_modify_stack_pointer(-n); break;
case '+': assm.emit_modify_stack_value(n); break;
case '-': assm.emit_modify_stack_value(-n); break;
}
break;
}
case '.':
assm.emit_putchar();
++begin;
break;
case ',':
assm.emit_getc();
++begin;
break;
case '[':
markers.push(assm.emit_jump_forward());
++begin;
break;
case ']':
assm.emit_jump_backward(markers.top());
markers.pop();
++begin;
break;
default:
++begin;
}
}
assm.emit_leave();
std::vector<unsigned char> code;
code.reserve(4096);
assm.write(std::back_inserter(code));
return bf_function(code.cbegin(), code.cend());
}
};
int main() {
bf_compiler<x86_64_asm> compiler;
std::vector<int> stack(64 * 1024, 0);
auto const time = take_time([&] {
auto fn = compiler.compile(std::istreambuf_iterator<char>(std::cin),
std::istreambuf_iterator<char>());
fn(stack.data());
});
std::cerr << time.count() << "ms" << std::endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment