Skip to content

Instantly share code, notes, and snippets.

@krrrr38
Created March 4, 2013 13:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save krrrr38/5082403 to your computer and use it in GitHub Desktop.
Save krrrr38/5082403 to your computer and use it in GitHub Desktop.
package com.krrrr38.lda
case class Token(docId: Int, wordId: Int)
case class WordProb(id: Int, prob: Double)
object LDA {
import java.util.Scanner
import java.io.File
import java.io.PrintWriter
def main(args: Array[String]) {
val sc = new Scanner(new File("data/docword.nips.txt"))
val D = sc.nextInt
val W = sc.nextInt
val N = sc.nextInt
val tokens = (for{
i <- 0 until N
did = sc.nextInt - 1
wid = sc.nextInt - 1
count = sc.nextInt
} yield List.fill(count)(Token(did,wid))).flatten
val scv = new Scanner(new File("data/vocab.nips.txt"))
val words = for(i <- 0 until W) yield scv.nextLine
val K = 50 // the number of topic
val seed = 777
val lda = new LDA(D, K, W, tokens.toList, seed)
for(i <- 0 until 201){
println(s"update count $i")
lda.update
if(i % 100 == 0){
val out = new PrintWriter(f"output/wordtopic_par$i%03d.txt")
val phi = lda.getPhi
outputWordTopicProb(phi, words, out)
out.close
}
}
}
def outputWordTopicProb(phi: Array[Array[Double]], words: Seq[String], out: PrintWriter) {
val K = phi.length
val W = phi(0).length
for(k <- 0 until K){
out.println(s"=== topic : $k")
val wordProbs = for(w <- 0 until W) yield WordProb(w, phi(k)(w))
wordProbs.sortBy(_.prob).reverse.foreach{ wp =>
out.println(s"${words(wp.id)} ${wp.prob}")
}
}
}
}
class LDA(docNumD: Int, topicNumK: Int, wordNumW: Int, tokenList: List[Token], alpha: Double, beta: Double, seed: Int) {
def this(docNumD: Int, topicNumK: Int, wordNumW: Int, tokenList: List[Token], seed: Int) {
this(docNumD, topicNumK, wordNumW, tokenList, 50.0/topicNumK, 0.1, seed)
}
val wordCount = Array.ofDim[Int](wordNumW, topicNumK)
val topicCount = Array.ofDim[Int](topicNumK)
val docCount = Array.ofDim[Int](docNumD, topicNumK)
val tokens = tokenList.toArray
val P = Array.ofDim[Double](topicNumK)
val z = Array.ofDim[Int](tokenList.length)
val rand = scala.util.Random
init
private def init {
for(i <- 0 until tokens.length){
val token = tokens(i)
val randomTopic = rand.nextInt(topicNumK)
incrementAllCount(token, randomTopic)
z(i) = randomTopic
}
}
private def incrementAllCount(token: Token, topicId: Int) {
val wordId = token.wordId
val docId = token.docId
wordCount(wordId)(topicId) += 1
docCount(docId)(topicId) += 1
topicCount(topicId) += 1
}
private def decrementAllCount(token: Token, topicId: Int) {
val wordId = token.wordId
val docId = token.docId
wordCount(wordId)(topicId) -= 1
docCount(docId)(topicId) -= 1
topicCount(topicId) -= 1
}
def update {
for(i <- 0 until tokens.length){
resample(i)
}
}
// reassign token to fit topic
private def resample(tokenId: Int) {
val token = tokens(tokenId)
// remove from current topic
val preAssignedTopicId = z(tokenId)
decrementAllCount(token, preAssignedTopicId)
// reassign token to other topic
val nextAssignedTopic = selectNextTopic(token)
incrementAllCount(token, nextAssignedTopic)
z(tokenId) = nextAssignedTopic
}
private def selectNextTopic(token: Token): Int = {
for(k <- 0 until topicNumK){
val nextValue = (wordCount(token.wordId)(k) + beta) * (docCount(token.docId)(k) + alpha) / (topicCount(k) + wordNumW * beta)
if(k != 0) {
P(k) = nextValue + P(k-1)
}else{
P(k) = nextValue
}
}
val u = rand.nextDouble() * P(topicNumK-1)
for(k <- 0 until topicNumK){
if(u < P(k)){
return k
}
}
topicNumK-1
}
def getTheta = {
val theta = Array.ofDim[Double](docNumD, topicNumK)
for(i <- 0 until docNumD) {
var sum = 0.0
for(j <- 0 until topicNumK){
theta(i)(j) = alpha + docCount(i)(j)
sum += theta(i)(j)
}
// normalize
val sinv = 1.0 / sum
for(j <- 0 until topicNumK){
theta(i)(j) = theta(i)(j) * sinv
}
}
theta
}
def getPhi = {
val phi = Array.ofDim[Double](topicNumK, wordNumW)
for(i <- 0 until topicNumK) {
var sum = 0.0
for(j <- 0 until wordNumW){
phi(i)(j) = beta + wordCount(j)(i)
sum += phi(i)(j)
}
// normalize
val sinv = 1.0 / sum
for(j <- 0 until wordNumW){
phi(i)(j) = phi(i)(j) * sinv
}
}
phi
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment