Created
September 29, 2014 16:25
-
-
Save miguno/3bf1a24be446b2edaeb1 to your computer and use it in GitHub Desktop.
Draft of generic CMS
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.miguno.algebird.extensions | |
import com.twitter.algebird.{Approximate, Monoid, MonoidAggregator} | |
import scala.collection.immutable.SortedSet | |
class GenCountMinSketchMonoid[K: Ordering : GenCMSHasher](eps: Double, delta: Double, seed: Int, | |
heavyHittersPct: Double = 0.01) extends Monoid[GenCMS[K]] { | |
assert(0 < eps && eps < 1, "eps must lie in (0, 1)") | |
assert(0 < delta && delta < 1, "delta must lie in (0, 1)") | |
assert(0 < heavyHittersPct && heavyHittersPct < 1, "heavyHittersPct must lie in (0, 1)") | |
// Typically, we would use d pair-wise independent hash functions of the form | |
// | |
// h_i(x) = a_i * x + b_i (mod p) | |
// | |
// But for this particular application, setting b_i does not matter (since all it does is shift the results of a | |
// particular hash), so we omit it (by setting b_i to 0) and simply use hash functions of the form | |
// | |
// h_i(x) = a_i * x (mod p) | |
// | |
val hashes: Seq[GenCMSHash[K]] = { | |
val r = new scala.util.Random(seed) | |
val numHashes = GenCMS.depth(delta) | |
val numCounters = GenCMS.width(eps) | |
(0 to (numHashes - 1)).map { _ => GenCMSHash[K](r.nextInt(), 0, numCounters)} | |
} | |
val params = GenCMSParams(hashes, eps, delta, heavyHittersPct) | |
val zero: GenCMS[K] = GenCMSZero[K](params) | |
/** | |
* We assume the Count-Min sketches on the left and right use the same hash functions. | |
*/ | |
def plus(left: GenCMS[K], right: GenCMS[K]): GenCMS[K] = left ++ right | |
/** | |
* Create a Count-Min sketch out of a single item or data stream. | |
*/ | |
def create(item: K): GenCMS[K] = GenCMSItem[K](item, params) | |
def create(data: Seq[K]): GenCMS[K] = { | |
data.foldLeft(zero) { case (acc, x) => plus(acc, create(x))} | |
} | |
} | |
object GenCMS { | |
def monoid[K: Ordering : GenCMSHasher](eps: Double, delta: Double, seed: Int, heavyHittersPct: Double = 0.01): GenCountMinSketchMonoid[K] = | |
new GenCountMinSketchMonoid[K](eps, delta, seed, heavyHittersPct) | |
def monoid[K: Ordering : GenCMSHasher](depth: Int, width: Int, seed: Int, heavyHittersPct: Double): GenCountMinSketchMonoid[K] = | |
new GenCountMinSketchMonoid[K](GenCMS.eps(width), GenCMS.delta(depth), seed, heavyHittersPct) | |
def aggregator[K: Ordering : GenCMSHasher](eps: Double, delta: Double, seed: Int, heavyHittersPct: Double = 0.01): GenCountMinSketchAggregator[K] = { | |
val monoid = new GenCountMinSketchMonoid[K](eps, delta, seed, heavyHittersPct) | |
new GenCountMinSketchAggregator[K](monoid) | |
} | |
def aggregator[K: Ordering : GenCMSHasher](depth: Int, width: Int, seed: Int, heavyHittersPct: Double): GenCountMinSketchAggregator[K] = { | |
val monoid = new GenCountMinSketchMonoid[K](GenCMS.eps(width), GenCMS.delta(depth), seed, heavyHittersPct) | |
new GenCountMinSketchAggregator[K](monoid) | |
} | |
/** | |
* Functions to translate between (eps, delta) and (depth, width). The translation is: | |
* depth = ceil(ln 1/delta) | |
* width = ceil(e / eps) | |
*/ | |
def eps(width: Int) = scala.math.exp(1.0) / width | |
def delta(depth: Int) = 1.0 / scala.math.exp(depth) | |
def depth(delta: Double) = scala.math.ceil(scala.math.log(1.0 / delta)).toInt | |
def width(eps: Double) = scala.math.ceil(scala.math.exp(1) / eps).toInt | |
} | |
/** | |
* The actual Count-Min sketch data structure. | |
*/ | |
sealed abstract class GenCMS[K] extends java.io.Serializable { | |
// Parameters used to bound confidence in error estimates. | |
def eps: Double | |
def delta: Double | |
// Number of hash functions. | |
def depth: Int = GenCMS.depth(delta) | |
// Number of counters per hash function. | |
def width: Int = GenCMS.width(eps) | |
def ++(other: GenCMS[K]): GenCMS[K] | |
/** | |
* Returns an estimate of the total number of times this item has been seen | |
* in the stream so far. This estimate is an upper bound. | |
* | |
* It is always true that trueFrequency <= estimatedFrequency. | |
* With probability p >= 1 - delta, it also holds that | |
* estimatedFrequency <= trueFrequency + eps * totalCount. | |
*/ | |
def frequency(item: K): Approximate[Long] | |
/** | |
* Returns an estimate of the inner product against another data stream. | |
* | |
* In other words, let a_i denote the number of times element i has been seen in | |
* the data stream summarized by this CMS, and let b_i denote the same for the other CMS. | |
* Then this returns an estimate of <a, b> = \sum a_i b_i | |
* | |
* Note: this can also be viewed as the join size between two relations. | |
* | |
* It is always true that actualInnerProduct <= estimatedInnerProduct. | |
* With probability p >= 1 - delta, it also holds that | |
* estimatedInnerProduct <= actualInnerProduct + eps * thisTotalCount * otherTotalCount | |
*/ | |
def innerProduct(other: GenCMS[K]): Approximate[Long] | |
/** | |
* Finds all heavy hitters, i.e., elements in the stream that appear at least | |
* (heavyHittersPct * totalCount) times. | |
* | |
* Every item that appears at least (heavyHittersPct * totalCount) times is output, | |
* and with probability p >= 1 - delta, no item whose count is less than | |
* (heavyHittersPct - eps) * totalCount is output. | |
* | |
* Note that the set of heavy hitters contains at most 1 / heavyHittersPct | |
* elements, so keeping track of all elements that appear more than (say) 1% of the | |
* time requires tracking at most 100 items. | |
*/ | |
def heavyHittersPct: Double | |
def heavyHitters: Set[K] | |
/** | |
* Total number of elements seen in the data stream so far. | |
*/ | |
def totalCount: Long | |
/** | |
* The first frequency moment is the total number of elements in the stream. | |
*/ | |
def f1: Long = totalCount | |
/** | |
* The second frequency moment is `\sum a_i^2`, where a_i is the count of the ith element. | |
*/ | |
def f2: Approximate[Long] = innerProduct(this) | |
} | |
/** | |
* Zero element. Used for initialization. | |
*/ | |
case class GenCMSZero[K: Ordering](params: GenCMSParams[K]) extends GenCMS[K] { | |
def eps: Double = params.eps | |
def delta: Double = params.delta | |
def heavyHittersPct: Double = params.heavyHittersPct | |
def totalCount: Long = 0L | |
def ++(other: GenCMS[K]): GenCMS[K] = other | |
def frequency(item: K): Approximate[Long] = Approximate.exact(0L) | |
def innerProduct(other: GenCMS[K]): Approximate[Long] = Approximate.exact(0L) | |
def heavyHitters: Set[K] = Set[K]() | |
} | |
/** | |
* Used for holding a single element, to avoid repeatedly adding elements from sparse counts tables. | |
*/ | |
case class GenCMSItem[K: Ordering](item: K, params: GenCMSParams[K]) extends GenCMS[K] { | |
def eps: Double = params.eps | |
def delta: Double = params.delta | |
def heavyHittersPct: Double = params.heavyHittersPct | |
def totalCount: Long = 1L | |
def ++(other: GenCMS[K]): GenCMS[K] = { | |
other match { | |
// TODO: Properly handle type erasure for K? | |
case other: GenCMSZero[_] => this | |
case other: GenCMSItem[K] => GenCMSInstance[K](params) + item + other.item | |
case other: GenCMSInstance[K] => other + item | |
} | |
} | |
def frequency(x: K): Approximate[Long] = if (item == x) Approximate.exact(1L) else Approximate.exact(0L) | |
def innerProduct(other: GenCMS[K]): Approximate[Long] = other.frequency(item) | |
def heavyHitters: Set[K] = Set(item) | |
} | |
/** | |
* The general Count-Min sketch structure, used for holding any number of elements. | |
*/ | |
case class GenCMSInstance[K: Ordering](countsTable: GenCMSInstance.GenCMSCountsTable[K], totalCount: Long, | |
hhs: GenCMSInstance.HeavyHitters[K], params: GenCMSParams[K]) extends GenCMS[K] { | |
def eps: Double = params.eps | |
def delta: Double = params.delta | |
def heavyHittersPct: Double = params.heavyHittersPct | |
def ++(other: GenCMS[K]): GenCMS[K] = { | |
other match { | |
// TODO: Properly handle type erasure for K? | |
case other: GenCMSZero[_] => this | |
case other: GenCMSItem[K] => this + other.item | |
case other: GenCMSInstance[K] => | |
val newTotalCount = totalCount + other.totalCount | |
val newHhs = (hhs ++ other.hhs).dropCountsBelow(params.heavyHittersPct * newTotalCount) | |
GenCMSInstance[K](countsTable ++ other.countsTable, newTotalCount, newHhs, params) | |
} | |
} | |
private def makeApprox(est: Long): Approximate[Long] = { | |
if (est == 0L) { | |
Approximate.exact(0L) | |
} else { | |
val lower = math.max(0L, est - (eps * totalCount).toLong) | |
Approximate(lower, est, est, 1 - delta) | |
} | |
} | |
def frequency(item: K): Approximate[Long] = { | |
val estimates = countsTable.counts.zipWithIndex.map { | |
case (row, i) => | |
row(params.hashes(i)(item)) | |
} | |
makeApprox(estimates.min) | |
} | |
/** | |
* Let X be a CMS, and let count_X[j, k] denote the value in X's 2-dimensional count table at row j and column k. | |
* Then the Count-Min sketch estimate of the inner product between A and B is the minimum inner product between their | |
* rows: | |
* estimatedInnerProduct = min_j (\sum_k count_A[j, k] * count_B[j, k]) | |
*/ | |
def innerProduct(other: GenCMS[K]): Approximate[Long] = { | |
other match { | |
// TODO: Properly handle type erasure for K? | |
case other: GenCMSInstance[_] => | |
assert((other.depth, other.width) ==(depth, width), "Tables must have the same dimensions.") | |
def innerProductAtDepth(d: Int) = (0 to (width - 1)).map { w => | |
countsTable.getCount(d, w) * other.countsTable.getCount(d, w) | |
}.sum | |
val est = (0 to (depth - 1)).map { | |
innerProductAtDepth | |
}.min | |
Approximate(est - (eps * totalCount * other.totalCount).toLong, est, est, 1 - delta) | |
case _ => other.innerProduct(this) | |
} | |
} | |
def heavyHitters: Set[K] = hhs.items | |
/** | |
* Updates the sketch with a new element from the data stream. | |
*/ | |
def +(item: K): GenCMSInstance[K] = this +(item, 1L) | |
def +(item: K, count: Long): GenCMSInstance[K] = { | |
require(count >= 0, "count must be >= 0 (negative counts not implemented") | |
val newHhs = updateHeavyHitters(item, count) | |
val newCountsTable = | |
(0 to (depth - 1)).foldLeft(countsTable) { | |
case (table, row) => | |
val pos = (row, params.hashes(row)(item)) | |
table +(pos, count) | |
} | |
GenCMSInstance[K](newCountsTable, totalCount + count, newHhs, params) | |
} | |
/** | |
* Updates the data structure of heavy hitters when a new item (with associated count) enters the stream. | |
*/ | |
private def updateHeavyHitters(item: K, count: Long): GenCMSInstance.HeavyHitters[K] = { | |
val oldItemCount = frequency(item).estimate | |
val newItemCount = oldItemCount + count | |
val newTotalCount = totalCount + count | |
// If the new item is a heavy hitter, add it, and remove any previous instances. | |
val newHhs = | |
if (newItemCount >= heavyHittersPct * newTotalCount) { | |
hhs - GenCMSInstance.HeavyHitter[K](item, oldItemCount) + GenCMSInstance.HeavyHitter[K](item, newItemCount) | |
} else { | |
hhs | |
} | |
// Remove any items below the new heavy hitter threshold. | |
newHhs.dropCountsBelow(heavyHittersPct * newTotalCount) | |
} | |
} | |
object GenCMSInstance { | |
/** | |
* Initializes a CMSInstance with all zeroes. | |
*/ | |
def apply[K: Ordering](params: GenCMSParams[K]): GenCMSInstance[K] = { | |
val countsTable = CMSCountsTable[K](GenCMS.depth(params.delta), GenCMS.width(params.eps)) | |
implicit val heavyHitterOrdering = HeavyHitter.ordering[K] | |
GenCMSInstance[K](countsTable, 0, HeavyHitters[K](SortedSet[HeavyHitter[K]]()), params) | |
} | |
/** | |
* The 2-dimensional table of counters used in the Count-Min sketch. | |
* Each row corresponds to a particular hash function. | |
* TODO: implement a dense matrix type, and use it here | |
*/ | |
case class GenCMSCountsTable[K](counts: Vector[Vector[Long]]) { | |
assert(depth > 0, "Table must have at least 1 row.") | |
assert(width > 0, "Table must have at least 1 column.") | |
def depth: Int = counts.size | |
def width: Int = counts(0).size | |
def getCount(pos: (Int, Int)): Long = { | |
val (row, col) = pos | |
assert(row < depth && col < width, "Position must be within the bounds of this table.") | |
counts(row)(col) | |
} | |
/** | |
* Updates the count of a single cell in the table. | |
*/ | |
def +(pos: (Int, Int), count: Long): GenCMSCountsTable[K] = { | |
val (row, col) = pos | |
val currCount = getCount(pos) | |
val newCounts = counts.updated(row, counts(row).updated(col, currCount + count)) | |
GenCMSCountsTable[K](newCounts) | |
} | |
/** | |
* Adds another counts table to this one, through element-wise addition. | |
*/ | |
def ++(other: GenCMSCountsTable[K]): GenCMSCountsTable[K] = { | |
assert((depth, width) ==(other.depth, other.width), "Tables must have the same dimensions.") | |
val iil = Monoid.plus[IndexedSeq[IndexedSeq[Long]]](counts, other.counts) | |
def toVector[V](is: IndexedSeq[V]): Vector[V] = { | |
is match { | |
case v: Vector[_] => v | |
case _ => Vector(is: _*) | |
} | |
} | |
GenCMSCountsTable[K](toVector(iil.map { | |
toVector | |
})) | |
} | |
} | |
object CMSCountsTable { | |
// Creates a new CMSCountsTable with counts initialized to all zeroes. | |
def apply[K: Ordering](depth: Int, width: Int): GenCMSCountsTable[K] = GenCMSCountsTable[K](Vector.fill[Long](depth, width)(0L)) | |
} | |
/** | |
* Containers for holding heavy hitter items and their associated counts. | |
*/ | |
case class HeavyHitters[K: Ordering](hhs: SortedSet[HeavyHitter[K]]) { | |
def -(hh: HeavyHitter[K]): HeavyHitters[K] = HeavyHitters[K](hhs - hh) | |
def +(hh: HeavyHitter[K]): HeavyHitters[K] = HeavyHitters[K](hhs + hh) | |
def ++(other: HeavyHitters[K]): HeavyHitters[K] = HeavyHitters[K](hhs ++ other.hhs) | |
def items: Set[K] = hhs.map { | |
_.item | |
} | |
def dropCountsBelow(minCount: Double): HeavyHitters[K] = { | |
HeavyHitters[K](hhs.dropWhile { | |
_.count < minCount | |
}) | |
} | |
} | |
case class HeavyHitter[K: Ordering](item: K, count: Long) | |
object HeavyHitter { | |
def ordering[K: Ordering]: Ordering[HeavyHitter[K]] = { | |
Ordering.by { hh: HeavyHitter[K] => (hh.count, hh.item)} | |
} | |
} | |
} | |
/** | |
* Convenience class for holding constant parameters of a Count-Min sketch. | |
*/ | |
case class GenCMSParams[K](hashes: Seq[GenCMSHash[K]], eps: Double, delta: Double, heavyHittersPct: Double) | |
/** | |
* An Aggregator for the CountMinSketch. Can be created using `CMS.aggregator`. | |
*/ | |
case class GenCountMinSketchAggregator[K](cmsMonoid: GenCountMinSketchMonoid[K]) | |
extends MonoidAggregator[K, GenCMS[K], GenCMS[K]] { | |
val monoid = cmsMonoid | |
def prepare(value: K): GenCMS[K] = monoid.create(value) | |
def present(cms: GenCMS[K]): GenCMS[K] = cms | |
} | |
trait GenCMSHasher[K] { | |
val PRIME_MODULUS = (1L << 31) - 1 | |
/** | |
* Returns `a * x + b (mod p) (mod width)`. | |
*/ | |
def hash(a: Int, b: Int, width: Int)(x: K): Int | |
} | |
/** | |
* The Count-Min sketch uses `d` pair-wise independent hash functions drawn from a universal hashing family of the form: | |
* | |
* `h(x) = [a * x + b (mod p)] (mod m)` | |
*/ | |
case class GenCMSHash[K: GenCMSHasher](a: Int, b: Int, width: Int) { | |
/** | |
* Returns `a * x + b (mod p) (mod width)`. | |
*/ | |
def apply(x: K): Int = implicitly[GenCMSHasher[K]].hash(a, b, width)(x) | |
} | |
object GenCountMinSketchImplicits { | |
implicit object GenCMSHasherLong extends GenCMSHasher[Long] { | |
def hash(a: Int, b: Int, width: Int)(x: Long) = { | |
val unmodded: Long = (x * a) + b | |
// Apparently a super fast way of computing x mod 2^p-1 | |
// See page 149 of http://www.cs.princeton.edu/courses/archive/fall09/cos521/Handouts/universalclasses.pdf | |
// after Proposition 7. | |
val modded: Long = (unmodded + (unmodded >> 32)) & PRIME_MODULUS | |
// Modulo-ing integers is apparently twice as fast as modulo-ing Longs. | |
modded.toInt % width | |
} | |
} | |
implicit object GenCMSHasherByte extends GenCMSHasher[Byte] { | |
def hash(a: Int, b: Int, width: Int)(x: Byte) = GenCMSHasherInt.hash(a, b, width)(x) | |
} | |
implicit object GenCMSHasherShort extends GenCMSHasher[Short] { | |
def hash(a: Int, b: Int, width: Int)(x: Short) = GenCMSHasherInt.hash(a, b, width)(x) | |
} | |
implicit object GenCMSHasherInt extends GenCMSHasher[Int] { | |
def hash(a: Int, b: Int, width: Int)(x: Int) = { | |
val unmodded: Int = (x * a) + b | |
val modded: Long = (unmodded + (unmodded >> 32)) & PRIME_MODULUS | |
modded.toInt % width | |
} | |
} | |
implicit object GenCMSHasherBigInt extends GenCMSHasher[BigInt] { | |
def hash(a: Int, b: Int, width: Int)(x: BigInt) = { | |
val unmodded: BigInt = (x * a) + b | |
val modded: BigInt = (unmodded + (unmodded >> 32)) & PRIME_MODULUS | |
modded.toInt % width | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment