Created
December 4, 2011 12:25
-
-
Save j5ik2o/1430072 to your computer and use it in GitHub Desktop.
Lock & STM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env xsbtscript | |
!# | |
/*** | |
scalaVersion := "2.9.1" | |
resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/" | |
libraryDependencies ++= Seq( | |
"se.scalablesolutions.akka" % "akka-stm" % "1.2" | |
) | |
*/ | |
import java.util.concurrent.CountDownLatch | |
import scala.collection.mutable.HashSet | |
import scala.testing.Benchmark | |
import akka.stm._ | |
import akka.util.duration._ | |
import compat.Platform | |
trait Sequence { | |
def getValue():Int | |
def getAndIncrement():Int | |
} | |
class SequenceByUnsafe extends Sequence { | |
var value = 0 | |
def getValue() = value | |
def getAndIncrement() = { | |
value += 1 | |
value | |
} | |
} | |
class SequenceBySynchronized extends SequenceByUnsafe { | |
override def getAndIncrement() = synchronized{ | |
super.getAndIncrement() | |
} | |
} | |
class SequenceBySTM extends Sequence { | |
var value = Ref(0) | |
def getValue() = value.get | |
implicit val txFactory = TransactionFactory(maxRetries = 10000) | |
// var abort = 0 | |
// var commit = 0 | |
def getAndIncrement() = atomic { | |
val result = value alter (_ + 1) | |
/* | |
deferred { | |
commit += 1 | |
} | |
compensating { | |
abort += 1 | |
} | |
*/ | |
result | |
} | |
// override def toString = "%d,%d,%f".format(abort, commit, if(commit>0) BigDecimal(abort) / BigDecimal(commit) * 100 else BigDecimal(0)) | |
} | |
object Main { | |
val THREAD_LOOP = 1000 | |
val THREAD_COUNT = 2 | |
def main(args: Array[String]) { | |
args match{ | |
case Array("L") => { | |
println(">>synchronized") | |
threadTest(new SequenceBySynchronized) | |
} | |
case Array("S") => { | |
println(">>STM") | |
val sstm = new SequenceBySTM | |
threadTest(sstm) | |
} | |
case Array("U") => { | |
println(">>unsafe") | |
threadTest(new SequenceByUnsafe) | |
} | |
} | |
} | |
def threadTest(sequence:Sequence) { | |
Platform.collectGarbage | |
val startLatch = new CountDownLatch(1); | |
val threads = HashSet.empty[Thread] | |
// スレッドの準備 | |
for (i <- 1 to THREAD_COUNT) { | |
val thread = new Thread(new ThreadAccess(sequence, startLatch, THREAD_LOOP)); | |
threads += (thread) | |
thread.start(); | |
} | |
// 足並み揃えてゴー。 | |
startLatch.countDown(); | |
try { | |
// みんなが終わるのを待つ | |
for (thread <- threads) { | |
thread.join(); | |
} | |
} catch { | |
case e: InterruptedException => | |
e.printStackTrace(); | |
} | |
} | |
// スレッドの処理 | |
class ThreadAccess(sequence: Sequence, startLatch: CountDownLatch, loopCount: Int) extends Runnable { | |
import BenchUtil._ | |
override def run() { | |
try { | |
startLatch.await(); | |
bench(100){ | |
for(i <- 1 to (loopCount/100)){ | |
val counter = sequence.getAndIncrement(); | |
} | |
} | |
} catch { | |
case e: InterruptedException => | |
e.printStackTrace(); | |
} | |
} | |
} | |
object BenchUtil { | |
private def avg(xs:List[BigDecimal]):BigDecimal = | |
xs.sum / xs.size | |
private def std(xs:List[BigDecimal]):BigDecimal = { | |
val a = avg(xs) | |
Math.sqrt((xs.foldLeft(BigDecimal(0))((s,c) => s + (c-a) * (c+a)) / xs.size).toDouble) | |
} | |
private def median(xs:List[BigDecimal]) = xs.toSet.toList.sortWith(_ < _) match { | |
case n :: Nil => n | |
case xs if xs.size % 2 !=0 => xs(xs.size / 2) | |
case xs if xs.size % 2 == 0 => { | |
val a = xs(xs.size / 2 -1) | |
val b = xs(xs.size / 2) | |
(a + b)/2 | |
} | |
case _ => throw new RuntimeException | |
} | |
private def mode(xs:List[BigDecimal]):BigDecimal = | |
xs.foldLeft(Map[BigDecimal,Int]().withDefaultValue(0)){(map,key) => map + (key -> (map(key) + 1))} maxBy(_._2) _1 | |
def bench(n:Int)(f: => Unit) { | |
val truncate = n / 5 | |
val times = for(i <- List.range(1, n + 1, 1)) yield { | |
val start = System.nanoTime | |
f | |
val stop = System.nanoTime | |
Platform.collectGarbage | |
BigDecimal(stop - start) / 1000 | |
} | |
val result = times.sortWith(_ < _).view(truncate, n - truncate).toList | |
if (result.size > 0 ){ | |
println("threadId = %d, n = %d, avg = %11.2f, std = %11.2f, median = %11.2f, mode = %11.2f, min = %11.2f, max = %11.2f". | |
format(Thread.currentThread.getId, result.size, avg(result), std(result), median(result), mode(result), result.min, result.max)) | |
} | |
} | |
} | |
} | |
Main.main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-----STMの方を読み込みだけにしたテスト結果-----