Last active
January 1, 2016 08:39
-
-
Save soulmachine/8119952 to your computer and use it in GitHub Desktop.
run()
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def run(C: Int, D: Int, data: RDD[LabeledPoint]) = { | |
val partitionCounts = data.mapPartitions { iterator => | |
val localCountPerLabel = mutable.Map.empty[Int, Int].withDefaultValue(0) | |
val localSummedObservations = mutable.Map.empty[Int, Array[Double]] | |
.withDefaultValue(Array.fill(D)(0.0)) | |
iterator.foreach { | |
case LabeledPoint(label, features) => | |
val y = label.toInt | |
localCountPerLabel(y) += 1 | |
localSummedObservations(y) = localSummedObservations(y).zip(features) | |
.map(pair => pair._1 + pair._2) | |
} | |
localCountPerLabel.keys.toIterator.map { label => | |
label -> (localCountPerLabel(label), localSummedObservations(label)) | |
} | |
} | |
val counts = partitionCounts.groupByKey().mapValues { seq => | |
var count = 0 | |
var countPerFeature = new Array[Double](D) | |
seq.foreach { pair => | |
count += pair._1 | |
countPerFeature = countPerFeature.zip(pair._2) | |
.map(p => p._1 + p._2) | |
} | |
(count, countPerFeature) | |
} | |
val N = counts.map(_._1).collect().sum | |
val logDenominator = math.log(N + C * lambda) | |
val weightPerLabel = counts.mapValues { pair => | |
math.log(pair._1 + lambda) - logDenominator | |
} | |
val weightMatrix = counts.mapValues { pair => | |
val countsPerLabel = pair._2 | |
val sum = countsPerLabel.sum | |
val logDenom = math.log(sum + D * lambda) | |
countsPerLabel.map(w => math.log(w + lambda) - logDenom) | |
} | |
val labelWeights = weightPerLabel.collect().sorted.map(_._2) | |
val weightsMat = weightMatrix.collect().sortBy(_._1).map(_._2) | |
new NaiveBayesModel(labelWeights, weightsMat) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def run(C: Int, D: Int, data: RDD[LabeledPoint]): NaiveBayesModel = { | |
val groupedData = data.map(p => p.label.toInt -> p.features).groupByKey() | |
val countPerLabel = groupedData.mapValues(_.size) | |
val logDenominator = math.log(data.count() + C * lambda) | |
val weightPerLabel = countPerLabel.mapValues { | |
count => math.log(count + lambda) - logDenominator | |
} | |
val summedObservations = groupedData.mapValues(_.reduce { | |
(lhs, rhs) => lhs.zip(rhs).map(pair => pair._1 + pair._2) | |
}) | |
val weightsMatrix = summedObservations.mapValues { weights => | |
val sum = weights.sum | |
val logDenom = math.log(sum + D * lambda) | |
weights.map(w => math.log(w + lambda) - logDenom) | |
} | |
val labelWeights = weightPerLabel.collect().sorted.map(_._2) | |
val weightsMat = weightsMatrix.collect().sortBy(_._1).map(_._2) | |
new NaiveBayesModel(labelWeights, weightsMat) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment