Skip to content

Instantly share code, notes, and snippets.

@mrkm4ntr
Created December 19, 2017 06:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mrkm4ntr/a7d2093cc23d2f077c2226e2d19d0bf6 to your computer and use it in GitHub Desktop.
Save mrkm4ntr/a7d2093cc23d2f077c2226e2d19d0bf6 to your computer and use it in GitHub Desktop.
Minimal implementation of LogisticRegression in Spark ML
package org.apache.spark.ml.classification
import breeze.linalg.{DenseVector => BDV}
import breeze.numerics.sigmoid
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, Row}
case class Datum(label: Double, features: Vector)
class BinaryLogisticRegression(override val uid: String)
extends ProbabilisticClassifier[Vector, BinaryLogisticRegression, BinaryLogisticRegressionModel] {
override def copy(extra: ParamMap): BinaryLogisticRegression = defaultCopy(extra)
override protected def train(dataset: Dataset[_]): BinaryLogisticRegressionModel = {
val data = dataset.select("label", "features").rdd.map { case Row(label: Double, features: Vector) =>
Datum(label, features)
}
val numFeatures = data.first().features.size
val optimizer = new LBFGS[BDV[Double]]()
val initialCoefficients = Vectors.zeros(numFeatures)
val costFun = new BinaryLogisticLossFun(data)
val x = optimizer.minimize(new CachedDiffFunction[BDV[Double]](costFun), new BDV[Double](initialCoefficients.toArray))
val trainedCoefficients = Vectors.dense(x.toArray.clone)
new BinaryLogisticRegressionModel(uid, trainedCoefficients)
}
}
class BinaryLogisticRegressionModel(
override val uid: String,
val coefficients: Vector)
extends ProbabilisticClassificationModel[Vector, BinaryLogisticRegressionModel] {
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector =
Vectors.dense(rawPrediction.toDense.values.map(sigmoid(_)))
override def numClasses: Int = 2
override protected def predictRaw(features: Vector): Vector = {
val margin = BLAS.dot(coefficients, features)
Vectors.dense(-margin, margin)
}
override def copy(extra: ParamMap): BinaryLogisticRegressionModel = {
val model = copyValues(new BinaryLogisticRegressionModel(uid, coefficients))
model.setParent(parent)
}
}
class BinaryLogisticLossFun(data: RDD[Datum]) extends DiffFunction[BDV[Double]] {
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
val bcCoefficients = data.context.broadcast(Vectors.fromBreeze(coefficients))
val logisticAggregator = {
val seqOp = (c: BinaryLogisticAggregator, instance: Datum) => c.add(instance)
val combOp = (c1: BinaryLogisticAggregator, c2: BinaryLogisticAggregator) => c1.merge(c2)
data.treeAggregate(new BinaryLogisticAggregator(bcCoefficients))(seqOp, combOp)
}
bcCoefficients.destroy(blocking = false)
(logisticAggregator.loss, BDV(logisticAggregator.gradient.toArray))
}
}
class BinaryLogisticAggregator(bcCoefficients: Broadcast[Vector]) extends Serializable {
private var numData = 0
private var lossSum = 0.0
private val numCoefficients = bcCoefficients.value.size
@transient
private lazy val coefficients: Vector = bcCoefficients.value
private lazy val localGradient = Vectors.zeros(numCoefficients)
def add(datum: Datum): BinaryLogisticAggregator = datum match {
case Datum(label, features) =>
val margin = BLAS.dot(coefficients, features)
BLAS.axpy(sigmoid(margin) - label, features, localGradient)
lossSum += label * MLUtils.log1pExp(-margin) + (1.0 - label) * (MLUtils.log1pExp(-margin) + margin)
numData += 1
this
}
def merge(other: BinaryLogisticAggregator): BinaryLogisticAggregator = {
numData += other.numData
lossSum += other.lossSum
this
}
def loss: Double = lossSum / numData
def gradient: Vector = {
val result = Vectors.dense(localGradient.toArray.clone)
BLAS.scal(1.0 / numData, result)
result
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment