public
Last active

JIT for dummies: JIT compiling RPN in python

  • Download Gist
Rationale.md
Markdown

If you don't care about the explanation, scroll down to find the code, it's 50 some odd lines and written by someone who doesn't know any better. You have been warned.

What it does

This is a very simple proof of concept jitting RPN calculator implemented in python. Basically, it takes the source code, tokenizes it via whitespace, and asks itself one simple question: am I looking at a number or not?

First, let's talk about the underlying program flow. Pretend that you are a shoe connoisseur with a tiny desk. You may only have two individual shoes on that desk at any one time, but should you ever purchase a new one or get harassed by an unruly shoe salesman without realizing that you have the power to say no (or even maybe?), you can always sweep aside one of the two shoes on the desk (the one on the right, because you're a lefty and you feel that the left side is always superior) onto the messy floor, put the other shoe on the right hand side, and then place your newly acquired shoe in the most leftward side of the desk so that you may gloat about how your very best shoe is on the very best side. You are also cursed, which may inexplicably explain your obsession with shoes, such that every time you admire a shoe, that shoe somehow merges itself into the other shoe, at which point you must promptly get up, walk around the floor, and find the last shoe that was swept away and place that shoe into the rightmost slot. Your ultimate goal is to admire enough shoes such that you are left with a single highly desirable shoe.

Okay, so that was just a bit crazy, but if you replace shoes with 32bit integers, the desk with two x86 registers (namely eax and ecx), the floor with the stack (finding shoes off of the floor is considerably more time consuming than finding one of the two shoes on your desk, this analogy extends to the x86 architecture regarding registers and the stack), buying/getting harassed by shoe salesmen with pushing a number into the RPN stack, and admiration with popping two numbers, perform an operation, then pushing the result back, well, okay that's still kinda crazy.

Anyways, everytime we load a number in, we push the second slot (ecx) onto the stack (not the RPN stack!), switch the contents of ecx and eax, then load the number into eax. Following this scheme, eax will always be the most recent (and "coincidentally" the returned value), ecx, the second most recent, and everything else ordered by their time of arrival. Similarly, for a binary operation, we would need to "pop" the two most recent numbers, perform that operation, then "push" it back. Since these numbers are stored in the register, and since all(most all) of the operations are performed inline, this scheme is perfect. We just need to remember to pop into ecx to refill the vacancy and to keep track of the number of items on the RPN stack so we can clean our call stack at the end of each run.

Other fun stuff

Okay, there's a small detail that I left out. How the hell do I run the x86 machine code in python?

Well, python comes with a large standard library, and with large standard libraries, comes great responsib...

Err, I mean, python's standard library is really powerful as it allows you to interface with native code without having to write a single line of C.

So now the problem becomes: How the hell do I run x86 machine code in plain C? Isn't there some kind of safeguard against executing arbitrary and potentially unsafe code for the hell of it? Yes, DEP prevents arbitrary execution of data initialized on the heap or any nonexecutable regions of memory, so that code like this:

char* evil_exploit_to_activate_skynet_do_not_execute = "\x33\xc0\xc3Haha, I just zeroed your eax, you're next Connor";
int ahhh_what_the_hell_are_you_doing = ((int(*)())evil_exploit_to_activate_skynet_do_not_execute)();

won't (or shouldn't, http://codepad.org/X2zXBm4O) work. (Actually, this example would be more correct if we allocated space on the heap, copied over the code, and then tried to run it)

However, not all code are inherently bent on the destruction of the world, and in the end, operating systems are more altruistic than not. If the program just asks really really nicely, the system may allocate blocks of memory that are both writable and executable. Depending on the system, a program can either call win32's VirtualAlloc with 0x40 as its protection flag (which stands for read-write-execute) or linux's mprotect with 0x7 as its protection flag (which again stands for read-write-execute), and if done nicely enough (you pushed twelve bytes onto the stack? that's good enough for me says the operating system), you'll get a pointer back with the address to your brand new executable block of memory. (For the concerned few, this isn't really a security flaw as you still need root privileges to mprotect another process's heap. For the most part, virtualprotect and mprotect are used for interprocess communication rather than self indulging hacks).

Anyways, all the pieces are there, and the code below shows you how they fit together.

A General note on JIT Compilation

Beware that this is a very primitive compiler. For one thing, it's monotonously typed with int32's being the only datatype. At the same time, there are no complex loops, jumps, or even subroutines. This essentially means that there's absolutely no point (besides for the sake of curiosity) to compile the code down into machine code as the running time would be insignificantly small either way.

Now general JIT compilers comes in two flavors. Traditionally, we have the jit compilers from strongly typed languages (such as java) that does very detailed static analysis on its IL bytecode and compile entire methods down into machine code just like our example (minus the method part) does. Novel idea, no? So why would we need alternative methods? For one thing, static analysis is woefully inadequate for largely dynamically typed languages like javascript, python, lua, or even ruby. In python for example, you can never tell if a function such as this one

def add(a, b):
    return a + b

expects integers as its arguments, strings, or even instance objects with an overloaded add method statically. For this reason, code such as this are often omitted from the compilation phase. For languages such as javascript and python, this essentially means that a static compiler would cover nearly none of its code.

And furthermore, static analysis in itself is a very expensive operation, so much so that most of the analysis is done rather shoddily for the sake of reducing its running time.

More recently, developers began pioneering a new field in dynamic compilation called trace compiling, or tracing. During compilation into the intermediate language, the compiler can mark the branches within the code, for example, a label in front of a return statement, a jmp operation at a break statement, or call or return. The reasoning is simple, if you strip away all of the jumps and function calls, no matter how long the program is, as long as its written by a human, then from an engineer's perspective, the program will always run in negligible time. Where programs begins to eat up memory however, are the various loops, recursive calls, and whatnot. Essentially, the introduction of branches makes the complexity of the program unpredictable, hence by optimizing the frequently used loops and function calls, we can effectively balance the efficiency of the program with the resources required to compile the code (which is still not trivial by any means). It also solves another major problem. Whenever the program counter (or instruction pointer) hits one of the branch markers, its entire execution flow is traced and stored (and will eventually be compiled if that route is taken frequently enough). Because the program flow of a string concatenation differs from those of an integer addition or even an overloaded function, each of these flow generates unique traces with their unique signatures (the input). For this very reason, if we encounter a string concatenation before an integer addition, the program will not only be able to identify the integer case as a completely different route, but will also be able to break down the complexity of the environment so that further optimizations during the compilation of the integer route can safely assume that its inputs are strictly integers. Neat huh?

Note: I got a little lazy, and as I'm optimizing for space, decided that it's a better idea to push and pop ecx for no reason on illegal tokens rather than to handle it as an explicit case. Also, only rudimentary safeguards are implemented against injecting malicious code by masking only 32 bits on integers. There are still (probably) a lot of vulnerabilities, but hey, it's just a simple hack.

rpnjit.py
Python
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
import ctypes, mmap, struct
try:
_VirtualAlloc = ctypes.windll.kernel32.VirtualAlloc
def valloc(size):
addr = _VirtualAlloc(0, size, 0x1000, 0x40)
if not addr:
raise RuntimeError("Cannot allocate RWX memory")
return addr
except:
libc = ctypes.CDLL("libc.so.6")
def valloc(size):
addr = libc.valloc(size)
if not addr or libc.mprotect(addr, size, 0x07):
raise RuntimeError("Cannot allocate RWX memory")
return addr
class RPN_jit:
def __init__(self):
self.size = mmap.PAGESIZE
self.exepage = valloc(self.size) # a single page for execution at the very minimum
def emit(self, code):
# we will use cdecl function declarations, assume small endianness
buffer = "\x55" + "\x8b\xec" + "\x81\xec\xcc\0\0\0" + "\x53" + "\x56" + "\x57" + "\x8d\xbd\x34\xff\xff\xff"
def _num(o):
try: return int(o)
except: return int(o, 16)
sp = 0 # stack count, used to clean up stack space
for o in code.split():
# valid tokens: integers, +, -, *, /, %; all operators are assumed binary
try:
o = _num(o)
# eax, ecx serves as our registers for most current values, the stack comes later
# every time we load in an integer, effectively, we push the content of ecx, bump eax's data into ecx, and then load into eax
buffer += "\x51"+"\x91"+"\xb8"+struct.pack("i",o&(0xffffffff)) # don't want to overflow the mov instruction
sp+=1
except (ValueError):
# eax is first param, ecx is second param, eax is storage, and then we pop into ecx
# at the end of the run, eax is the most recently "pushed" item, perfect for /xc3
if sp<2: raise RuntimeError("Stack will underflow.")
buffer += "\x03\xc1" if o in ("+", "add", "plus") else "\x2b\xc1" if o in ("-", "sub", "minus") \
else "\x0f\xaf\xc1" if o in ("*", "mul", "mult") else "\x99\xf7\xf9" if o in ("/", "div") \
else "\x99\xf7\xf9\x92" if o in ("%", "mod", "rem") else "\x55" # mod is actually just idiv and xchg edx, eax
buffer += "\x59" # pop ecx
sp-=1
if not sp: raise RuntimeError("Nothing to compile.")
for _ in range(sp): buffer += "\x59" # pop ecx to clear the stack
buffer += "\x5f\x5e\x5b\x8b\xe5\x5d\xc3" # pops all register, rebases ebp&esp, and return eax, which contains the last push
if not ctypes.memmove(self.exepage, buffer, min(len(buffer), self.size)):
raise RuntimeError("Input cannot not fit into memory.")
return ctypes.CFUNCTYPE(ctypes.c_int32)(self.exepage)
a = RPN_jit()
print a.emit("3 10 mod 10 + 0x100 * 100 * 50 -")()

Actually, to be honest, I didn't really have to allocate the 0xcc bytes of locals below ebp, this would probably also eliminate endianess problems as well

im 12 and wat is this?

In python for example, you can never tell if a function such as this one expects integers as its arguments, strings, or even instance objects with an overloaded add method statically.

A common misconception about static analysis in dynamically-typed languages. You can't write a program that determines that property of add in general, for all programs containing it, but if this is the program in question, there already are many analyzers/compilers that determine that add only ever works on integers:

#!/usr/bin/env python
def add(a, b):
  return a + b
print(add(1, 2))
print(add(-4, 17)

See: Halting problem. HP doesn't say you can't prove the program I just gave never terminates, it says you can't prove it in the general case. Enormous amounts of work have been done to increase the set of programs you can prove to halt.

Nice hack but you are wrong about static analysis in dynamic languages, http://shed-skin.blogspot.com/

@michaeledgar and sean, Thanks, so the analysis is not context-free, which makes sense. Albeit would it be safe to assume that static analysis is expensive? I think most modern jitters use a hybrid between the two so I guess what I said was misleading

@flood lol -.-

@leegao, Yes determining the types is a very expensive operation. See this post by the author of shedskin, http://shed-skin.blogspot.com/2010/12/shed-skin-07-type-inference-scalability.html type tracing JITs on the other hand are a good fit for existing dynamic languages where the main distribution mechanism is the unprocessed source, ie JavaScript and where large portions of the code will not execute very often.

Shedskin is not intetended to compile say arbitrary Django apps but to speedup kernels that are written with Shedskin in mind (which can be brought in as Python modules into CPython).

Your code is an excellent piece of systems programming that is often not explored in Python for fear of being unpythonic. Keep it up!

If you could pass values into your compiled functions it would be useful (it useful now, just a little harder to use). See this stack overflow question on modifying transforming the python AST, http://stackoverflow.com/questions/768634/python-parse-a-py-file-read-the-ast-modify-it-then-write-back-the-modified

function objects have a rich set of attributes that are worth exploring.

# add_t.py

def add(a,b):
    return a+b

print add.func_code.co_firstlineno
print add.func_code.co_filename

The above example is an easy way to get access to the file and linenumber of a function, then load in the src and hand it off to ast. Besides Shedskin take a look @ http://code.google.com/p/unpython/

I would implement the fractal function and then port a simple ray tracer over. There are many fine examples of ray tracers in Python

Another route to explore is converting the scene graph from one of the above ray tracers and constructing a textual program in modified version of your RPN calculator and compile a single function that emits the rendered image.Then there is no need to pass in parameters, the entire program is the parameter.

@leegao Pretty sweet. I didn't realise you call into libc like that from python. I will definitely giving this stuff a whorl at some point in the near future. However, you code does segfault when run on my machine. I don't have time to track it down right now but (linux x86_64) is the enviroment.

This is speculation as I don't actually know, I'll have to disassemble my original code first, but it's probably because I assembled on 32 bit machine without explicitly specifying the l postfix but still used eax, ecx and esp, which aren't general purpose in 64 bit. Same issues with the prologue and epilogue.

Codepad seems to work well http://codepad.org/aqLx1ePp

Ok good to know. I think if I give this a try I will have it generate the mnemonics and then assemble with gas so it will be slightly more portable.

Please sign in to comment on this gist.

Something went wrong with that request. Please try again.