Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
import{CountVectorizer, RegexTokenizer, StopWordsRemover}
import org.apache.spark.mllib.clustering.{LDA, OnlineLDAOptimizer}
import org.apache.spark.mllib.linalg.Vector
import sqlContext.implicits._
val numTopics: Int = 100
val maxIterations: Int = 100
val vocabSize: Int = 10000
* Data Preprocessing
// Load the raw articles, assign docIDs, and convert to DataFrame
val rawTextRDD = ... // loaded from S3
val docDF = rawTextRDD.zipWithIndex.toDF("text", "docId")
// Split each document into words
val tokens = new RegexTokenizer()
// Filter out stopwords
val stopwords: Array[String] = ... // loaded from S3
val filteredTokens = new StopWordsRemover()
// Limit to top `vocabSize` most common words and convert to word count vector features
val cvModel = new CountVectorizer()
val countVectors = cvModel.transform(filteredTokens)
.select("docId", "features")
.map { case Row(docId: Long, countVector: Vector) => (docId, countVector) }
* Configure and run LDA
val mbf = {
// add (1.0 / actualCorpusSize) to MiniBatchFraction be more robust on tiny datasets.
val corpusSize = countVectors.count()
2.0 / maxIterations + 1.0 / corpusSize
val lda = new LDA()
.setOptimizer(new OnlineLDAOptimizer().setMiniBatchFraction(math.min(1.0, mbf)))
.setDocConcentration(-1) // use default symmetric document-topic prior
.setTopicConcentration(-1) // use default symmetric topic-word prior
val startTime = System.nanoTime()
val ldaModel =
val elapsed = (System.nanoTime() - startTime) / 1e9
* Print results.
// Print training time
println(s"Finished training LDA model. Summary:")
println(s"Training time (sec)\t$elapsed")
// Print the topics, showing the top-weighted terms for each topic.
val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)
val vocabArray = cvModel.vocabulary
val topics = { case (terms, termWeights) =>
println(s"$numTopics topics:")
topics.zipWithIndex.foreach { case (topic, i) =>
println(s"TOPIC $i")
topic.foreach { case (term, weight) => println(s"$term\t$weight") }

kuza55 commented Jan 15, 2016

I noticed that the LDA model has maxIterations explicitly set to 2 on line 57, rather than the 100 defined at the top of the file; does this affect the benchmarks in the associated blog post?

abrocod commented Feb 5, 2016

The blog post mentioned that we can also query for the top topics that each document talks about. I am wondering how to do this?

Can you specify the memory usage need to run this Pipeline

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment