Create a gist now

Instantly share code, notes, and snippets.

@tibordp /jit.cc
Last active Oct 26, 2015

What would you like to do?
A JIT compiler for parametrized RPN expressions
#define NOMINMAX // Otherwise the min/max will be defined as macros on VS
#include <algorithm>
#include <array>
#include <atomic>
#include <cctype>
#include <chrono>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <random>
#include <sstream>
#include <stack>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
#ifdef _WIN32
#include <windows.h>
#else
#include <sys/mman.h>
#endif
using namespace std; // Don't judge :)
class code_vector : public vector<uint8_t>
{
public:
/*
This enables us to push an arbitrary type as an immediate
value to the stream
*/
template<typename T>
void push_value(iterator whither, T what)
{
auto position = whither - begin();
insert(whither, sizeof(T), 0x00);
*reinterpret_cast<T*>(&(*this)[position]) = what;
}
};
/*--------------------------------------------------------------------- */
/*
This class wraps the x_memory functionality and provides a
std::function-like interface, like this:
auto f = x_function<int(int)> ( ptr_to_machine_code, length);
cout << f( 42 );
*/
template<typename> class x_function;
template<typename R, typename... ArgTypes>
class x_function<R(ArgTypes...)>
{
bool executable_;
size_t size_;
void * data_;
public:
void set_executable(bool executable)
{
if (executable_ != executable)
{
#ifdef _WIN32
DWORD old_protect;
VirtualProtect(
data_, size_,
executable ? PAGE_EXECUTE_READ : PAGE_READWRITE,
&old_protect
);
#else
mprotect(data_, size_,
PROT_READ | (executable ? PROT_EXEC : PROT_WRITE));
#endif
executable_ = executable;
}
}
x_function(size_t size) :
executable_(false),
size_(size)
{
if (size == 0) { data_ = nullptr; return; }
#ifdef _WIN32
data_ = VirtualAlloc(0, size, MEM_COMMIT, PAGE_READWRITE);
#else
data_ = mmap(NULL, size,
PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON, -1, 0);
#endif
}
x_function(void* data, size_t size) :
x_function(size)
{
memcpy(data_, data, size);
set_executable(true);
}
~x_function()
{
#ifdef _WIN32
VirtualFree(data_, 0, MEM_RELEASE);
#else
munmap(data_, size_);
#endif
}
void swap(x_function& other)
{
using std::swap;
swap(executable_, other.executable_);
swap(size_, other.size_);
swap(data_, other.data_);
}
/*
By copying the x_memory object we are actually making a copy.
While it would probably be more efficient to just keep a
reference count, it may be helpful to allow the users to
unprotect & change the machine code.
*/
x_function() : x_function(0) {}
x_function(const x_function& other) : x_function()
{
x_function copy(other.size_);
memcpy(copy.data_, other.data_, other.size_);
copy.set_executable(other.executable_);
swap(copy);
}
x_function(x_function&& other) : x_function() { swap(other); }
x_function& operator=(x_function other)
{
swap(other);
return *this;
}
x_function(code_vector::iterator begin, code_vector::iterator end) :
x_function(&*begin, end - begin) { }
/*----------------------------------------------------------------- */
void* data() const { return data_; }
size_t size() const { return size_; }
bool executable() const { return executable_; }
template<typename... RArgTypes>
R operator()(RArgTypes&&... args) const
{
return reinterpret_cast<R(*)(ArgTypes...)>(data_)(
forward<RArgTypes>(args)...);
}
};
/* -------------------------------------------------------------------- */
x_function<int64_t(int64_t)> add(int64_t value)
{
code_vector code;
code.insert(code.end(), { 0x48, 0xB8 });
code.push_value(code.end(), value);
#ifdef _WIN32
code.insert(code.end(), { 0x48, 0x01, 0xC8, 0xC3 });
#else
code.insert(code.end(), { 0x48, 0x01, 0xF8, 0xC3 });
#endif
/*
movabs rax, %(value)
add rax, rcx # or 'add rax, rdi' on UNIX
ret
*/
return x_function<int64_t(int64_t)>(code.begin(), code.end());
}
x_function<void(int64_t&)> add_by_ref(int64_t value)
{
code_vector code;
code.insert(code.end(), { 0x48, 0xB8 });
code.push_value(code.end(), value);
#ifdef _WIN32
code.insert(code.end(), { 0x48, 0x01, 0x01, 0xC3 });
#else
code.insert(code.end(), { 0x48, 0x01, 0x07, 0xC3 });
#endif
/*
movabs rax, %(value)
add QWORD PTR [rcx],rax
# or 'add QWORD PTR [rdi],rax' on UNIX
ret
*/
return x_function<void(int64_t&)>(code.begin(), code.end());
}
/* -------------------------------------------------------------------- */
/*
This is a simple RPN evaluator provided for completeness.
*/
double rpn(char const* expression, double value)
{
array<double, 128> stk;
int depth = 0;
for (; *expression; ++expression)
{
if (isdigit(*expression) || (*expression == '.') ||
(*expression == '-' && isdigit(*(expression + 1))))
{
char* new_expression;
stk[depth] = strtod(expression, &new_expression);
expression = new_expression;
++depth;
}
else if (*expression == 'x')
{
stk[depth] = value;
++depth;
}
else if (*expression > 32)
{
switch (*expression) {
case '+': stk[depth - 2] += stk[depth - 1]; break;
case '*': stk[depth - 2] *= stk[depth - 1]; break;
case '-': stk[depth - 2] -= stk[depth - 1]; break;
case '/': stk[depth - 2] /= stk[depth - 1]; break;
default: ;
}
--depth;
}
if (*expression == '\0')
break;
}
return stk[0];
}
/*
This is where the fun happens. This function takes a parametrized
mathematical expression in RPN ('x' is the parameter) and compiles it
to machine code. Returns a function that maps x -> f(x).
*/
x_function<double(double)> rpn_compile(char const* expression)
{
code_vector code;
vector<double> literals;
// Some macros for commonly used instructions.
auto xmm = [](uint8_t n1, uint8_t n2) -> uint8_t
{ return 0xC0 + 8 * n1 + n2; };
auto pop_xmm = [&](uint8_t whither) {
code.insert(code.end(), { 0xF3, 0x0F, 0x6F,
(uint8_t)(0x04 + whither * 8), 0x24, 0x48, 0x83, 0xC4, 0x10 });
};
auto push_xmm = [&](uint8_t whence) {
code.insert(code.end(), { 0x48, 0x83, 0xEC, 0x10,
0xF3, 0x0F, 0x7F, (uint8_t)(0x04 + whence * 8), 0x24 });
};
auto operation = [&](uint8_t op, uint8_t n1, uint8_t n2) {
code.insert(code.end(), { 0xF2, 0x0F, op, xmm(n1, n2) });
};
auto movapd_xmm = [&](uint8_t n1, uint8_t n2) {
code.insert(code.end(), { 0x66, 0x0F, 0x28, xmm(n1, n2) });
};
auto load_xmm = [&](uint8_t n, int32_t offset) {
code.insert(code.end(), { 0x66, 0x0F, 0x6F,
(uint8_t)(0x81 + n * 8) });
code.push_value<int32_t>(code.end(), 16 * offset);
/*
movdqa xmm<n>,xmmword ptr [rcx+offset]
This requires that the data be aligned on the 16-byte
boundary, otherwise we would use:
movdqa xmm<n>,xmmword ptr [rcx+offset]
*/
};
int depth = 0;
/*
We save the XMM5-7 register state since we have to by the Windows'
calling convention
*/
#ifdef _WIN32
push_xmm(5); push_xmm(6); push_xmm(7);
#endif
/*
We use registers xmm1 - xmm5 for our operand stack. If the stack
gets larger, we push the intermediate results to the main stack.
If we need to load operands from the main stack, we put them into
xmm6 and xmm7 for proceessing.
*/
for (; *expression; ++expression)
{
if ( isdigit(*expression) || (*expression == '.') ||
(*expression == '-' && isdigit (*(expression+1)) ))
{
char* new_expression;
literals.push_back(strtod(expression, &new_expression));
expression = new_expression;
// We copy the data from the literal table to the appropriate
// register
load_xmm(min(depth + 1, 6), literals.size() - 1);
if (depth + 1 >= 6)
push_xmm(6);
++depth;
}
else if (*expression == 'x')
{
// The parameter is already in this register, so we just
// copy/push it.
if (depth + 1 >= 6)
push_xmm(0);
else
movapd_xmm(depth + 1, 0);
++depth;
}
else if(*expression > 32)
{
// If we have fewer than 2 operands on the stack, the
// expression is malformed.
if (depth < 2)
throw runtime_error("Invalid expression");
if (depth >= 6)
pop_xmm(6);
if (depth > 6)
pop_xmm(7);
// Perform the operation in the correct registers
int tgt_reg = min(depth - 1, 6);
int src_reg = min(depth, 7);
switch (*expression) {
case '+':
operation(0x58, tgt_reg, src_reg); // addsd xmm1, xmm2
break;
case '*':
operation(0x59, tgt_reg, src_reg); // mulsd xmm1, xmm2
break;
case '-':
operation(0x5C, tgt_reg, src_reg); // subsd xmm1, xmm2
break;
case '/':
operation(0x5E, tgt_reg, src_reg); // divsd xmm1, xmm2
break;
default:;
}
// If the register stack is full, we push onto the main stack.
if (depth > 6)
push_xmm(6);
--depth;
}
// If strtof moved the pointer to the end.
if (*expression == '\0')
break;
}
// If there is to little or too much left on stack.
if (depth != 1)
throw runtime_error("Invalid expression");
// The return value is passed by xmm0. Now we no longer need
// to hold onto the value for x.
movapd_xmm(0, 1);
#ifdef _WIN32
// We restore the XMM5-7 register state.
pop_xmm(7); pop_xmm(6); pop_xmm(5);
#endif
code.push_back( 0xc3 ); // ret
/*
We store the base address for literals into rcx. The base address
is calculated as the offset from this instruction, so it is placed
at the beginning.
lea rcx, [rip + <PLACEHOLDER>]
*/
// +7 since we are going to insert a lea instruction
int32_t executable_size = code.size() + 7;
// Align on the 16-byte boundary:
executable_size = 15 + executable_size - (executable_size - 1) % 16;
// This is slow for large vectors. Oh well...
code.insert(code.begin(), { 0x48, 0x8D, 0x0D });
code.push_value<int32_t>(code.begin() + 3, executable_size - 7);
code.insert(code.end(), executable_size - code.size(), 0x00);
/*
We place all the floating point literals AFTER the code.
*/
for (double val : literals)
{
code.push_value<double>(code.end(), val);
code.push_value<double>(code.end(), 0);
}
return x_function<double(double)>(code.begin(), code.end());
}
string random_polynomial(int length)
{
stringstream ss;
default_random_engine rng;
rng.seed(random_device()());
int depth = 0;
for (; (length >= 0) || (depth > 1); --length)
{
int choice;
// If we have reached the end, we just add the remaining
// values on the stack.
if (length < 0)
{
ss << "+ ";
--depth;
continue;
}
do
{
choice = uniform_int_distribution<int>(0, 4)(rng);
} while ((choice <= 2) && (depth < 2));
switch (choice){
case 0: ss << "+"; break;
case 1: ss << "-"; break;
case 2: ss << "*"; break;
case 3: ss << "x"; break;
case 4:
ss << uniform_real_distribution<double>(-1, 1)(rng);
break;
}
ss << " ";
depth += (choice >= 3) ? 1 : -1;
}
return ss.str();
}
/*
We use this routine for a time benchmark.
*/
template<typename T>
double time(T callback)
{
using namespace chrono;
/* TODO: Is it possible that the following lines somehow get reordered?
There are some weird results on gcc with -O3 */
auto start = high_resolution_clock::now();
callback();
auto end = high_resolution_clock::now();
return duration_cast<duration<double>>(end - start).count();
}
int main()
{
// Generation
string poly;
auto generation = time([&] {
poly = random_polynomial(100 * 1000 * 1000);
});
// Compilation
auto expression = poly.c_str();
x_function<double(double)> f;
auto parsing = time([&] {
f = rpn_compile(expression);
});
// Interpreted evaluation
vector<double> results1;
int count1 = 0;
auto calculating1 = time([&] {
for (double x = -1; x <= 1; x += 0.5)
{
results1.push_back(rpn(expression, x));
++count1;
}
});
// Compiled evaluation
vector<double> results2;
int count2 = 0;
auto calculating2 = time([&] {
for (double x = -1; x <= 1; x += 0.5)
{
results2.push_back(f(x));
++count2;
}
});
cout << "Generation: " << generation << "s" << endl
<< "Parsing: " << parsing << "s" << endl
<< "Calculation (interpreted): " << (calculating1 / count1) << "s" << endl
<< "Calculation (compiled): " << (calculating2 / count2) << "s" << endl;
return EXIT_SUCCESS;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment