Skip to content

Instantly share code, notes, and snippets.

@miguno
Created September 29, 2014 16:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save miguno/3bf1a24be446b2edaeb1 to your computer and use it in GitHub Desktop.
Save miguno/3bf1a24be446b2edaeb1 to your computer and use it in GitHub Desktop.
Draft of generic CMS
package com.miguno.algebird.extensions
import com.twitter.algebird.{Approximate, Monoid, MonoidAggregator}
import scala.collection.immutable.SortedSet
class GenCountMinSketchMonoid[K: Ordering : GenCMSHasher](eps: Double, delta: Double, seed: Int,
heavyHittersPct: Double = 0.01) extends Monoid[GenCMS[K]] {
assert(0 < eps && eps < 1, "eps must lie in (0, 1)")
assert(0 < delta && delta < 1, "delta must lie in (0, 1)")
assert(0 < heavyHittersPct && heavyHittersPct < 1, "heavyHittersPct must lie in (0, 1)")
// Typically, we would use d pair-wise independent hash functions of the form
//
// h_i(x) = a_i * x + b_i (mod p)
//
// But for this particular application, setting b_i does not matter (since all it does is shift the results of a
// particular hash), so we omit it (by setting b_i to 0) and simply use hash functions of the form
//
// h_i(x) = a_i * x (mod p)
//
val hashes: Seq[GenCMSHash[K]] = {
val r = new scala.util.Random(seed)
val numHashes = GenCMS.depth(delta)
val numCounters = GenCMS.width(eps)
(0 to (numHashes - 1)).map { _ => GenCMSHash[K](r.nextInt(), 0, numCounters)}
}
val params = GenCMSParams(hashes, eps, delta, heavyHittersPct)
val zero: GenCMS[K] = GenCMSZero[K](params)
/**
* We assume the Count-Min sketches on the left and right use the same hash functions.
*/
def plus(left: GenCMS[K], right: GenCMS[K]): GenCMS[K] = left ++ right
/**
* Create a Count-Min sketch out of a single item or data stream.
*/
def create(item: K): GenCMS[K] = GenCMSItem[K](item, params)
def create(data: Seq[K]): GenCMS[K] = {
data.foldLeft(zero) { case (acc, x) => plus(acc, create(x))}
}
}
object GenCMS {
def monoid[K: Ordering : GenCMSHasher](eps: Double, delta: Double, seed: Int, heavyHittersPct: Double = 0.01): GenCountMinSketchMonoid[K] =
new GenCountMinSketchMonoid[K](eps, delta, seed, heavyHittersPct)
def monoid[K: Ordering : GenCMSHasher](depth: Int, width: Int, seed: Int, heavyHittersPct: Double): GenCountMinSketchMonoid[K] =
new GenCountMinSketchMonoid[K](GenCMS.eps(width), GenCMS.delta(depth), seed, heavyHittersPct)
def aggregator[K: Ordering : GenCMSHasher](eps: Double, delta: Double, seed: Int, heavyHittersPct: Double = 0.01): GenCountMinSketchAggregator[K] = {
val monoid = new GenCountMinSketchMonoid[K](eps, delta, seed, heavyHittersPct)
new GenCountMinSketchAggregator[K](monoid)
}
def aggregator[K: Ordering : GenCMSHasher](depth: Int, width: Int, seed: Int, heavyHittersPct: Double): GenCountMinSketchAggregator[K] = {
val monoid = new GenCountMinSketchMonoid[K](GenCMS.eps(width), GenCMS.delta(depth), seed, heavyHittersPct)
new GenCountMinSketchAggregator[K](monoid)
}
/**
* Functions to translate between (eps, delta) and (depth, width). The translation is:
* depth = ceil(ln 1/delta)
* width = ceil(e / eps)
*/
def eps(width: Int) = scala.math.exp(1.0) / width
def delta(depth: Int) = 1.0 / scala.math.exp(depth)
def depth(delta: Double) = scala.math.ceil(scala.math.log(1.0 / delta)).toInt
def width(eps: Double) = scala.math.ceil(scala.math.exp(1) / eps).toInt
}
/**
* The actual Count-Min sketch data structure.
*/
sealed abstract class GenCMS[K] extends java.io.Serializable {
// Parameters used to bound confidence in error estimates.
def eps: Double
def delta: Double
// Number of hash functions.
def depth: Int = GenCMS.depth(delta)
// Number of counters per hash function.
def width: Int = GenCMS.width(eps)
def ++(other: GenCMS[K]): GenCMS[K]
/**
* Returns an estimate of the total number of times this item has been seen
* in the stream so far. This estimate is an upper bound.
*
* It is always true that trueFrequency <= estimatedFrequency.
* With probability p >= 1 - delta, it also holds that
* estimatedFrequency <= trueFrequency + eps * totalCount.
*/
def frequency(item: K): Approximate[Long]
/**
* Returns an estimate of the inner product against another data stream.
*
* In other words, let a_i denote the number of times element i has been seen in
* the data stream summarized by this CMS, and let b_i denote the same for the other CMS.
* Then this returns an estimate of <a, b> = \sum a_i b_i
*
* Note: this can also be viewed as the join size between two relations.
*
* It is always true that actualInnerProduct <= estimatedInnerProduct.
* With probability p >= 1 - delta, it also holds that
* estimatedInnerProduct <= actualInnerProduct + eps * thisTotalCount * otherTotalCount
*/
def innerProduct(other: GenCMS[K]): Approximate[Long]
/**
* Finds all heavy hitters, i.e., elements in the stream that appear at least
* (heavyHittersPct * totalCount) times.
*
* Every item that appears at least (heavyHittersPct * totalCount) times is output,
* and with probability p >= 1 - delta, no item whose count is less than
* (heavyHittersPct - eps) * totalCount is output.
*
* Note that the set of heavy hitters contains at most 1 / heavyHittersPct
* elements, so keeping track of all elements that appear more than (say) 1% of the
* time requires tracking at most 100 items.
*/
def heavyHittersPct: Double
def heavyHitters: Set[K]
/**
* Total number of elements seen in the data stream so far.
*/
def totalCount: Long
/**
* The first frequency moment is the total number of elements in the stream.
*/
def f1: Long = totalCount
/**
* The second frequency moment is `\sum a_i^2`, where a_i is the count of the ith element.
*/
def f2: Approximate[Long] = innerProduct(this)
}
/**
* Zero element. Used for initialization.
*/
case class GenCMSZero[K: Ordering](params: GenCMSParams[K]) extends GenCMS[K] {
def eps: Double = params.eps
def delta: Double = params.delta
def heavyHittersPct: Double = params.heavyHittersPct
def totalCount: Long = 0L
def ++(other: GenCMS[K]): GenCMS[K] = other
def frequency(item: K): Approximate[Long] = Approximate.exact(0L)
def innerProduct(other: GenCMS[K]): Approximate[Long] = Approximate.exact(0L)
def heavyHitters: Set[K] = Set[K]()
}
/**
* Used for holding a single element, to avoid repeatedly adding elements from sparse counts tables.
*/
case class GenCMSItem[K: Ordering](item: K, params: GenCMSParams[K]) extends GenCMS[K] {
def eps: Double = params.eps
def delta: Double = params.delta
def heavyHittersPct: Double = params.heavyHittersPct
def totalCount: Long = 1L
def ++(other: GenCMS[K]): GenCMS[K] = {
other match {
// TODO: Properly handle type erasure for K?
case other: GenCMSZero[_] => this
case other: GenCMSItem[K] => GenCMSInstance[K](params) + item + other.item
case other: GenCMSInstance[K] => other + item
}
}
def frequency(x: K): Approximate[Long] = if (item == x) Approximate.exact(1L) else Approximate.exact(0L)
def innerProduct(other: GenCMS[K]): Approximate[Long] = other.frequency(item)
def heavyHitters: Set[K] = Set(item)
}
/**
* The general Count-Min sketch structure, used for holding any number of elements.
*/
case class GenCMSInstance[K: Ordering](countsTable: GenCMSInstance.GenCMSCountsTable[K], totalCount: Long,
hhs: GenCMSInstance.HeavyHitters[K], params: GenCMSParams[K]) extends GenCMS[K] {
def eps: Double = params.eps
def delta: Double = params.delta
def heavyHittersPct: Double = params.heavyHittersPct
def ++(other: GenCMS[K]): GenCMS[K] = {
other match {
// TODO: Properly handle type erasure for K?
case other: GenCMSZero[_] => this
case other: GenCMSItem[K] => this + other.item
case other: GenCMSInstance[K] =>
val newTotalCount = totalCount + other.totalCount
val newHhs = (hhs ++ other.hhs).dropCountsBelow(params.heavyHittersPct * newTotalCount)
GenCMSInstance[K](countsTable ++ other.countsTable, newTotalCount, newHhs, params)
}
}
private def makeApprox(est: Long): Approximate[Long] = {
if (est == 0L) {
Approximate.exact(0L)
} else {
val lower = math.max(0L, est - (eps * totalCount).toLong)
Approximate(lower, est, est, 1 - delta)
}
}
def frequency(item: K): Approximate[Long] = {
val estimates = countsTable.counts.zipWithIndex.map {
case (row, i) =>
row(params.hashes(i)(item))
}
makeApprox(estimates.min)
}
/**
* Let X be a CMS, and let count_X[j, k] denote the value in X's 2-dimensional count table at row j and column k.
* Then the Count-Min sketch estimate of the inner product between A and B is the minimum inner product between their
* rows:
* estimatedInnerProduct = min_j (\sum_k count_A[j, k] * count_B[j, k])
*/
def innerProduct(other: GenCMS[K]): Approximate[Long] = {
other match {
// TODO: Properly handle type erasure for K?
case other: GenCMSInstance[_] =>
assert((other.depth, other.width) ==(depth, width), "Tables must have the same dimensions.")
def innerProductAtDepth(d: Int) = (0 to (width - 1)).map { w =>
countsTable.getCount(d, w) * other.countsTable.getCount(d, w)
}.sum
val est = (0 to (depth - 1)).map {
innerProductAtDepth
}.min
Approximate(est - (eps * totalCount * other.totalCount).toLong, est, est, 1 - delta)
case _ => other.innerProduct(this)
}
}
def heavyHitters: Set[K] = hhs.items
/**
* Updates the sketch with a new element from the data stream.
*/
def +(item: K): GenCMSInstance[K] = this +(item, 1L)
def +(item: K, count: Long): GenCMSInstance[K] = {
require(count >= 0, "count must be >= 0 (negative counts not implemented")
val newHhs = updateHeavyHitters(item, count)
val newCountsTable =
(0 to (depth - 1)).foldLeft(countsTable) {
case (table, row) =>
val pos = (row, params.hashes(row)(item))
table +(pos, count)
}
GenCMSInstance[K](newCountsTable, totalCount + count, newHhs, params)
}
/**
* Updates the data structure of heavy hitters when a new item (with associated count) enters the stream.
*/
private def updateHeavyHitters(item: K, count: Long): GenCMSInstance.HeavyHitters[K] = {
val oldItemCount = frequency(item).estimate
val newItemCount = oldItemCount + count
val newTotalCount = totalCount + count
// If the new item is a heavy hitter, add it, and remove any previous instances.
val newHhs =
if (newItemCount >= heavyHittersPct * newTotalCount) {
hhs - GenCMSInstance.HeavyHitter[K](item, oldItemCount) + GenCMSInstance.HeavyHitter[K](item, newItemCount)
} else {
hhs
}
// Remove any items below the new heavy hitter threshold.
newHhs.dropCountsBelow(heavyHittersPct * newTotalCount)
}
}
object GenCMSInstance {
/**
* Initializes a CMSInstance with all zeroes.
*/
def apply[K: Ordering](params: GenCMSParams[K]): GenCMSInstance[K] = {
val countsTable = CMSCountsTable[K](GenCMS.depth(params.delta), GenCMS.width(params.eps))
implicit val heavyHitterOrdering = HeavyHitter.ordering[K]
GenCMSInstance[K](countsTable, 0, HeavyHitters[K](SortedSet[HeavyHitter[K]]()), params)
}
/**
* The 2-dimensional table of counters used in the Count-Min sketch.
* Each row corresponds to a particular hash function.
* TODO: implement a dense matrix type, and use it here
*/
case class GenCMSCountsTable[K](counts: Vector[Vector[Long]]) {
assert(depth > 0, "Table must have at least 1 row.")
assert(width > 0, "Table must have at least 1 column.")
def depth: Int = counts.size
def width: Int = counts(0).size
def getCount(pos: (Int, Int)): Long = {
val (row, col) = pos
assert(row < depth && col < width, "Position must be within the bounds of this table.")
counts(row)(col)
}
/**
* Updates the count of a single cell in the table.
*/
def +(pos: (Int, Int), count: Long): GenCMSCountsTable[K] = {
val (row, col) = pos
val currCount = getCount(pos)
val newCounts = counts.updated(row, counts(row).updated(col, currCount + count))
GenCMSCountsTable[K](newCounts)
}
/**
* Adds another counts table to this one, through element-wise addition.
*/
def ++(other: GenCMSCountsTable[K]): GenCMSCountsTable[K] = {
assert((depth, width) ==(other.depth, other.width), "Tables must have the same dimensions.")
val iil = Monoid.plus[IndexedSeq[IndexedSeq[Long]]](counts, other.counts)
def toVector[V](is: IndexedSeq[V]): Vector[V] = {
is match {
case v: Vector[_] => v
case _ => Vector(is: _*)
}
}
GenCMSCountsTable[K](toVector(iil.map {
toVector
}))
}
}
object CMSCountsTable {
// Creates a new CMSCountsTable with counts initialized to all zeroes.
def apply[K: Ordering](depth: Int, width: Int): GenCMSCountsTable[K] = GenCMSCountsTable[K](Vector.fill[Long](depth, width)(0L))
}
/**
* Containers for holding heavy hitter items and their associated counts.
*/
case class HeavyHitters[K: Ordering](hhs: SortedSet[HeavyHitter[K]]) {
def -(hh: HeavyHitter[K]): HeavyHitters[K] = HeavyHitters[K](hhs - hh)
def +(hh: HeavyHitter[K]): HeavyHitters[K] = HeavyHitters[K](hhs + hh)
def ++(other: HeavyHitters[K]): HeavyHitters[K] = HeavyHitters[K](hhs ++ other.hhs)
def items: Set[K] = hhs.map {
_.item
}
def dropCountsBelow(minCount: Double): HeavyHitters[K] = {
HeavyHitters[K](hhs.dropWhile {
_.count < minCount
})
}
}
case class HeavyHitter[K: Ordering](item: K, count: Long)
object HeavyHitter {
def ordering[K: Ordering]: Ordering[HeavyHitter[K]] = {
Ordering.by { hh: HeavyHitter[K] => (hh.count, hh.item)}
}
}
}
/**
* Convenience class for holding constant parameters of a Count-Min sketch.
*/
case class GenCMSParams[K](hashes: Seq[GenCMSHash[K]], eps: Double, delta: Double, heavyHittersPct: Double)
/**
* An Aggregator for the CountMinSketch. Can be created using `CMS.aggregator`.
*/
case class GenCountMinSketchAggregator[K](cmsMonoid: GenCountMinSketchMonoid[K])
extends MonoidAggregator[K, GenCMS[K], GenCMS[K]] {
val monoid = cmsMonoid
def prepare(value: K): GenCMS[K] = monoid.create(value)
def present(cms: GenCMS[K]): GenCMS[K] = cms
}
trait GenCMSHasher[K] {
val PRIME_MODULUS = (1L << 31) - 1
/**
* Returns `a * x + b (mod p) (mod width)`.
*/
def hash(a: Int, b: Int, width: Int)(x: K): Int
}
/**
* The Count-Min sketch uses `d` pair-wise independent hash functions drawn from a universal hashing family of the form:
*
* `h(x) = [a * x + b (mod p)] (mod m)`
*/
case class GenCMSHash[K: GenCMSHasher](a: Int, b: Int, width: Int) {
/**
* Returns `a * x + b (mod p) (mod width)`.
*/
def apply(x: K): Int = implicitly[GenCMSHasher[K]].hash(a, b, width)(x)
}
object GenCountMinSketchImplicits {
implicit object GenCMSHasherLong extends GenCMSHasher[Long] {
def hash(a: Int, b: Int, width: Int)(x: Long) = {
val unmodded: Long = (x * a) + b
// Apparently a super fast way of computing x mod 2^p-1
// See page 149 of http://www.cs.princeton.edu/courses/archive/fall09/cos521/Handouts/universalclasses.pdf
// after Proposition 7.
val modded: Long = (unmodded + (unmodded >> 32)) & PRIME_MODULUS
// Modulo-ing integers is apparently twice as fast as modulo-ing Longs.
modded.toInt % width
}
}
implicit object GenCMSHasherByte extends GenCMSHasher[Byte] {
def hash(a: Int, b: Int, width: Int)(x: Byte) = GenCMSHasherInt.hash(a, b, width)(x)
}
implicit object GenCMSHasherShort extends GenCMSHasher[Short] {
def hash(a: Int, b: Int, width: Int)(x: Short) = GenCMSHasherInt.hash(a, b, width)(x)
}
implicit object GenCMSHasherInt extends GenCMSHasher[Int] {
def hash(a: Int, b: Int, width: Int)(x: Int) = {
val unmodded: Int = (x * a) + b
val modded: Long = (unmodded + (unmodded >> 32)) & PRIME_MODULUS
modded.toInt % width
}
}
implicit object GenCMSHasherBigInt extends GenCMSHasher[BigInt] {
def hash(a: Int, b: Int, width: Int)(x: BigInt) = {
val unmodded: BigInt = (x * a) + b
val modded: BigInt = (unmodded + (unmodded >> 32)) & PRIME_MODULUS
modded.toInt % width
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment