Created
December 19, 2020 20:30
-
-
Save AlexRogalskiy/d38b974cf3b6d6ed02b6129712bf677e to your computer and use it in GitHub Desktop.
Scala memoizer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.concurrent.duration.FiniteDuration | |
/** | |
* A map-like object (with only the most basic functions implemented) that is thread safe but doesn't is not locking. | |
* The size of the map is limited to a maximum bound and the "oldest" entries are thrown out. | |
*/ | |
class LockFreeBoundedCache[K, E](maxSize: Int) { | |
import java.util.concurrent.atomic.AtomicReference | |
import scala.collection.mutable.LinkedHashMap | |
import scala.annotation.tailrec | |
//the AtomicReference that is keeping track of the LinkedHashMap pointer | |
private val cache: AtomicReference[LinkedHashMap[K, E]] = new AtomicReference(new LinkedHashMap()) | |
/** | |
* get the stored value for a particular key. Returns None if the key is not in the map. | |
*/ | |
def get(key: K): Option[E] = cache.get.get(key) //basic debug: .map { k => println(s"from cache $key"); k } | |
/** | |
* put a new element in the map and make sure we do not run over the bound. This operation might take an indefinite amount of time if there | |
* is a large amount of contention. | |
* To avoid waiting indefinitely on the put, you can specify a maximum duration to wait for and fail the put if we had to wait too much. | |
* Returns true if the put was successful, false if we were not able to put the new value within the given time. | |
*/ | |
def put(key: K, elt: E, maxWait: Option[FiniteDuration] = None): Boolean = { | |
//start counting until the maxWait | |
val deadline = maxWait.map(_.fromNow) | |
//a tail recursive "while" loop that will only finish when a successful CAS is operated on the Reference. | |
@tailrec | |
def cas(): Boolean = { | |
//get the reference to the current map (these might be changed later by another thread) | |
val oldMap = cache.get | |
//create a new updated map | |
val newMap = { | |
//1- clone and add the new element | |
val clone = oldMap.clone += key -> elt | |
//2- check that we are within the bound | |
if (clone.size > maxSize) clone.drop(clone.size - maxSize) | |
else clone | |
} | |
//try to CAS the AtomicReference. If we fail, try again | |
if (!cache.compareAndSet(oldMap, newMap)) deadline match { | |
//except if we are past the deadline | |
case Some(d) if d.isOverdue => false | |
case _ => cas() | |
//equivalent to (but covering all cases): | |
//case None => cas() | |
//case Some(d) if d.hasTimeLeft => cas() | |
} else { | |
true | |
} | |
} | |
cas() | |
} | |
/** | |
* try to get the key, if it's not in the map yet, put the value in it. (useful for a cache use) | |
*/ | |
def getOrElse(key: K, maxWait: Option[FiniteDuration] = None)(elt: => E): E = { | |
get(key).getOrElse { | |
val newElt = elt | |
put(key, newElt, maxWait) | |
newElt | |
} | |
} | |
} | |
object Memoizer { | |
/** | |
* wrap the function within a bounded cache. | |
* We have to specify the maximum size of the cache, and optionally a maximum time to wait for access to the cache. | |
* It's a good idea to specify a maxWait. If the cache is too busy, we do not want to wait indefinitely for our computed value, it's better to | |
* recompute it. | |
*/ | |
def boundedMemoize[I, O](bound: Int, maxWait: Option[FiniteDuration] = None)(f: I => O): I => O = { | |
// use the lock free map for the cache. One per wrapped function | |
val cache = new LockFreeBoundedCache[I, O](bound) | |
//return a new function that uses the cache instead of doing the straight computation | |
((in: I) => | |
cache.getOrElse(in, maxWait) { | |
f(in) | |
}) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment