Skip to content

Instantly share code, notes, and snippets.

@avnerbarr
Last active June 15, 2020 10:23
Show Gist options
  • Save avnerbarr/4e2efa42970fa1b76d7e55bcfd6f2353 to your computer and use it in GitHub Desktop.
Save avnerbarr/4e2efa42970fa1b76d7e55bcfd6f2353 to your computer and use it in GitHub Desktop.
import java.util.concurrent.locks.ReentrantReadWriteLock
import com.typesafe.scalalogging.LazyLogging
import scala.collection.mutable
import scala.concurrent.duration._
/**
* Trivial cache which is lazily limited to size - when size is exceeded with evict the oldest object
* @param limitSize the limit size option
* @param defaultRetainTime the default retention time - lazily evaluated and refreshed as needed
* @tparam K the key type
* @tparam V the value type
*/
class TrivialCache[K, V](val limitSize: Option[Long],val defaultRetainTime: FiniteDuration = 10 minutes) extends LazyLogging {
logger.info(s"starting cache with limit size of $limitSize and retention time of $defaultRetainTime")
type TimeStamp = Long
// class used to wrap a value with it's cache meta information
case class Wrapper[X](value: X, insertedAt: TimeStamp, ttl: Option[Duration])
// syncronziing lock around reads and writes to the cache
// this lock acts as a write barrier
// multiple reads can occur simultaneously, while only a single writer can exist at a time to guarantee consistancy of the mutable state
private val lock = new ReentrantReadWriteLock()
// https://alvinalexander.com/scala/how-to-choose-map-implementation-class-sorted-scala-cookbook
private val keyValueMapping = mutable.LinkedHashMap[K,Wrapper[V]]()
/**
* Returns a dictionary representation of the current state of the cache
* @return the map of key to value of the cache
*/
def toDict(): Map[K, V] = {
lock.readLock().lock()
val m = keyValueMapping.map {
case (k,v) => (k,v.value)
}.toMap
lock.readLock().unlock()
m
}
def keyValues(): Iterable[(K, TimeStamp, V)] = {
lock.readLock().lock()
val m = keyValueMapping.map {
case (k,v) => (k,v.insertedAt,v.value)
}
lock.readLock().unlock()
m.toIterable
}
/**
* Inserts a key,value pair with a given TTL duration, if no ttl is supplied the default duration will be used
* If the same key is inserted multiple times, the TTL will be in regards to the last insertion time
* @param key the key
* @param value the value
* @param ttl the optional overide for TTL - notice that objects are still evicted by insertion order and the TTL only impacts reheating of this object
*/
def insert(key: K, value: V, ttl: Option[Duration]): Unit = {
lock.writeLock().lock()
if (keyValueMapping.size >= limitSize.getOrElse(10000000L)) {
// map keys are ordered by insertion order since LinkedHashMap
// only need to take the head - we disregard customer TTL ordering of keys
val oldestKey: K = keyValueMapping.head._1
// if considering otherwise possible to use a different technique
// val oldestKey: K = keyValueMapping.keys.toSeq.sortWith((l, r) => keyValueMapping(l).insertedAt < keyValueMapping(r).insertedAt).head
keyValueMapping.remove(oldestKey)
}
// actually planting the value into the cache
val now = System.nanoTime()
// override previous if exists
keyValueMapping.remove(key) // remove the key to side effect ordering of LinkedHashMap
keyValueMapping.put(key, Wrapper(value, now, Some(ttl.getOrElse(defaultRetainTime))))
lock.writeLock().unlock()
}
/**
* Remove a value for a given key
* @param key the key to remove
* @return true if the value existed and was removed
*/
def remove(key: K): Boolean = {
lock.writeLock().lock()
val a: Option[Wrapper[V]] = keyValueMapping.remove(key)
// returns true if existed and was removed or false otherwise
lock.writeLock().unlock()
a.isDefined
}
/**
* Remove all values from the cache
*/
def removeAll(): Unit = {
lock.writeLock().lock()
logger.debug("cleaning cache")
keyValueMapping.clear()
lock.writeLock().unlock()
}
/**
* Get a value for a given key
* @param key the key to get
* @return The optional value if exists
*/
def get(key: K): Option[V] = {
get(key, {None})
}
/**
* Get a value for a given key, if the key doesn't exist the block is called which will populate the cache for the given key if the value exists
* @param key the key
* @param f the block that is applied when the key doesn't exist or the ttl has expired for the key
* @return the option for the value of the key
*/
def get(key: K, f: => Option[(V, Option[Duration])]): Option[V] = {
def recache(): Option[V] = {
val newv: Option[(V, Option[Duration])] = f
newv match {
case Some(x: (V, Option[Duration])) => {
val duration: Option[Duration] = x._2 match {
case Some(x: Duration) => Some(x) // a ttl time was provided
case None => Some(defaultRetainTime) // ttl wasn't provided so using default retention
}
insert(key = key, value = x._1, ttl = duration) // we can call insert directly
Some(x._1)
}
case None => None // nothing returned by user
}
}
def isExpired(x: Wrapper[V]): Boolean = {
remainingTime(x) < 0
}
def remainingTime(wrapper: Wrapper[V], from: Long = System.nanoTime()): TimeStamp = {
val remaining: TimeStamp = (wrapper.insertedAt + wrapper.ttl.getOrElse(defaultRetainTime).toNanos) - System.nanoTime()
remaining
}
// lock for reading
lock.readLock().lock()
keyValueMapping.get(key) match {
case Some(x: Wrapper[V]) => {
// unlock read
lock.readLock().unlock()
if (isExpired(x)) {
logger.debug(s"cache expire for key: $key")
// remove will lock writes
remove(key)
// recache will lock writes
recache()
} else {
// update LRU
val now = System.nanoTime()
// we re-insert the wrapper with the updated TTL duration til death of key
val remainingTTL: TimeStamp = remainingTime(wrapper = x, from = now)
val updatedWrapper = Wrapper(x.value, now, Some(remainingTTL nanoseconds) )
lock.writeLock().lock()
keyValueMapping.remove(key) // we need to remove and put back in order to move key to top of LinkedHashMap insertion order for TTL eviction reasons
keyValueMapping.put(key, updatedWrapper)
lock.writeLock().unlock()
Some(x.value)
}
}
case None => {
// not in cache
lock.readLock().unlock()
val newv: Option[(V, Option[Duration])] = f
newv match {
case Some(x) => {
logger.debug(s"cache warming for key: $key")
// insert will lock writes
insert(key, x._1, x._2)
Some(x._1)
}
case None => None
}
}
}
}
/**
* returns the current size of the cache
* @return
*/
def size(): Int = {
lock.readLock().lock()
val s = keyValueMapping.size
lock.readLock().unlock()
s
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment