Skip to content

Instantly share code, notes, and snippets.

@jasonbaldridge
Created May 29, 2012 18:44
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jasonbaldridge/2829963 to your computer and use it in GitHub Desktop.
Save jasonbaldridge/2829963 to your computer and use it in GitHub Desktop.
Gibbs sampler for topic models for artificial data in Steyvers and Griffiths 2007.
## An implementation of Gibbs sampling for topic models for the
## example in section 4 of Steyvers and Griffiths (2007):
## http://cocosci.berkeley.edu/tom/papers/SteyversGriffiths.pdf
##
## Author: Jason Baldridge (jasonbaldridge@gmail.com)
# Functions to parse the input data
words.to.indices = data.frame(row.names=c("r","s","b","m","l"),1:5)
mysplit = function(x) { strsplit(x,"")[[1]] }
word.vector = function(x) { words.to.indices[mysplit(x),] }
# The document by words matrix
d = matrix(c(word.vector("bbbbmmmmmmllllll"),
word.vector("bbbbbmmmmmmmllll"),
word.vector("bbbbbbbmmmmmllll"),
word.vector("bbbbbbbmmmmmmlll"),
word.vector("bbbbbbbmmlllllll"),
word.vector("bbbbbbbbbmmmllll"),
word.vector("rbbbbmmmmmmlllll"),
word.vector("rssbbbbbbmmmmlll"),
word.vector("rsssbbbbbbmmmmll"),
word.vector("rrlllbbbbbbmllll"),
word.vector("rrsssbbbbbbbmmml"),
word.vector("rrrssssssbbbbbbm"),
word.vector("rrrrrrsssbbbbbbl"),
word.vector("rrssssssssbbbbbb"),
word.vector("rrrrsssssssbbbbb"),
word.vector("rrrrrsssssssbbbb")),ncol=16,byrow=T)
# The document by topics matrix
t.start = matrix(as.numeric(c(
mysplit("2222122212112122"),
mysplit("2212211111121221"),
mysplit("2221222212121222"),
mysplit("1112122211222222"),
mysplit("1121212122122222"),
mysplit("2112111112122211"),
mysplit("2211111221221112"),
mysplit("1212211112112112"),
mysplit("1221222221121121"),
mysplit("1211212222211221"),
mysplit("2121122211221111"),
mysplit("2222222121211212"),
mysplit("2221112121222112"),
mysplit("2211222111112122"),
mysplit("2111111221212121"),
mysplit("1211211222211112"))),ncol=16,byrow=T)
t = t.start
# Parameters
alpha = .1
beta = .1
num.iterations = 64
# Constants
num.docs = nrow(d)
vocab.size = length(unique(as.vector(d)))
num.topics = length(unique(as.vector(t)))
# Populate the matrix of document by topic counts
cdt = matrix(nrow=num.docs,ncol=num.topics)
for (i in 1:num.docs) {
cdt[i,] = xtabs(~t.start[i,])
}
# Populate the matrix of word by topic counts
cwt = matrix(rep(0,vocab.size*num.topics),nrow=vocab.size,ncol=num.topics)
for (i in 1:num.docs) {
for (j in 1:length(d[i,])) {
word.id = d[i,j]
topic.id = t[i,j]
cwt[word.id,topic.id] = cwt[word.id,topic.id] + 1
}
}
# Gibbs sampling iterations
for (iteration in 1:num.iterations) {
print(iteration)
for (i in 1:num.docs) {
for (j in 1:length(d[i,])) {
word.id = d[i,j]
topic.old = t[i,j]
# Decrement counts before computing equation (3)
cdt[i,topic.old] = cdt[i,topic.old] - 1
cwt[word.id,topic.old] = cwt[word.id,topic.old] - 1
# Calculate equation (3) for each topic
vals = prop.table(cwt+beta,2)[word.id,] * prop.table(cdt[i,]+alpha)
# Sample the new topic from the normalized results for (3)
topic.new = sample(num.topics,1,prob=vals/sum(vals))
# Set the new topic and update counts
t[i,j] = topic.new
cdt[i,topic.new] = cdt[i,topic.new] + 1
cwt[word.id,topic.new] = cwt[word.id,topic.new] + 1
}
}
}
# Document-topic distributions
theta = prop.table(cdt+alpha,1)
# Word-topic distributions
phi = prop.table(cwt+beta,2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment