Created
March 4, 2013 13:58
-
-
Save krrrr38/5082403 to your computer and use it in GitHub Desktop.
lda with gibbs sampling http://d.hatena.ne.jp/tsubosaka/20091223/1261572639
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.krrrr38.lda | |
case class Token(docId: Int, wordId: Int) | |
case class WordProb(id: Int, prob: Double) | |
object LDA { | |
import java.util.Scanner | |
import java.io.File | |
import java.io.PrintWriter | |
def main(args: Array[String]) { | |
val sc = new Scanner(new File("data/docword.nips.txt")) | |
val D = sc.nextInt | |
val W = sc.nextInt | |
val N = sc.nextInt | |
val tokens = (for{ | |
i <- 0 until N | |
did = sc.nextInt - 1 | |
wid = sc.nextInt - 1 | |
count = sc.nextInt | |
} yield List.fill(count)(Token(did,wid))).flatten | |
val scv = new Scanner(new File("data/vocab.nips.txt")) | |
val words = for(i <- 0 until W) yield scv.nextLine | |
val K = 50 // the number of topic | |
val seed = 777 | |
val lda = new LDA(D, K, W, tokens.toList, seed) | |
for(i <- 0 until 201){ | |
println(s"update count $i") | |
lda.update | |
if(i % 100 == 0){ | |
val out = new PrintWriter(f"output/wordtopic_par$i%03d.txt") | |
val phi = lda.getPhi | |
outputWordTopicProb(phi, words, out) | |
out.close | |
} | |
} | |
} | |
def outputWordTopicProb(phi: Array[Array[Double]], words: Seq[String], out: PrintWriter) { | |
val K = phi.length | |
val W = phi(0).length | |
for(k <- 0 until K){ | |
out.println(s"=== topic : $k") | |
val wordProbs = for(w <- 0 until W) yield WordProb(w, phi(k)(w)) | |
wordProbs.sortBy(_.prob).reverse.foreach{ wp => | |
out.println(s"${words(wp.id)} ${wp.prob}") | |
} | |
} | |
} | |
} | |
class LDA(docNumD: Int, topicNumK: Int, wordNumW: Int, tokenList: List[Token], alpha: Double, beta: Double, seed: Int) { | |
def this(docNumD: Int, topicNumK: Int, wordNumW: Int, tokenList: List[Token], seed: Int) { | |
this(docNumD, topicNumK, wordNumW, tokenList, 50.0/topicNumK, 0.1, seed) | |
} | |
val wordCount = Array.ofDim[Int](wordNumW, topicNumK) | |
val topicCount = Array.ofDim[Int](topicNumK) | |
val docCount = Array.ofDim[Int](docNumD, topicNumK) | |
val tokens = tokenList.toArray | |
val P = Array.ofDim[Double](topicNumK) | |
val z = Array.ofDim[Int](tokenList.length) | |
val rand = scala.util.Random | |
init | |
private def init { | |
for(i <- 0 until tokens.length){ | |
val token = tokens(i) | |
val randomTopic = rand.nextInt(topicNumK) | |
incrementAllCount(token, randomTopic) | |
z(i) = randomTopic | |
} | |
} | |
private def incrementAllCount(token: Token, topicId: Int) { | |
val wordId = token.wordId | |
val docId = token.docId | |
wordCount(wordId)(topicId) += 1 | |
docCount(docId)(topicId) += 1 | |
topicCount(topicId) += 1 | |
} | |
private def decrementAllCount(token: Token, topicId: Int) { | |
val wordId = token.wordId | |
val docId = token.docId | |
wordCount(wordId)(topicId) -= 1 | |
docCount(docId)(topicId) -= 1 | |
topicCount(topicId) -= 1 | |
} | |
def update { | |
for(i <- 0 until tokens.length){ | |
resample(i) | |
} | |
} | |
// reassign token to fit topic | |
private def resample(tokenId: Int) { | |
val token = tokens(tokenId) | |
// remove from current topic | |
val preAssignedTopicId = z(tokenId) | |
decrementAllCount(token, preAssignedTopicId) | |
// reassign token to other topic | |
val nextAssignedTopic = selectNextTopic(token) | |
incrementAllCount(token, nextAssignedTopic) | |
z(tokenId) = nextAssignedTopic | |
} | |
private def selectNextTopic(token: Token): Int = { | |
for(k <- 0 until topicNumK){ | |
val nextValue = (wordCount(token.wordId)(k) + beta) * (docCount(token.docId)(k) + alpha) / (topicCount(k) + wordNumW * beta) | |
if(k != 0) { | |
P(k) = nextValue + P(k-1) | |
}else{ | |
P(k) = nextValue | |
} | |
} | |
val u = rand.nextDouble() * P(topicNumK-1) | |
for(k <- 0 until topicNumK){ | |
if(u < P(k)){ | |
return k | |
} | |
} | |
topicNumK-1 | |
} | |
def getTheta = { | |
val theta = Array.ofDim[Double](docNumD, topicNumK) | |
for(i <- 0 until docNumD) { | |
var sum = 0.0 | |
for(j <- 0 until topicNumK){ | |
theta(i)(j) = alpha + docCount(i)(j) | |
sum += theta(i)(j) | |
} | |
// normalize | |
val sinv = 1.0 / sum | |
for(j <- 0 until topicNumK){ | |
theta(i)(j) = theta(i)(j) * sinv | |
} | |
} | |
theta | |
} | |
def getPhi = { | |
val phi = Array.ofDim[Double](topicNumK, wordNumW) | |
for(i <- 0 until topicNumK) { | |
var sum = 0.0 | |
for(j <- 0 until wordNumW){ | |
phi(i)(j) = beta + wordCount(j)(i) | |
sum += phi(i)(j) | |
} | |
// normalize | |
val sinv = 1.0 / sum | |
for(j <- 0 until wordNumW){ | |
phi(i)(j) = phi(i)(j) * sinv | |
} | |
} | |
phi | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment