Skip to content

Instantly share code, notes, and snippets.

@jesusjavierdediego
Last active February 27, 2019 22:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jesusjavierdediego/d206ebd165a116edf41f8c226642f042 to your computer and use it in GitHub Desktop.
Save jesusjavierdediego/d206ebd165a116edf41f8c226642f042 to your computer and use it in GitHub Desktop.
class ModelTrainer extends Logging{
def train(trainingDF: DataFrame): (LogisticRegressionModel, Map[String, Double]) ={
import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression}
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import spark.implicits._
val SPLIT_FACTOR = configuration.envOrElseConfig("learning.split-factor").toDouble
val SPLIT_FEED = configuration.envOrElseConfig("learning.split-seed").toLong
val MAX_ITERATIONS = configuration.envOrElseConfig("learning.max-iterations").toInt
val REGULARIZATION_PARAM = configuration.envOrElseConfig("learning.regularization-parameter").toDouble
val ELASTIC_NET_PARAM = configuration.envOrElseConfig("learning.elastic-net-parameter").toDouble
// 1-Set weights and seed
val dataSplits = trainingDF.randomSplit(Array(SPLIT_FACTOR, 1 - SPLIT_FACTOR), seed = SPLIT_FEED)
val trainingData = dataSplits(0).cache()
val testData = dataSplits(1)
// 2-Check and see datasets...
trainingData.show(false)
trainingData.printSchema()
testData.show(false)
estData.printSchema()
// 3-Create the model and set the hyperparams
val lr = new LogisticRegression()
.setMaxIter(MAX_ITERATIONS)
.setRegParam(REGULARIZATION_PARAM)
.setElasticNetParam(ELASTIC_NET_PARAM)
.setFeaturesCol("features")
// 4-Train the model
val trainedModel = lr.fit(trainingData)
// 5-Get some factors:
val trainingSummary = trainedModel.summary
val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary]
val fMeasure = binarySummary.fMeasureByThreshold
val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure).select("threshold").head().getDouble(0)
trainedModel.setThreshold(bestThreshold)
logger.info(s"Coefficients: ${trainedModel.coefficients} Intercept: ${trainedModel.intercept}")
logger.info(s"areaUnderROC: ${binarySummary.areaUnderROC}")
logger.info(s"max F-measure: ${maxFMeasure}")
logger.info(s"best threshold: ${bestThreshold}")
// 6-Test the trained model with the test dataset
val predictions = trainedModel.transform(testData)
// 7-Evaluate the performance of the model
val evaluator = new BinaryClassificationEvaluator()
val accuracy = evaluator.evaluate(predictions)
logger.info("Test Error = " + (1.0 - accuracy))
var modelMetadata: scala.collection.mutable.Map[String, Double] = scala.collection.mutable.Map.empty[String, Double]
modelMetadata += ("areaUnderROC" -> binarySummary.areaUnderROC)
modelMetadata += ("maxFMeasure" -> maxFMeasure)
modelMetadata += ("accuracy" -> accuracy)
(trainedModel, modelMetadata.toMap)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment