Skip to content

Instantly share code, notes, and snippets.

@AlexRogalskiy
Created December 19, 2020 20:30
Show Gist options
  • Save AlexRogalskiy/d38b974cf3b6d6ed02b6129712bf677e to your computer and use it in GitHub Desktop.
Save AlexRogalskiy/d38b974cf3b6d6ed02b6129712bf677e to your computer and use it in GitHub Desktop.
Scala memoizer
import scala.concurrent.duration.FiniteDuration
/**
* A map-like object (with only the most basic functions implemented) that is thread safe but doesn't is not locking.
* The size of the map is limited to a maximum bound and the "oldest" entries are thrown out.
*/
class LockFreeBoundedCache[K, E](maxSize: Int) {
import java.util.concurrent.atomic.AtomicReference
import scala.collection.mutable.LinkedHashMap
import scala.annotation.tailrec
//the AtomicReference that is keeping track of the LinkedHashMap pointer
private val cache: AtomicReference[LinkedHashMap[K, E]] = new AtomicReference(new LinkedHashMap())
/**
* get the stored value for a particular key. Returns None if the key is not in the map.
*/
def get(key: K): Option[E] = cache.get.get(key) //basic debug: .map { k => println(s"from cache $key"); k }
/**
* put a new element in the map and make sure we do not run over the bound. This operation might take an indefinite amount of time if there
* is a large amount of contention.
* To avoid waiting indefinitely on the put, you can specify a maximum duration to wait for and fail the put if we had to wait too much.
* Returns true if the put was successful, false if we were not able to put the new value within the given time.
*/
def put(key: K, elt: E, maxWait: Option[FiniteDuration] = None): Boolean = {
//start counting until the maxWait
val deadline = maxWait.map(_.fromNow)
//a tail recursive "while" loop that will only finish when a successful CAS is operated on the Reference.
@tailrec
def cas(): Boolean = {
//get the reference to the current map (these might be changed later by another thread)
val oldMap = cache.get
//create a new updated map
val newMap = {
//1- clone and add the new element
val clone = oldMap.clone += key -> elt
//2- check that we are within the bound
if (clone.size > maxSize) clone.drop(clone.size - maxSize)
else clone
}
//try to CAS the AtomicReference. If we fail, try again
if (!cache.compareAndSet(oldMap, newMap)) deadline match {
//except if we are past the deadline
case Some(d) if d.isOverdue => false
case _ => cas()
//equivalent to (but covering all cases):
//case None => cas()
//case Some(d) if d.hasTimeLeft => cas()
} else {
true
}
}
cas()
}
/**
* try to get the key, if it's not in the map yet, put the value in it. (useful for a cache use)
*/
def getOrElse(key: K, maxWait: Option[FiniteDuration] = None)(elt: => E): E = {
get(key).getOrElse {
val newElt = elt
put(key, newElt, maxWait)
newElt
}
}
}
object Memoizer {
/**
* wrap the function within a bounded cache.
* We have to specify the maximum size of the cache, and optionally a maximum time to wait for access to the cache.
* It's a good idea to specify a maxWait. If the cache is too busy, we do not want to wait indefinitely for our computed value, it's better to
* recompute it.
*/
def boundedMemoize[I, O](bound: Int, maxWait: Option[FiniteDuration] = None)(f: I => O): I => O = {
// use the lock free map for the cache. One per wrapped function
val cache = new LockFreeBoundedCache[I, O](bound)
//return a new function that uses the cache instead of doing the straight computation
((in: I) =>
cache.getOrElse(in, maxWait) {
f(in)
})
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment