Skip to content

Instantly share code, notes, and snippets.

@TomLous TomLous/TrainModel.scala
Last active Apr 25, 2017

Embed
What would you like to do?
// load the labeled data
val labeledSet = spark.read.parquet(path).as[LabeledVector]
// split train/test (80/20)
val Array(trainingData, testData) = labeledSet.randomSplit(Array(Config.trainSplit, 1 - Config.trainSplit))
// Basic model
val lr = new LogisticRegression().setMaxIter(200).setRegParam(0.01).setElasticNetParam(0.8)
// Train
val lrModel = lr.fit(trainingData)
println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
// Summary & ROC
val trainingSummary = lrModel.summary
val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary]
println(s"areaUnderROC: ${binarySummary.areaUnderROC}")
// F scores & treshold
val fMeasure = binarySummary.fMeasureByThreshold
val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure).select("threshold").head().getDouble(0)
lrModel.setThreshold(bestThreshold)
println(s"max F-measure: ${maxFMeasure}")
println(s"best threshold: ${bestThreshold}")
// test model
val predictions = lrModel.transform(testData)
// evaluate test
val evaluator = new BinaryClassificationEvaluator()
val accuracy = evaluator.evaluate(predictions)
println("Test Error = " + (1.0 - accuracy))
// save model
lrModel.write.overwrite().save(Config.lrModelFile)
println("Saved model " + Config.lrModelFile)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.