Last active
February 27, 2019 22:39
-
-
Save jesusjavierdediego/d206ebd165a116edf41f8c226642f042 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ModelTrainer extends Logging{ | |
def train(trainingDF: DataFrame): (LogisticRegressionModel, Map[String, Double]) ={ | |
import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression} | |
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator | |
import spark.implicits._ | |
val SPLIT_FACTOR = configuration.envOrElseConfig("learning.split-factor").toDouble | |
val SPLIT_FEED = configuration.envOrElseConfig("learning.split-seed").toLong | |
val MAX_ITERATIONS = configuration.envOrElseConfig("learning.max-iterations").toInt | |
val REGULARIZATION_PARAM = configuration.envOrElseConfig("learning.regularization-parameter").toDouble | |
val ELASTIC_NET_PARAM = configuration.envOrElseConfig("learning.elastic-net-parameter").toDouble | |
// 1-Set weights and seed | |
val dataSplits = trainingDF.randomSplit(Array(SPLIT_FACTOR, 1 - SPLIT_FACTOR), seed = SPLIT_FEED) | |
val trainingData = dataSplits(0).cache() | |
val testData = dataSplits(1) | |
// 2-Check and see datasets... | |
trainingData.show(false) | |
trainingData.printSchema() | |
testData.show(false) | |
estData.printSchema() | |
// 3-Create the model and set the hyperparams | |
val lr = new LogisticRegression() | |
.setMaxIter(MAX_ITERATIONS) | |
.setRegParam(REGULARIZATION_PARAM) | |
.setElasticNetParam(ELASTIC_NET_PARAM) | |
.setFeaturesCol("features") | |
// 4-Train the model | |
val trainedModel = lr.fit(trainingData) | |
// 5-Get some factors: | |
val trainingSummary = trainedModel.summary | |
val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary] | |
val fMeasure = binarySummary.fMeasureByThreshold | |
val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0) | |
val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure).select("threshold").head().getDouble(0) | |
trainedModel.setThreshold(bestThreshold) | |
logger.info(s"Coefficients: ${trainedModel.coefficients} Intercept: ${trainedModel.intercept}") | |
logger.info(s"areaUnderROC: ${binarySummary.areaUnderROC}") | |
logger.info(s"max F-measure: ${maxFMeasure}") | |
logger.info(s"best threshold: ${bestThreshold}") | |
// 6-Test the trained model with the test dataset | |
val predictions = trainedModel.transform(testData) | |
// 7-Evaluate the performance of the model | |
val evaluator = new BinaryClassificationEvaluator() | |
val accuracy = evaluator.evaluate(predictions) | |
logger.info("Test Error = " + (1.0 - accuracy)) | |
var modelMetadata: scala.collection.mutable.Map[String, Double] = scala.collection.mutable.Map.empty[String, Double] | |
modelMetadata += ("areaUnderROC" -> binarySummary.areaUnderROC) | |
modelMetadata += ("maxFMeasure" -> maxFMeasure) | |
modelMetadata += ("accuracy" -> accuracy) | |
(trainedModel, modelMetadata.toMap) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment