Skip to content

Instantly share code, notes, and snippets.

@Krimit
Created July 6, 2015 20:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Krimit/f61b5e3be2b1e380d1ca to your computer and use it in GitHub Desktop.
Save Krimit/f61b5e3be2b1e380d1ca to your computer and use it in GitHub Desktop.
return scores in multiclass org.apache.spark.mllib.classification.LogisticRegressionModel
override protected def predictPoint(
dataMatrix: Vector,
weightMatrix: Vector,
intercept: Double) = {
require(dataMatrix.size == numFeatures)
var (margins, maxMargin, bestClass) = computeMargins(dataMatrix, weightMatrix, intercept)
if (numClasses == 2) {
val scores = computeScores(margins)
val score = scores(0)
threshold match {
case Some(t) => if (score > t) 1.0 else 0.0
case None => score
}
} else {
bestClass.toDouble
}
}
private def computeMargins(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): (Vector, Double, Int) = {
var bestClass = 0
var maxMargin = 0.0
if (numClasses == 2) {
val margin = dot(weightMatrix, dataMatrix) + intercept
val margins = Vectors.dense(margin, -margin)
val maxMargin = math.max(margin, -margin)
val bestClass = if (maxMargin.equals(margin)) 0 else 1
(margins, maxMargin, bestClass)
} else {
var bestClass = 0
var maxMargin = 0.0
val withBias = dataMatrix.size + 1 == dataWithBiasSize
val margins = Array.tabulate(numClasses - 1) { i =>
var margin = 0.0
dataMatrix.foreachActive { (index, value) =>
if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index)
}
// Intercept is required to be added into margin.
if (withBias) {
margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size)
}
if (margin > maxMargin) {
maxMargin = margin
bestClass = i + 1
}
margin
}
(Vectors.dense(margins), maxMargin, bestClass)
}
}
private def computeScores(margins: Vector): Vector[Double] = {
margins match {
case dv: DenseVector =>
(0 until dv.size - 1).map { i =>
dv.values(i) = 1.0 / (1.0 + math.exp(-dv.values(i)))
}
dv
case sv: SparseVector =>
throw new RuntimeException("Unexpected error in LogisticRegressionModel:" +
" raw2probabilitiesInPlace encountered SparseVector")
}
}
def scorePoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Vector[Double] = {
require(dataMatrix.size == numFeatures)
var (margins, maxMargin, _) = computeMargins(dataMatrix, weightMatrix, intercept)
if (numClasses > 2 && maxMargin > 0) {
for (i <- 0 until margins.size - 1) {
margins(i) -= maxMargin
}
}
computeScores(margins)
}
/**
* Compute score of each class for the given data set using the model trained.
*
* @param testData RDD representing data points to be predicted
* @return RDD[Double] where each entry contains the corresponding prediction
*/
def score(testData: RDD[Vector]): RDD[Vector[Double]] = {
// A small optimization to avoid serializing the entire model. Only the weightsMatrix
// and intercept is needed.
val localWeights = weights
val bcWeights = testData.context.broadcast(localWeights)
val localIntercept = intercept
testData.mapPartitions { iter =>
val w = bcWeights.value
iter.map(v => scorePoint(v, w, localIntercept))
}
}
/**
* Compute score of each class for a single data point using the model trained.
*
* @param testData array representing a single data point
* @return Double prediction from the trained model
*/
def score(testData: Vector): Vector[Double] = {
scorePoint(testData, weights, intercept)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment