Skip to content

Instantly share code, notes, and snippets.

@TomLous
Last active April 25, 2017 09:13
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save TomLous/dca6e006dffcaed3c71b06026c97672a to your computer and use it in GitHub Desktop.
Save TomLous/dca6e006dffcaed3c71b06026c97672a to your computer and use it in GitHub Desktop.
// load the labeled data
val labeledSet = spark.read.parquet(path).as[LabeledVector]
// split train/test (80/20)
val Array(trainingData, testData) = labeledSet.randomSplit(Array(Config.trainSplit, 1 - Config.trainSplit))
// Basic model
val lr = new LogisticRegression().setMaxIter(200).setRegParam(0.01).setElasticNetParam(0.8)
// Train
val lrModel = lr.fit(trainingData)
println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
// Summary & ROC
val trainingSummary = lrModel.summary
val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary]
println(s"areaUnderROC: ${binarySummary.areaUnderROC}")
// F scores & treshold
val fMeasure = binarySummary.fMeasureByThreshold
val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure).select("threshold").head().getDouble(0)
lrModel.setThreshold(bestThreshold)
println(s"max F-measure: ${maxFMeasure}")
println(s"best threshold: ${bestThreshold}")
// test model
val predictions = lrModel.transform(testData)
// evaluate test
val evaluator = new BinaryClassificationEvaluator()
val accuracy = evaluator.evaluate(predictions)
println("Test Error = " + (1.0 - accuracy))
// save model
lrModel.write.overwrite().save(Config.lrModelFile)
println("Saved model " + Config.lrModelFile)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment