-
-
Save ericsson49/fa4fa9a767bef2ec3cd8cef6666d66c2 to your computer and use it in GitHub Desktop.
Test performance of pure vs impure snipets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.collection.mutable | |
// Scala 3 | |
abstract class PerfTest[T] { | |
def setup(n: Int): T | |
def runTest(o: T, n: Int): Int | |
} | |
class PureAppendTest extends PerfTest[Vector[Int]] { | |
def setup(b: Int) = Vector[Int]() | |
def runTest(lst: Vector[Int], n: Int): Int = { | |
var res = lst | |
for (i <- 0 until n) { | |
res = res :+ i | |
} | |
return res.sum | |
} | |
} | |
class ImpureAppendTest extends PerfTest[mutable.ArrayBuffer[Int]] { | |
def setup(n: Int) = mutable.ArrayBuffer[Int]() | |
def runTest(lst: mutable.ArrayBuffer[Int], n: Int): Int = { | |
val res = lst | |
for (i <- 0 until n) { | |
res += i | |
} | |
return res.sum | |
} | |
} | |
class PureUpdateListTest extends PerfTest[Vector[Int]] { | |
def setup(n: Int) = Vector.fill(n)(0) | |
def runTest(lst: Vector[Int], n: Int): Int = { | |
var res = lst | |
for (i <- 0 until n) { | |
res = res.updated(i, res(i) + 11) | |
} | |
return res.sum | |
} | |
} | |
class ImpureUpdateListTest extends PerfTest[mutable.ArrayBuffer[Int]] { | |
def setup(n: Int) = mutable.ArrayBuffer.fill(n)(0) | |
def runTest(lst: mutable.ArrayBuffer[Int], n: Int): Int = { | |
val res = lst | |
for (i <- 0 until n) { | |
res(i) += 11 | |
} | |
return res.sum | |
} | |
} | |
case class Validator(val balance: Int) | |
case class BeaconState( | |
val balances: Vector[Int], | |
val validators: Vector[Validator] | |
) | |
case class MValidator(var balance: Int) | |
case class MBeaconState( | |
val balances: scala.collection.mutable.ArrayBuffer[Int], | |
val validators: scala.collection.mutable.ArrayBuffer[MValidator] | |
) | |
class PureUpdateBalancesTest extends PerfTest[BeaconState] { | |
def setup(n: Int) = BeaconState( | |
Vector.fill(n)(0), | |
Vector.tabulate(n)(_ => Validator(0))) | |
def runTest(initial: BeaconState, n: Int): Int = { | |
var state = initial | |
for (i <- 0 until n) { | |
state = state.copy(balances = state.balances.updated(i, state.balances(i) + 23)) | |
} | |
return state.balances.sum | |
} | |
} | |
class ImureUpdateBalancesTest extends PerfTest[MBeaconState] { | |
def setup(n: Int) = MBeaconState( | |
mutable.ArrayBuffer.fill(n)(0), | |
mutable.ArrayBuffer.tabulate(n)(_ => MValidator(0))) | |
def runTest(state: MBeaconState, n: Int): Int = { | |
for (i <- 0 until n) { | |
state.balances(i) += 23 | |
} | |
return state.balances.sum | |
} | |
} | |
class PureUpdateValidatorBalancesTest extends PerfTest[BeaconState] { | |
def setup(n: Int) = BeaconState( | |
Vector.fill(n)(0), | |
Vector.tabulate(n)(_ => Validator(0))) | |
def runTest(initial: BeaconState, n: Int): Int = { | |
var state = initial | |
for (i <- 0 until n) { | |
state = state.copy( | |
validators = state.validators.updated( | |
i, state.validators(i).copy(balance = state.validators(i).balance + 32))) | |
} | |
return state.validators.map(v => v.balance).sum | |
} | |
} | |
class ImureUpdateValidatorBalancesTest extends PerfTest[MBeaconState] { | |
def setup(n: Int) = MBeaconState( | |
mutable.ArrayBuffer.fill(n)(0), | |
mutable.ArrayBuffer.tabulate(n)(_ => MValidator(0))) | |
def runTest(state: MBeaconState, n: Int): Int = { | |
for (i <- 0 until n) { | |
state.validators(i).balance += 32 | |
} | |
return state.validators.map(v => v.balance).sum | |
} | |
} | |
@main def test: Unit = { | |
val n = 4096 | |
val c = 5000 | |
val params = List(n, 2*n, 4*n, 8*n) | |
val tests: List[(String,(PerfTest[_],PerfTest[_]))] = List( | |
("list append", (PureAppendTest(), ImpureAppendTest())), | |
("list update", (PureUpdateListTest(), ImpureUpdateListTest())), | |
("update balance", (PureUpdateBalancesTest(), ImureUpdateBalancesTest())), | |
("update validator balance", (PureUpdateValidatorBalancesTest(), ImureUpdateValidatorBalancesTest())) | |
) | |
for(t <- tests) { | |
val (testName, (pure, impure)) = t | |
println(testName) | |
warmUp(10*c, n, pure) | |
warmUp(10*c, n, impure) | |
for(p <- params) { | |
println(s"n = $p") | |
val (pureResults, pureTime) = performTest(c, p, pure) | |
val (impureResults, impureTime) = performTest(c, p, impure) | |
if (impureResults != pureResults) | |
throw RuntimeException() | |
println(s"pure = $pureTime ns") | |
println(s"impure = $impureTime ns") | |
println(s"ratio = ${Math.round(pureTime*10.0/impureTime)/10.0}") | |
} | |
} | |
} | |
def warmUp[T](c: Int, n: Int, test: PerfTest[T]): Unit = { | |
for (i <- 0 until c) { | |
val d = test.setup(n) | |
test.runTest(d, n) | |
} | |
} | |
def performTest[T](c: Int, n: Int, test: PerfTest[T]): (Seq[Int], Long) = { | |
var time: Long = 0 | |
val results = mutable.ArrayBuffer[Int]() | |
for(i <- 0 until c) { | |
val d = test.setup(n) | |
val s = System.nanoTime() | |
results += test.runTest(d, n) | |
val s2 = System.nanoTime() | |
time += (s2 - s) | |
} | |
return (results.toSeq, Math.round((time * 1.0) / c)) | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment