Skip to content

Instantly share code, notes, and snippets.

Created August 8, 2017 14:26
Show Gist options
  • Save anonymous/238c0de4cfb95a6f411a525f9f4918d1 to your computer and use it in GitHub Desktop.
Save anonymous/238c0de4cfb95a6f411a525f9f4918d1 to your computer and use it in GitHub Desktop.
private val logger = org.log4s.getLogger
private def logMetrics(model: PipelineModel, dataset: Dataset[_]): Unit = {
import dataset.sparkSession.implicits.newProductEncoder
val labeledDataset = model.transform(dataset).select("p", "isDuplicateLabel").as[(DenseVector, Int)]
labeledDataset.cache() // will be used multiple times
val scoresAndLabels = labeledDataset map { case (p, label) =>
(p.values.last, label.toDouble) // the last of values is the probability of the positive class
}
val binaryMetrics = new BinaryClassificationMetrics(scoresAndLabels.rdd)
val areaUnderPR = binaryMetrics.areaUnderPR()
val areaUnderROC = binaryMetrics.areaUnderROC()
val predictionAndLabels = labeledDataset map { case (p, label) =>
(Array(p.values.last.round.toDouble), if (label > 0) Array(1.0) else Array(0.0))
}
val multiLabelMetrics = new MultilabelMetrics(predictionAndLabels.rdd)
logger.info(s"Trained a model with area under pr $areaUnderPR and area under roc curve $areaUnderROC, " +
s"accuracy ${multiLabelMetrics.accuracy} and f1 ${multiLabelMetrics.f1Measure}.\n" +
s"Params are: ${explainParams()}")
labeledDataset.unpersist()
}
private def logTopics(spark: SparkSession, model: PipelineModel): Unit = {
val vocabulary = getNestedMcModel[CountVectorizerModel](model, 3).vocabulary
val topicsDF = getNestedMcModel[LDAModel](model, 5).describeTopics()
val bc = spark.sparkContext.broadcast(vocabulary)
val getTerm = udf((indices: Seq[Int]) => indices.map(i => bc.value(i)))
val topics = topicsDF.withColumn("terms", getTerm(col("termIndices"))).collect
val explained = topics.map { row =>
val topic = row.getAs[Int]("topic")
val terms = row.getAs[Seq[String]]("terms")
val termWeights = row.getAs[Seq[Double]]("termWeights")
val tw = terms.zip(termWeights).map(tw => f"${tw._1} ${tw._2}%.3f").mkString(" ")
s"$topic: $tw"
}
bc.unpersist()
val msg = explained.mkString("\n")
logger.info(s"LDA model topics are:\n$msg")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment