Skip to content

Instantly share code, notes, and snippets.

@danielhfrank
Created May 21, 2015 02:01
Show Gist options
  • Save danielhfrank/a937a00b9f4da8d9968a to your computer and use it in GitHub Desktop.
Save danielhfrank/a937a00b9f4da8d9968a to your computer and use it in GitHub Desktop.
OverTimeMonoid
package com.stripe.danielhfrank
import com.twitter.algebird.Monoid
import scala.annotation.tailrec
import scala.collection.mutable
class OverTimeMonoid[V: Monoid] extends Monoid[ Seq[(Long, V)]] {
override def zero: Seq[(Long, V)] = Seq.empty
/**
* Arbitrarily start with l, and successively merge in each element of r
* @param l
* @param r
* @return
*/
override def plus(l: Seq[(Long, V)], r: Seq[(Long, V)]): Seq[(Long, V)] = {
// r.foldLeft(l) { case(x,y) => mergeIn(x, y) }
tailRecMerge(List.empty[(Long, V)], l, r, Monoid.zero[V], Monoid.zero[V])
}
/**
* General idea - save last seen value (which is total) from each side. Find min by Long,
* add in previous value from the *other* side, and then continue down the chain (with a
* new previous value as appropriate)
* @param l
* @param r
* @param lastL
* @param lastR
* @return
*/
@tailrec
private def tailRecMerge(acc: List[(Long, V)], l: Seq[(Long, V)], r: Seq[(Long, V)], lastL: V, lastR: V): List[(Long, V)] = {
// val last = acc.headOption.map(_._2).getOrElse(Monoid.zero[V])
(l,r) match {
case (Nil, (rt, rv)::rs) => tailRecMerge((rt, Monoid.plus(rv, lastL))::acc, l, rs, lastL, rv)
case ((lt, lv)::ls, Nil) => tailRecMerge((lt, Monoid.plus(lv, lastR))::acc, r, ls, lastR, lv)
case((lt,lv)::ls, (rt,rv)::rs) =>
if(lt < rt){
tailRecMerge((lt, Monoid.plus(lv, lastR))::acc, ls, r, lv, lastR)
}else{
tailRecMerge((rt, Monoid.plus(rv, lastL))::acc, l, rs, lastL, rv)
}
case (Nil, Nil) => acc.reverse
}
}
}
class BatchingOverTimeMonoid[V: Monoid](batchSize: Int = 2048) extends OverTimeMonoid[V] {
val buf = mutable.Buffer[Seq[(Long, V)]]()
var acc = Seq.empty[(Long, V)]
/**
* Buffer up batches and divide-and-conquer sum them when they fill up.
* Designed for case where we are summing together singleton lists - common
* because that is what you get from a single event/row, and because summing
* singleton lists means you don't have to do a sort.
* @param vs
* @return
*/
override def sum(vs: TraversableOnce[Seq[(Long, V)]]): Seq[(Long, V)] = {
vs.foreach{ s =>
buf.append(s)
if (buf.size >= batchSize)
mergeAndClear()
}
mergeAndClear()
acc
}
private def mergeAndClear() = {
acc = plus(acc, sumBuf(buf))
buf.clear()
}
private def sumBuf(buf: mutable.Buffer[Seq[(Long, V)]]): Seq[(Long,V)] = {
if (buf.size < 3)
super.sum(buf)
else
plus(sumBuf(buf.slice(0, buf.size / 2)), sumBuf(buf.slice(buf.size / 2, buf.size)))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment