Skip to content

Instantly share code, notes, and snippets.

@alex9311
Last active April 15, 2020 23:13
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save alex9311/774089d936eee505d7832c6df2eb597d to your computer and use it in GitHub Desktop.
Save alex9311/774089d936eee505d7832c6df2eb597d to your computer and use it in GitHub Desktop.
How to implement LDA in Spark and get the topic distributions of new documents

How to implement LDA in Spark and get the topic distributions of new documents

import org.apache.spark.rdd._
import org.apache.spark.mllib.clustering.{LDA, DistributedLDAModel, LocalLDAModel}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import scala.collection.mutable

val stopWordsInput = sc.textFile("stopwords.csv")
val stopWords = stopWordsInput.collect()

//create training document set
val input = sc.textFile("training_documents/*").collect().map(s=>s.mkString);
val corpus: RDD[Array[String]] = sc.parallelize(input.map{ 
  doc => doc.split("\\s")
})

val termCounts: Array[(String, Long)] = corpus.flatMap(_.map(_ -> 1L)).reduceByKey(_ + _).collect().sortBy(-_._2)

val vocabArray: Array[String] = termCounts.takeRight(termCounts.size).map(_._1)
val vocab: Map[String, Int] = vocabArray.zipWithIndex.toMap

// Convert training documents into term count vectors
val documents: RDD[(Long, Vector)] =
    corpus.zipWithIndex.map { case (tokens, id) =>
        val counts = new mutable.HashMap[Int, Double]()
        tokens.foreach { term =>
            if (vocab.contains(term)) {
                val idx = vocab(term)
                counts(idx) = counts.getOrElse(idx, 0.0) + 1.0
            }
        }
        (id, Vectors.sparse(vocab.size, counts.toSeq))
    }
// Set LDA parameters and create model
val numTopics = 10
val ldaModel: DistributedLDAModel = new LDA().setK(numTopics).setMaxIterations(20).run(documents).asInstanceOf[DistributedLDAModel]
val localLDAModel: LocalLDAModel = ldaModel.toLocal

//create test input, convert to term count, and get its topic distribution
val test_input = Seq("this is my test document")
val test_document:RDD[(Long,Vector)] = sc.parallelize(test_input.map(doc=>doc.split("\\s"))).zipWithIndex.map{ case (tokens, id) =>
    val counts = new mutable.HashMap[Int, Double]()
    tokens.foreach { term =>
    if (vocab.contains(term)) {
        val idx = vocab(term)
        counts(idx) = counts.getOrElse(idx, 0.0) + 1.0
        }
    }
    (id, Vectors.sparse(vocab.size, counts.toSeq))
}

val topicDistributions = localLDAModel.topicDistributions(test_document)
println("first topic distribution:"+topicDistributions.first._2.toArray.mkString(", "))
@medoidai
Copy link

medoidai commented Apr 11, 2020

Hi Alex,

Hope you are well.

I am thinking of trying this approach for topic modelling in short texts.

I want to extract topic probabilities for some texts and then compare them by comparing the distributions.

I want to ask you the following:

How the trained Spark LDA model will behave when it will tested/applied on new unseen text? For example, some of the words or all from the input texts might be new.

Thank you.

Efstathios

@alex9311
Copy link
Author

alex9311 commented Apr 15, 2020

Hi @medoidai , good question! I've never tried using an LDA model on text with words it hasn't seen before. I am no expert in LDA, just using off-the-shelf implementations. Making new types of LDA that allow for infinite vocabularies (new words) does seem to be an area of research though

http://proceedings.mlr.press/v28/zhai13.pdf

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment