Created
December 19, 2017 06:42
-
-
Save mrkm4ntr/a7d2093cc23d2f077c2226e2d19d0bf6 to your computer and use it in GitHub Desktop.
Minimal implementation of LogisticRegression in Spark ML
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.apache.spark.ml.classification | |
import breeze.linalg.{DenseVector => BDV} | |
import breeze.numerics.sigmoid | |
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS} | |
import org.apache.spark.broadcast.Broadcast | |
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} | |
import org.apache.spark.ml.param.ParamMap | |
import org.apache.spark.mllib.util.MLUtils | |
import org.apache.spark.rdd.RDD | |
import org.apache.spark.sql.{Dataset, Row} | |
case class Datum(label: Double, features: Vector) | |
class BinaryLogisticRegression(override val uid: String) | |
extends ProbabilisticClassifier[Vector, BinaryLogisticRegression, BinaryLogisticRegressionModel] { | |
override def copy(extra: ParamMap): BinaryLogisticRegression = defaultCopy(extra) | |
override protected def train(dataset: Dataset[_]): BinaryLogisticRegressionModel = { | |
val data = dataset.select("label", "features").rdd.map { case Row(label: Double, features: Vector) => | |
Datum(label, features) | |
} | |
val numFeatures = data.first().features.size | |
val optimizer = new LBFGS[BDV[Double]]() | |
val initialCoefficients = Vectors.zeros(numFeatures) | |
val costFun = new BinaryLogisticLossFun(data) | |
val x = optimizer.minimize(new CachedDiffFunction[BDV[Double]](costFun), new BDV[Double](initialCoefficients.toArray)) | |
val trainedCoefficients = Vectors.dense(x.toArray.clone) | |
new BinaryLogisticRegressionModel(uid, trainedCoefficients) | |
} | |
} | |
class BinaryLogisticRegressionModel( | |
override val uid: String, | |
val coefficients: Vector) | |
extends ProbabilisticClassificationModel[Vector, BinaryLogisticRegressionModel] { | |
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = | |
Vectors.dense(rawPrediction.toDense.values.map(sigmoid(_))) | |
override def numClasses: Int = 2 | |
override protected def predictRaw(features: Vector): Vector = { | |
val margin = BLAS.dot(coefficients, features) | |
Vectors.dense(-margin, margin) | |
} | |
override def copy(extra: ParamMap): BinaryLogisticRegressionModel = { | |
val model = copyValues(new BinaryLogisticRegressionModel(uid, coefficients)) | |
model.setParent(parent) | |
} | |
} | |
class BinaryLogisticLossFun(data: RDD[Datum]) extends DiffFunction[BDV[Double]] { | |
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { | |
val bcCoefficients = data.context.broadcast(Vectors.fromBreeze(coefficients)) | |
val logisticAggregator = { | |
val seqOp = (c: BinaryLogisticAggregator, instance: Datum) => c.add(instance) | |
val combOp = (c1: BinaryLogisticAggregator, c2: BinaryLogisticAggregator) => c1.merge(c2) | |
data.treeAggregate(new BinaryLogisticAggregator(bcCoefficients))(seqOp, combOp) | |
} | |
bcCoefficients.destroy(blocking = false) | |
(logisticAggregator.loss, BDV(logisticAggregator.gradient.toArray)) | |
} | |
} | |
class BinaryLogisticAggregator(bcCoefficients: Broadcast[Vector]) extends Serializable { | |
private var numData = 0 | |
private var lossSum = 0.0 | |
private val numCoefficients = bcCoefficients.value.size | |
@transient | |
private lazy val coefficients: Vector = bcCoefficients.value | |
private lazy val localGradient = Vectors.zeros(numCoefficients) | |
def add(datum: Datum): BinaryLogisticAggregator = datum match { | |
case Datum(label, features) => | |
val margin = BLAS.dot(coefficients, features) | |
BLAS.axpy(sigmoid(margin) - label, features, localGradient) | |
lossSum += label * MLUtils.log1pExp(-margin) + (1.0 - label) * (MLUtils.log1pExp(-margin) + margin) | |
numData += 1 | |
this | |
} | |
def merge(other: BinaryLogisticAggregator): BinaryLogisticAggregator = { | |
numData += other.numData | |
lossSum += other.lossSum | |
this | |
} | |
def loss: Double = lossSum / numData | |
def gradient: Vector = { | |
val result = Vectors.dense(localGradient.toArray.clone) | |
BLAS.scal(1.0 / numData, result) | |
result | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment