Skip to content

Instantly share code, notes, and snippets.

@yuroyoro
Created September 8, 2010 06:12
Show Gist options
  • Save yuroyoro/569718 to your computer and use it in GitHub Desktop.
Save yuroyoro/569718 to your computer and use it in GitHub Desktop.
trait Memorized[T,R] {
import scala.collection.mutable._
val cache = new HashMap[T,R]
def cacheOrApply( t:T )( f: => R ):R = cache get(t) getOrElse{
val rv = f
cache += t -> rv
println("do %s:%s" format(t, rv))
rv
}
}
object Memoized {
def apply[T1, R]( f:T1 => R ) =
new Function1[T1,R] with Memorized[T1,R] {
def apply( a1:T1 ):R = cacheOrApply( a1 ){ f(a1) }
}
def apply[T1,T2,R]( f:(T1,T2) => R ) =
new Function2[T1,T2,R] with Memorized[(T1,T2),R]{
def apply(a1:T1, a2:T2):R = cacheOrApply( (a1,a2) ){ f(a1,a2) }
}
def apply[T1,T2,T3,R]( f:(T1,T2,T3) => R ) =
new Function3[T1,T2,T3,R] with Memorized[(T1,T2,T3),R]{
def apply(a1:T1, a2:T2, a3:T3):R = cacheOrApply( (a1,a2,a3) ){ f(a1,a2,a3) }
}
}
object Main {
def main( args:Array[String] ) = {
import scala.testing.Benchmark
def fib(n:Int):BigInt = if (n < 2) 1 else fib(n-1) + fib(n-2)
val mf = Memoized(fib _ )
def memorizedFib(n:Int) = if( n < 2 ) 1 else mf(n-1) + mf(n-2)
def printResult(name:String, res:Seq[Long]) = {
println("Benchmark %s" format name )
res.zipWithIndex.foreach{ case (a,n) =>
println( "%2d: %d millsec" format( n,a))
}
println( "%s Ave:%s millsec" format( name,res.sum / 10) )
}
object FibBenchMark extends Benchmark {
def run = fib(35)
}
printResult( "normal fib", FibBenchMark.runBenchmark(10) )
object MemorizedFibBenchMark extends Benchmark {
def run = memorizedFib(35)
}
printResult( "memorized fib", MemorizedFibBenchMark.runBenchmark(10) )
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment