Skip to content

Instantly share code, notes, and snippets.

@emesday
Created July 18, 2019 03:46
Show Gist options
  • Save emesday/77d63be99b2dfe23f4528ab5a513e0d8 to your computer and use it in GitHub Desktop.
Save emesday/77d63be99b2dfe23f4528ab5a513e0d8 to your computer and use it in GitHub Desktop.
LocalBinaryClassificationMetrics.scala
package org.apache.spark.mllib.evaluation
import org.apache.spark.mllib.evaluation.binary._
import scala.collection.mutable
class LocalBinaryClassificationMetrics(
val scoreAndLabels: Seq[(Double, Double)],
val numBins: Int) {
require(numBins >= 0, "numBins must be nonnegative")
def this(scoreAndLabels: Seq[(Double, Double)]) = this(scoreAndLabels, 0)
def thresholds(): Seq[Double] = cumulativeCounts.map(_._1)
def roc(): Seq[(Double, Double)] = {
val rocCurve = createCurve(FalsePositiveRate, Recall)
Seq((0.0, 0.0)) ++ rocCurve ++ Seq((1.0, 1.0))
}
def areaUnderROC(): Double = AreaUnderCurve.of(roc())
def pr(): Seq[(Double, Double)] = {
val prCurve = createCurve(Recall, Precision)
val (_, firstPrecision) = prCurve.head
Seq((0.0, firstPrecision)) ++ prCurve
}
def areaUnderPR(): Double = AreaUnderCurve.of(pr())
def fMeasureByThreshold(beta: Double): Seq[(Double, Double)] = createCurve(FMeasure(beta))
def fMeasureByThreshold(): Seq[(Double, Double)] = fMeasureByThreshold(1.0)
def precisionByThreshold(): Seq[(Double, Double)] = createCurve(Precision)
def recallByThreshold(): Seq[(Double, Double)] = createCurve(Recall)
def totalCount: BinaryLabelCounter = totalCount0.clone()
private lazy val (
cumulativeCounts: Seq[(Double, BinaryLabelCounter)],
confusions: Seq[(Double, BinaryConfusionMatrix)],
totalCount0: BinaryLabelCounter) = {
val counterMap = mutable.HashMap[Double, BinaryLabelCounter]()
for ((score, label) <- scoreAndLabels) {
val counter = counterMap.getOrElseUpdate(score, new BinaryLabelCounter(0L, 0L))
counter += label
}
val counts = counterMap.toSeq.sortBy(-_._1)
val binnedCounts = if (numBins == 0) {
counts
} else {
throw new NotImplementedError("numBins > 0 is not implemented")
}
val cumulativeCounts = binnedCounts.map(_._1).zip(
binnedCounts.map(_._2).scanLeft(new BinaryLabelCounter())((agg, c) => agg.clone() += c).drop(1))
val totalCount = cumulativeCounts.last._2
val confusions = cumulativeCounts.map { case (score, cumCount) =>
(score, BinaryConfusionMatrixImpl(cumCount, totalCount).asInstanceOf[BinaryConfusionMatrix])
}
(cumulativeCounts, confusions, totalCount)
}
private def createCurve(y: BinaryClassificationMetricComputer): Seq[(Double, Double)] = {
confusions.map { case (s, c) =>
(s, y(c))
}
}
private def createCurve(
x: BinaryClassificationMetricComputer,
y: BinaryClassificationMetricComputer): Seq[(Double, Double)] = {
confusions.map { case (_, c) =>
(x(c), y(c))
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment