Created
August 8, 2017 14:26
-
-
Save anonymous/238c0de4cfb95a6f411a525f9f4918d1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
private val logger = org.log4s.getLogger | |
private def logMetrics(model: PipelineModel, dataset: Dataset[_]): Unit = { | |
import dataset.sparkSession.implicits.newProductEncoder | |
val labeledDataset = model.transform(dataset).select("p", "isDuplicateLabel").as[(DenseVector, Int)] | |
labeledDataset.cache() // will be used multiple times | |
val scoresAndLabels = labeledDataset map { case (p, label) => | |
(p.values.last, label.toDouble) // the last of values is the probability of the positive class | |
} | |
val binaryMetrics = new BinaryClassificationMetrics(scoresAndLabels.rdd) | |
val areaUnderPR = binaryMetrics.areaUnderPR() | |
val areaUnderROC = binaryMetrics.areaUnderROC() | |
val predictionAndLabels = labeledDataset map { case (p, label) => | |
(Array(p.values.last.round.toDouble), if (label > 0) Array(1.0) else Array(0.0)) | |
} | |
val multiLabelMetrics = new MultilabelMetrics(predictionAndLabels.rdd) | |
logger.info(s"Trained a model with area under pr $areaUnderPR and area under roc curve $areaUnderROC, " + | |
s"accuracy ${multiLabelMetrics.accuracy} and f1 ${multiLabelMetrics.f1Measure}.\n" + | |
s"Params are: ${explainParams()}") | |
labeledDataset.unpersist() | |
} | |
private def logTopics(spark: SparkSession, model: PipelineModel): Unit = { | |
val vocabulary = getNestedMcModel[CountVectorizerModel](model, 3).vocabulary | |
val topicsDF = getNestedMcModel[LDAModel](model, 5).describeTopics() | |
val bc = spark.sparkContext.broadcast(vocabulary) | |
val getTerm = udf((indices: Seq[Int]) => indices.map(i => bc.value(i))) | |
val topics = topicsDF.withColumn("terms", getTerm(col("termIndices"))).collect | |
val explained = topics.map { row => | |
val topic = row.getAs[Int]("topic") | |
val terms = row.getAs[Seq[String]]("terms") | |
val termWeights = row.getAs[Seq[Double]]("termWeights") | |
val tw = terms.zip(termWeights).map(tw => f"${tw._1} ${tw._2}%.3f").mkString(" ") | |
s"$topic: $tw" | |
} | |
bc.unpersist() | |
val msg = explained.mkString("\n") | |
logger.info(s"LDA model topics are:\n$msg") | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment