Created
May 21, 2015 02:01
-
-
Save danielhfrank/a937a00b9f4da8d9968a to your computer and use it in GitHub Desktop.
OverTimeMonoid
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.stripe.danielhfrank | |
import com.twitter.algebird.Monoid | |
import scala.annotation.tailrec | |
import scala.collection.mutable | |
class OverTimeMonoid[V: Monoid] extends Monoid[ Seq[(Long, V)]] { | |
override def zero: Seq[(Long, V)] = Seq.empty | |
/** | |
* Arbitrarily start with l, and successively merge in each element of r | |
* @param l | |
* @param r | |
* @return | |
*/ | |
override def plus(l: Seq[(Long, V)], r: Seq[(Long, V)]): Seq[(Long, V)] = { | |
// r.foldLeft(l) { case(x,y) => mergeIn(x, y) } | |
tailRecMerge(List.empty[(Long, V)], l, r, Monoid.zero[V], Monoid.zero[V]) | |
} | |
/** | |
* General idea - save last seen value (which is total) from each side. Find min by Long, | |
* add in previous value from the *other* side, and then continue down the chain (with a | |
* new previous value as appropriate) | |
* @param l | |
* @param r | |
* @param lastL | |
* @param lastR | |
* @return | |
*/ | |
@tailrec | |
private def tailRecMerge(acc: List[(Long, V)], l: Seq[(Long, V)], r: Seq[(Long, V)], lastL: V, lastR: V): List[(Long, V)] = { | |
// val last = acc.headOption.map(_._2).getOrElse(Monoid.zero[V]) | |
(l,r) match { | |
case (Nil, (rt, rv)::rs) => tailRecMerge((rt, Monoid.plus(rv, lastL))::acc, l, rs, lastL, rv) | |
case ((lt, lv)::ls, Nil) => tailRecMerge((lt, Monoid.plus(lv, lastR))::acc, r, ls, lastR, lv) | |
case((lt,lv)::ls, (rt,rv)::rs) => | |
if(lt < rt){ | |
tailRecMerge((lt, Monoid.plus(lv, lastR))::acc, ls, r, lv, lastR) | |
}else{ | |
tailRecMerge((rt, Monoid.plus(rv, lastL))::acc, l, rs, lastL, rv) | |
} | |
case (Nil, Nil) => acc.reverse | |
} | |
} | |
} | |
class BatchingOverTimeMonoid[V: Monoid](batchSize: Int = 2048) extends OverTimeMonoid[V] { | |
val buf = mutable.Buffer[Seq[(Long, V)]]() | |
var acc = Seq.empty[(Long, V)] | |
/** | |
* Buffer up batches and divide-and-conquer sum them when they fill up. | |
* Designed for case where we are summing together singleton lists - common | |
* because that is what you get from a single event/row, and because summing | |
* singleton lists means you don't have to do a sort. | |
* @param vs | |
* @return | |
*/ | |
override def sum(vs: TraversableOnce[Seq[(Long, V)]]): Seq[(Long, V)] = { | |
vs.foreach{ s => | |
buf.append(s) | |
if (buf.size >= batchSize) | |
mergeAndClear() | |
} | |
mergeAndClear() | |
acc | |
} | |
private def mergeAndClear() = { | |
acc = plus(acc, sumBuf(buf)) | |
buf.clear() | |
} | |
private def sumBuf(buf: mutable.Buffer[Seq[(Long, V)]]): Seq[(Long,V)] = { | |
if (buf.size < 3) | |
super.sum(buf) | |
else | |
plus(sumBuf(buf.slice(0, buf.size / 2)), sumBuf(buf.slice(buf.size / 2, buf.size))) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment