Skip to content

Instantly share code, notes, and snippets.

@pchiusano
Created September 21, 2017 20:51
Show Gist options
  • Save pchiusano/71dd8c7c35057f6f453ea1fc2974debf to your computer and use it in GitHub Desktop.
Save pchiusano/71dd8c7c35057f6f453ea1fc2974debf to your computer and use it in GitHub Desktop.
Code for Scala World 2017 talk on eliminating interpreter overhead via partial evaluation
package scalaworld.interpreters
/*
This file shows a simple language, an interpreter, and two
partial evaluators for that language, along with a profiling suite.
*/
trait Expr // denotes a Vector[Double] => Vector[Double]
object Expr {
/** Set register `d` (for "destination") equal to `n`. */
case class Num(d: Int, n: Double) extends Expr
/** Set register `d` equal to register `i` + register `j`. */
case class Plus(d: Int, i: Int, j: Int) extends Expr
/** Decrement register `d`. */
case class Decr(d: Int) extends Expr
/** Set register `d` equal to register `i`. */
case class Copy(d: Int, i: Int) extends Expr
/** Run the instructions in `es` in sequence. */
case class Block(es: List[Expr]) extends Expr
/** Execute `p` repeatedly until the `haltIf0` register is 0. */
case class Loop(haltIf0: Int, p: Expr) extends Expr
/** Some syntax - variadic `Block.apply`. */
object Block { def apply(es: Expr*): Expr = Block(es.toList) }
/**
* A simple interpreter for `Expr`. For efficiency, this mutates an `Array[Double]`
* rather than transforming a `Vector[Double]`. Straightforward but inefficient.
*/
def interpret(e: Expr, m: Array[Double]): Unit = e match {
case Num(d, n) => m(d) = n
case Decr(d) => m(d) = m(d) - 1.0
case Plus(d, i, j) => m(d) = m(i) + m(j)
case Copy(d, i) => m(d) = m(i)
case Loop(haltIf0, p) => interpretLoop(haltIf0, p, m)
case Block(es) => interpretBlock(es, m)
}
// Notice that we have interpreter overhead _on each execution of the loop body_.
def interpretLoop(haltIf0: Int, p: Expr, m: Array[Double]): Unit =
while (!(m(haltIf0) == 0)) interpret(p, m)
@annotation.tailrec
def interpretBlock(es: List[Expr], m: Array[Double]): Unit = es match {
case Nil => ()
case Block(es0) :: es => interpretBlock(es0 ++ es, m)
case e :: es => interpret(e, m); interpretBlock(es, m)
}
/**
* Here's a simple partial evaluatator. We curry the `interpret` function,
* but do all inspection of the syntax tree _before_ returning our
* compiled form, an `Array[Double] => Unit`.
*
* `partialEval` could be called `compile` - we are producing a compiled
* form with no interpreter overhead, as in the Futamura projections.
*/
def partialEval(e: Expr): Array[Double] => Unit = e match {
case Num(d, n) => m => m(d) = n
case Decr(d) => m => m(d) = m(d) - 1.0
// case Plus(d, i, j) if d == j => m => m(d) += m(i)
case Plus(d, i, j) => m => m(d) = m(i) + m(j)
case Copy(d, i) => m => m(d) = m(i)
case Loop(haltIf0, p) =>
val compiledBody = partialEval(p) // very important, we compile the body once!
m => while (m(haltIf0) != 0.0) compiledBody(m) // ... and then execute it multiple times
case Block(es) => partialEvalBlock(es)
}
def partialEvalBlock(ps: List[Expr]): Array[Double] => Unit = ps match {
case List(e) => partialEval(e)
case p :: ps2 =>
val cp = partialEval(p)
val cps = partialEvalBlock(ps2)
m => { cp(m); cps(m) }
}
// Performance of this approach is highly dependent on choice of compiled form.
// An `Array[Double] => Unit` may require computing array offsets and doing array
// bounds checks. To improve performance, we can move to a function that just
// takes a mutable record of `Double` values:
case class Machine(var r0: Double, var r1: Double, var r2: Double, var r3: Double)
// Our compiled form will be `Machine => Unit` for this second partial evaluator.
// Not being able to just use array offsets requires a bit more code than before.
object Machine {
def get(i: Int): Machine => Double = i match {
case 0 => _.r0
case 1 => _.r1
case 2 => _.r2
case 3 => _.r3
}
// experimented with this, doesn't seem to make a difference
//abstract class R { def apply(m: Machine): Double }
//def get(i: Int): R = i match {
// case 0 => new R { def apply(m: Machine) = m.r0 }
// case 1 => new R { def apply(m: Machine) = m.r1 }
// case 2 => new R { def apply(m: Machine) = m.r2 }
// case 3 => new R { def apply(m: Machine) = m.r3 }
//}
}
def partialEval2(e: Expr): Machine => Unit = e match {
case Num(d, n) => d match {
case 0 => m => m.r0 = n
case 1 => m => m.r1 = n
case 2 => m => m.r2 = n
case 3 => m => m.r3 = n
}
case Decr(d) => d match {
case 0 => m => m.r0 -= 1.0
case 1 => m => m.r1 -= 1.0
case 2 => m => m.r2 -= 1.0
case 3 => m => m.r3 -= 1.0
}
// can make a difference, suggests Machine.get(i) isn't reliably inlined by JIT
case Plus(1, 0, 1) => m => m.r1 += m.r0
case Plus(d, i, j) if d == j =>
val ci = Machine.get(i)
d match {
case 0 => m => m.r0 += ci(m)
case 1 => m => m.r1 += ci(m)
case 2 => m => m.r2 += ci(m)
case 3 => m => m.r3 += ci(m)
}
case Plus(d, i, j) =>
val ci = Machine.get(i)
val cj = Machine.get(j)
d match {
case 0 => m => m.r0 = ci(m) + cj(m)
case 1 => m => m.r1 = ci(m) + cj(m)
case 2 => m => m.r2 = ci(m) + cj(m)
case 3 => m => m.r3 = ci(m) + cj(m)
}
case Copy(d, i) =>
val ci = Machine.get(i)
d match {
case 0 => m => m.r0 = ci(m)
case 1 => m => m.r1 = ci(m)
case 2 => m => m.r2 = ci(m)
case 3 => m => m.r3 = ci(m)
}
// also can make a difference, suggests Machine.get compiled form isn't reliably inlined by JIT
case Loop(0, p) =>
val compiledBody = partialEval2(p)
m => while (m.r0 != 0.0) compiledBody(m)
case Loop(haltIf0, p) =>
val cHaltIf0 = Machine.get(haltIf0)
val compiledBody = partialEval2(p)
m => while (cHaltIf0(m) != 0.0) compiledBody(m)
case Block(es) => partialEvalBlock2(es)
}
def partialEvalBlock2(ps: List[Expr]): Machine => Unit = ps match {
case List(e) => partialEval2(e)
case p :: ps2 =>
val cp = partialEval2(p)
val cps = partialEvalBlock2(ps2)
m => { cp(m); cps(m) }
}
}
object Ex extends App {
import Expr._
import quickprofile.QuickProfile.{suite,profile}
def N = 1e6 + math.random.floor
val m = Array(0.0, 0.0, 0.0, 0.0)
// expects `n` in register 0, puts result in register 1
val fib = Block( // var n = <fn param>
Num(1, 0.0), // var f1 = 0
Num(2, 1.0), // var f2 = 1
Loop(0, Block( // while (n != 0) {
Plus(3, 1, 2),// val tmp = f1 + f2
Copy(1, 2), // f1 = f2
Copy(2, 3), // f2 = tmp
Decr(0))) // n -= 1
) // }
@annotation.tailrec
def fib(n: Double, f0: Double, f1: Double): Double =
if (n == 0) f0
else fib(n - 1.0, f1, f0 + f1)
// Sums up the numbers 0 to `n`.
// Expects `n` in register 0, puts result in register 1.
val sumN = Block(
Num(1, 0.0),
Loop(0, Block(
Plus(1, 0, 1),
Decr(0)
))
)
@annotation.tailrec
def sumN(n: Double, acc: Double): Double =
if (n == 0.0) acc
else sumN(n - 1.0, acc + n)
// Sanity check - let's make sure all implementations produce the same results
println {
println ("interpreted")
(0 until 10).map { i =>
m(0) = i.toDouble
interpret(sumN, m)
m(1).toLong
}.mkString(" ")
}
println {
val csum = partialEval(sumN)
println ("partially-evaluated")
(0 until 10).map { i =>
m(0) = i.toDouble
csum(m)
m(1).toLong
}.mkString(" ")
}
println {
val m = Machine(0,0,0,0)
val csum = partialEval2(sumN)
println ("partially-evaluated (2)")
(0 until 10).map { i =>
m.r0 = i.toDouble
csum(m)
m.r1.toLong
}.mkString(" ")
}
println {
println ("native")
(0 until 10).map { i => sumN(i.toDouble, 0.0).toLong }.mkString(" ")
}
// Okay, now run the profiling suite
suite (
{ val csum = partialEval2(sumN)
val m = Machine(0.0, 0.0, 0.0, 0.0)
profile("partially-evaluated (2)", 0.03) {
m.r0 = N
csum(m)
m.r1.toLong
}
},
{ val m = Array.fill(4)(0.0)
profile("interpreted", 0.03) {
m(0) = N
interpret(sumN, m)
m(1).toLong
}
},
profile("Scala", 0.03) { sumN(N, 0.0).toLong },
{ val csum = partialEval(sumN)
val m = Array.fill(4)(0.0)
profile("partially-evaluated", 0.03) {
m(0) = N
csum(m)
m(1).toLong
}
}
)
}
package quickprofile
/**
* Simple-to-use, fast, and relatively accurate benchmarking functions.
*
* `profile` runs an individual benchmark and reports performance.
* `suite` runs a collection of benchmarks and reports relative performance.
*
* Unlike JMH, we do not require picking an arbitrary number of warmup iterations
* or recorded iterations (often chosen to be either too small, yielding
* inaccurate or wildly varying results, or too big, leading benchmarks to take
* forever and not be run as part of normal development).
*
* See `profile` docs for more details on the methodology.
*
* Example usage: {{{
import QuickProfile.{suite, profile}
suite(
profile("loop1") {
val n = 1e6 + math.random
while (n > 0.0) n -=1
n.toLong
},
profile("loop2") {
val n = 1e6 + math.random;
(0 until 1000000).foreach { _ => n -= 1.0 }
n.toLong
},
{ // setup for the benchmark, won't be measured
val nums = Vector.range(0, 1000000)
// okay, start measuring
profile("loop3") {
val n = 1e6 + math.random;
nums.foreach { _ => n -= 1.0 }
n.toLong
}
}
)
}}}
Which produces output like: {{{
- loop1: 1.475 milliseconds (4.3% deviation, N=68, K = 0)
- loop2: 1.069 milliseconds (4.3% deviation, N=224, K = 0)
- loop3: 14.776 milliseconds (3.0% deviation, N=26, K = 0)
1.0 loop2
1.37 loop1
13.81 loop3
}}}
*/
object QuickProfile {
/**
* Run an action repeatedly, capturing profiling info, until the % deviation of
* timing info is less than `threshold` (closer to 0). Idea is to discover the
* steady state of performance that occurs when most hot spots have been JIT'd
* without needing to pick an arbitrary number of warmup iterations and trials
* (which are often either too small, leading to inaccurate results, or too big,
* leading to profiling taking way too long).
*
* This function increases N--the number of times `action` is invoked per
* iteration--until each iteration takes at least 100ms (so limited granularity
* of System.nanoTime is no longer an issue). Once reaching this point,
* it then gradually increase N exponentially (N = N * (1 + epsilon)) until percent
* deviation drops below `threshold`. This all tends to happen pretty quickly.
*
* The `action` must return a `Long`, preferably unique for each execution, and the
* sum of these numbers is threaded through the profiling computation to prevent the
* JVM from doing any heroic optimizations that would eliminate executions of `action`.
*
* One caveat: JVM optimizations are not totally deterministic, so running the same
* benchmark with a fresh JVM may reach a different steady state (though if performance
* is highly sensitive to this, it could be a good idea to find a different way of
* expressing your program such that performance is not as fragile). It's usually
* obvious from a handful of runs of a benchmark whether any nondeterminism of JVM
* optimizations is relevant for performance, but for maximum accuracy it can be a good
* idea to average results from multiple JVM runs.
*/
def profile(label: String, threshold: Double = 0.05)(action: => Long): (String, Double) = {
var N = 16L
var i = 0
var startTime = System.nanoTime
var stopTime = System.nanoTime
var sample = 1e9
var K = 0L
var ok = true
var percentDeviation = Double.PositiveInfinity
while (ok) {
// try to increase N to get at least 100ms sample -
if (sample*N < 1e8) { // 1e8 nanos is 100ms
// do linear interpolation to guess N that will hit 100ms exactly
val N2 = N * (1e8 / (sample*N)).toLong
if ((N.toDouble / N2.toDouble - 1.0).abs < .15)
// we're close enough, stop interpolating and just grow N exponentially
N = (N.toDouble*1.2).toLong
else
// not close enough, so use the linear interpolation
N = N2
}
// otherwise increase N gradually to decrease variance
else N = (N.toDouble*1.2).toLong
print(s"\r * $label: ${formatNanos(sample)}, N=$N, deviation=$percentDeviation%, target deviation: ${threshold*100}% ")
val Y = 10 //
val samples = (0 until Y) map { _ =>
i = 0 ; val startTime = System.nanoTime
// note - we sum the `Long` values returned from each `action`, to ensure
// `action` cannot be optimized away
while (i < N) { K += action; i += 1 }
val stopTime = System.nanoTime
val sample = (stopTime - startTime) / N
print(" ")
System.gc() // try to minimize variance due to GC timing
sample
}
val mean = samples.sum / Y.toDouble
val variance = samples.map(x => math.pow(x.toDouble - mean, 2)).sum / Y
val stddev = math.sqrt(variance)
val v = stddev / mean
percentDeviation = (v * 1000).toInt.toDouble / 10
if (v <= threshold) {
ok = false
// println("% deviation below threshold: " + v)
}
else {
// println("% deviation too high, increasing trials: " + v)
}
sample = mean
}
println("\r - "+label + ": " + formatNanos(sample) + s" ($percentDeviation% deviation, N=$N, K = ${K.toString.take(3)}) ")
(label, sample)
}
def roundToThousands(n: Double) = (n * 1000).toInt / 1000.0
def roundToHundreds(n: Double) = (n * 100).toInt / 100.0
// def formatNanos(nanos: Double) = nanos.toString
def formatNanos(nanos: Double) = {
if (nanos > 1e9) roundToThousands(nanos/1e9).toString + " seconds"
else if (nanos > 1e6) roundToThousands(nanos/1e6).toString + " milliseconds"
else if (nanos > 1e3) roundToThousands(nanos/1e3).toString + " microseconds"
else nanos.toString + " nanoseconds"
}
def suite(s: (String,Double)*): Unit = {
val tests = s.toList.sortBy(_._2)
val minNanos = tests.head._2
// min * x = y
// x = y / min
tests.foreach { case (label, nanos) =>
val x = roundToHundreds(nanos / minNanos)
println(x.toString.padTo(16, " ").mkString + label)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment