Skip to content

Instantly share code, notes, and snippets.

@paulsmith
Forked from leegao/Rationale.md
Created July 11, 2011 03:35
Show Gist options
  • Save paulsmith/1075284 to your computer and use it in GitHub Desktop.
Save paulsmith/1075284 to your computer and use it in GitHub Desktop.
JIT for dummies: JIT compiling RPN in python

If you don't care about the explanation, scroll down to find the code, it's 50 some odd lines and written by someone who doesn't know any better. You have been warned.

What it does

This is a very simple proof of concept jitting RPN calculator implemented in python. Basically, it takes the source code, tokenizes it via whitespace, and asks itself one simple question: am I looking at a number or not?

First, let's talk about the underlying program flow. Pretend that you are a shoe connoisseur with a tiny desk. You may only have two individual shoes on that desk at any one time, but should you ever purchase a new one or get harassed by an unruly shoe salesman without realizing that you have the power to say no (or even maybe?), you can always sweep aside one of the two shoes on the desk (the one on the right, because you're a lefty and you feel that the left side is always superior) onto the messy floor, put the other shoe on the right hand side, and then place your newly acquired shoe in the most leftward side of the desk so that you may gloat about how your very best shoe is on the very best side. You are also cursed, which may inexplicably explain your obsession with shoes, such that every time you admire a shoe, that shoe somehow merges itself into the other shoe, at which point you must promptly get up, walk around the floor, and find the last shoe that was swept away and place that shoe into the rightmost slot. Your ultimate goal is to admire enough shoes such that you are left with a single highly desirable shoe.

Okay, so that was just a bit crazy, but if you replace shoes with 32bit integers, the desk with two x86 registers (namely eax and ecx), the floor with the stack (finding shoes off of the floor is considerably more time consuming than finding one of the two shoes on your desk, this analogy extends to the x86 architecture regarding registers and the stack), buying/getting harassed by shoe salesmen with pushing a number into the RPN stack, and admiration with popping two numbers, perform an operation, then pushing the result back, well, okay that's still kinda crazy.

Anyways, everytime we load a number in, we push the second slot (ecx) onto the stack (not the RPN stack!), switch the contents of ecx and eax, then load the number into eax. Following this scheme, eax will always be the most recent (and "coincidentally" the returned value), ecx, the second most recent, and everything else ordered by their time of arrival. Similarly, for a binary operation, we would need to "pop" the two most recent numbers, perform that operation, then "push" it back. Since these numbers are stored in the register, and since all(most all) of the operations are performed inline, this scheme is perfect. We just need to remember to pop into ecx to refill the vacancy and to keep track of the number of items on the RPN stack so we can clean our call stack at the end of each run.

Other fun stuff

Okay, there's a small detail that I left out. How the hell do I run the x86 machine code in python?

Well, python comes with a large standard library, and with large standard libraries, comes great responsib...

Err, I mean, python's standard library is really powerful as it allows you to interface with native code without having to write a single line of C.

So now the problem becomes: How the hell do I run x86 machine code in plain C? Isn't there some kind of safeguard against executing arbitrary and potentially unsafe code for the hell of it? Yes, DEP prevents arbitrary execution of data initialized on the heap or any nonexecutable regions of memory, so that code like this:

char* evil_exploit_to_activate_skynet_do_not_execute = "\x33\xc0\xc3Haha, I just zeroed your eax, you're next Connor";
int ahhh_what_the_hell_are_you_doing = ((int(*)())evil_exploit_to_activate_skynet_do_not_execute)();

won't (or shouldn't, http://codepad.org/X2zXBm4O) work. (Actually, this example would be more correct if we allocated space on the heap, copied over the code, and then tried to run it)

However, not all code are inherently bent on the destruction of the world, and in the end, operating systems are more altruistic than not. If the program just asks really really nicely, the system may allocate blocks of memory that are both writable and executable. Depending on the system, a program can either call win32's VirtualAlloc with 0x40 as its protection flag (which stands for read-write-execute) or linux's mprotect with 0x7 as its protection flag (which again stands for read-write-execute), and if done nicely enough (you pushed twelve bytes onto the stack? that's good enough for me says the operating system), you'll get a pointer back with the address to your brand new executable block of memory. (For the concerned few, this isn't really a security flaw as you still need root privileges to mprotect another process's heap. For the most part, virtualprotect and mprotect are used for interprocess communication rather than self indulging hacks).

Anyways, all the pieces are there, and the code below shows you how they fit together.

A General note on JIT Compilation

Beware that this is a very primitive compiler. For one thing, it's monotonously typed with int32's being the only datatype. At the same time, there are no complex loops, jumps, or even subroutines. This essentially means that there's absolutely no point (besides for the sake of curiosity) to compile the code down into machine code as the running time would be insignificantly small either way.

Now general JIT compilers comes in two flavors. Traditionally, we have the jit compilers from strongly typed languages (such as java) that does very detailed static analysis on its IL bytecode and compile entire methods down into machine code just like our example (minus the method part) does. Novel idea, no? So why would we need alternative methods? For one thing, static analysis is woefully inadequate for largely dynamically typed languages like javascript, python, lua, or even ruby. In python for example, you can never tell if a function such as this one

def add(a, b):
    return a + b

expects integers as its arguments, strings, or even instance objects with an overloaded add method statically. For this reason, code such as this are often omitted from the compilation phase. For languages such as javascript and python, this essentially means that a static compiler would cover nearly none of its code.

And furthermore, static analysis in itself is a very expensive operation, so much so that most of the analysis is done rather shoddily for the sake of reducing its running time.

More recently, developers began pioneering a new field in dynamic compilation called trace compiling, or tracing. During compilation into the intermediate language, the compiler can mark the branches within the code, for example, a label in front of a return statement, a jmp operation at a break statement, or call or return. The reasoning is simple, if you strip away all of the jumps and function calls, no matter how long the program is, as long as its written by a human, then from an engineer's perspective, the program will always run in negligible time. Where programs begins to eat up memory however, are the various loops, recursive calls, and whatnot. Essentially, the introduction of branches makes the complexity of the program unpredictable, hence by optimizing the frequently used loops and function calls, we can effectively balance the efficiency of the program with the resources required to compile the code (which is still not trivial by any means). It also solves another major problem. Whenever the program counter (or instruction pointer) hits one of the branch markers, its entire execution flow is traced and stored (and will eventually be compiled if that route is taken frequently enough). Because the program flow of a string concatenation differs from those of an integer addition or even an overloaded function, each of these flow generates unique traces with their unique signatures (the input). For this very reason, if we encounter a string concatenation before an integer addition, the program will not only be able to identify the integer case as a completely different route, but will also be able to break down the complexity of the environment so that further optimizations during the compilation of the integer route can safely assume that its inputs are strictly integers. Neat huh?

Note: I got a little lazy, and as I'm optimizing for space, decided that it's a better idea to push and pop ecx for no reason on illegal tokens rather than to handle it as an explicit case. Also, only rudimentary safeguards are implemented against injecting malicious code by masking only 32 bits on integers. There are still (probably) a lot of vulnerabilities, but hey, it's just a simple hack.

import ctypes, mmap, struct
try:
_VirtualAlloc = ctypes.windll.kernel32.VirtualAlloc
def valloc(size):
addr = _VirtualAlloc(0, size, 0x1000, 0x40)
if not addr:
raise RuntimeError("Cannot allocate RWX memory")
return addr
except:
libc = ctypes.CDLL("libc.so.6")
def valloc(size):
addr = libc.valloc(size)
if not addr or libc.mprotect(addr, size, 0x07):
raise RuntimeError("Cannot allocate RWX memory")
return addr
class RPN_jit:
def __init__(self):
self.size = mmap.PAGESIZE
self.exepage = valloc(self.size) # a single page for execution at the very minimum
def emit(self, code):
# we will use cdecl function declarations, assume small endianness
buffer = "\x55" + "\x8b\xec" + "\x81\xec\xcc\0\0\0" + "\x53" + "\x56" + "\x57" + "\x8d\xbd\x34\xff\xff\xff"
def _num(o):
try: return int(o)
except: return int(o, 16)
sp = 0 # stack count, used to clean up stack space
for o in code.split():
# valid tokens: integers, +, -, *, /, %; all operators are assumed binary
try:
o = _num(o)
# eax, ecx serves as our registers for most current values, the stack comes later
# every time we load in an integer, effectively, we push the content of ecx, bump eax's data into ecx, and then load into eax
buffer += "\x51"+"\x91"+"\xb8"+struct.pack("i",o&(0xffffffff)) # don't want to overflow the mov instruction
sp+=1
except (ValueError):
# eax is first param, ecx is second param, eax is storage, and then we pop into ecx
# at the end of the run, eax is the most recently "pushed" item, perfect for /xc3
if sp<2: raise RuntimeError("Stack will underflow.")
buffer += "\x03\xc1" if o in ("+", "add", "plus") else "\x2b\xc1" if o in ("-", "sub", "minus") \
else "\x0f\xaf\xc1" if o in ("*", "mul", "mult") else "\x99\xf7\xf9" if o in ("/", "div") \
else "\x99\xf7\xf9\x92" if o in ("%", "mod", "rem") else "\x55" # mod is actually just idiv and xchg edx, eax
buffer += "\x59" # pop ecx
sp-=1
if not sp: raise RuntimeError("Nothing to compile.")
for _ in range(sp): buffer += "\x59" # pop ecx to clear the stack
buffer += "\x5f\x5e\x5b\x8b\xe5\x5d\xc3" # pops all register, rebases ebp&esp, and return eax, which contains the last push
if not ctypes.memmove(self.exepage, buffer, min(len(buffer), self.size)):
raise RuntimeError("Input cannot not fit into memory.")
return ctypes.CFUNCTYPE(ctypes.c_int32)(self.exepage)
a = RPN_jit()
print a.emit("3 10 mod 10 + 0x100 * 100 * 50 -")()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment