Create a gist now

Instantly share code, notes, and snippets.

Embed
What would you like to do?
How to implement LDA in Spark and get the topic distributions of new documents

How to implement LDA in Spark and get the topic distributions of new documents

import org.apache.spark.rdd._
import org.apache.spark.mllib.clustering.{LDA, DistributedLDAModel, LocalLDAModel}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import scala.collection.mutable

val stopWordsInput = sc.textFile("stopwords.csv")
val stopWords = stopWordsInput.collect()

//create training document set
val input = sc.textFile("training_documents/*").collect().map(s=>s.mkString);
val corpus: RDD[Array[String]] = sc.parallelize(input.map{ 
  doc => doc.split("\\s")
})

val termCounts: Array[(String, Long)] = corpus.flatMap(_.map(_ -> 1L)).reduceByKey(_ + _).collect().sortBy(-_._2)

val vocabArray: Array[String] = termCounts.takeRight(termCounts.size).map(_._1)
val vocab: Map[String, Int] = vocabArray.zipWithIndex.toMap

// Convert training documents into term count vectors
val documents: RDD[(Long, Vector)] =
    corpus.zipWithIndex.map { case (tokens, id) =>
        val counts = new mutable.HashMap[Int, Double]()
        tokens.foreach { term =>
            if (vocab.contains(term)) {
                val idx = vocab(term)
                counts(idx) = counts.getOrElse(idx, 0.0) + 1.0
            }
        }
        (id, Vectors.sparse(vocab.size, counts.toSeq))
    }
// Set LDA parameters and create model
val numTopics = 10
val ldaModel: DistributedLDAModel = new LDA().setK(numTopics).setMaxIterations(20).run(documents).asInstanceOf[DistributedLDAModel]
val localLDAModel: LocalLDAModel = ldaModel.toLocal

//create test input, convert to term count, and get its topic distribution
val test_input = Seq("this is my test document")
val test_document:RDD[(Long,Vector)] = sc.parallelize(test_input.map(doc=>doc.split("\\s"))).zipWithIndex.map{ case (tokens, id) =>
    val counts = new mutable.HashMap[Int, Double]()
    tokens.foreach { term =>
    if (vocab.contains(term)) {
        val idx = vocab(term)
        counts(idx) = counts.getOrElse(idx, 0.0) + 1.0
        }
    }
    (id, Vectors.sparse(vocab.size, counts.toSeq))
}

val topicDistributions = localLDAModel.topicDistributions(test_document)
println("first topic distribution:"+topicDistributions.first._2.toArray.mkString(", "))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment