Skip to content

Instantly share code, notes, and snippets.

@chiral
Last active May 4, 2019 13:16
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chiral/2560383d5643da80b6bf to your computer and use it in GitHub Desktop.
Save chiral/2560383d5643da80b6bf to your computer and use it in GitHub Desktop.
An implementation of Collapsed Gibbs sampling algorithm for LDA in R
# LDA collapsed Gibbs sampler implementation in R by isobe
bows2corpus <- function(bows) {
print("bows2corpus")
docs <- list()
words <- c()
index <- list()
last_index <- 0
word2index <- function(word) {
if (is.null(index[[word]])) {
last_index <<- last_index+1
index[[word]] <<- last_index
words[last_index] <<- word
}
return(index[[word]])
}
D <- length(bows)
for (d in 1:D) {
bow <- bows[[d]]
ws <- c()
for (word in names(bow)) {
count <- bow[[word]]
i <- word2index(word)
ws <- c(ws,rep(i,count))
}
docs[[d]] <- ws
}
return(list(docs=docs,words=words,index=index))
}
lda_cgs <- function(corpus,K,alpha,beta,num_iter=50) {
print("lda_cgs")
docs <- corpus$docs
words <- corpus$words
V <- length(words)
D <- length(docs)
print(paste("V,D=",V,D))
### initialize ###
L <- 0
topics <- list() # represents V*D sparse matrix
n_td <- matrix(0,K,D)
n_wt <- matrix(0,V,K)
n_t <- rep(0,K)
for (d in 1:D) {
ws <- docs[[d]]
N <- length(ws)
L <- L + N
ks <- ceiling(runif(N)*K)
topics[[d]] <- ks
for (w in 1:N) {
i <- ws[w]
k <- ks[w]
n_wt[i,k] <- n_wt[i,k]+1
}
for (k in 1:K) {
nk <- sum(which(ks==k))
n_td[k,d] <- nk
n_t[k] <- n_t[k]+nk
}
}
### update topic of word ###
before_update <- function(d,w,i) {
k <- topics[[d]][w]
topics[[d]][w] <<- 0
n_wt[i,k] <<- n_wt[i,k]-1
n_td[k,d] <<- n_td[k,d]-1
n_t[k] <<- n_t[k]-1
}
after_update <- function(d,w,i,k) {
topics[[d]][w] <<- k
n_wt[i,k] <<- n_wt[i,k]+1
n_td[k,d] <<- n_td[k,d]+1
n_t[k] <<- n_t[k]+1
}
### Gibbs sampling ###
sample <- function(d,w,i) {
prob <- c()
for (k in 1:K) {
v <- alpha+n_td[k,d]
v <- v * (beta+n_wt[i,k])
v <- v / (beta*V+n_t[k])
prob[k] <- v
}
prob <- prob/sum(prob)
r <- rmultinom(1,1,prob)
return(which(r==1))
}
### main loop
stats <- matrix(0,L*num_iter,5)
for (iter in 1:num_iter) {
print(paste("iter=",iter))
count <- 0
for (d in 1:D) {
print(paste("iter=",iter,"doc=",d,'/',D))
ws <- docs[[d]]
N <- length(ws)
for (w in 1:N) {
i <- ws[w]
before_update(d,w,i)
k <- sample(d,w,i)
after_update(d,w,i,k)
count <- count+1
stats[(iter-1)*L+count,] <- c(iter,d,w,i,k)
}
}
}
### process result ###
return(data.frame(iter=stats[,1],
d=stats[,2],w=stats[,3],i=stats[,4],k=stats[,5]))
}
csv2bows <- function(fn) {
print("csv2bows")
df <- read.csv(fn,stringsAsFactors=F)
bows <- list()
for (i in 1:nrow(df)) {
d <- df[i,]$doc
word <- df[i,]$word
count <- df[i,]$count
if (length(bows)<d) {
bows[[d]] <- list()
}
bows[[d]][[word]] <- count
}
return(bows)
}
stats <- NULL
corpus <- NULL
test <- function(f_in,f_out,K) {
bows <- csv2bows(f_in)
corpus <<- bows2corpus(bows)
alpha <- 50/K
beta <- 0.1
stats <<- lda_cgs(corpus,K,alpha,beta)
write.csv(stats,f_out)
}
test("http://labs.adfive.net/mlhackathon/20140614/in1_small.csv","out1.csv",5)
hist(stats[which(stats$d==1),]$k)
corpus$index['タモリ']
stats[which(stats$i==66),]$k)
corpus$words[stats[which(stats$iter==50 & stats$i==1),]$i]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment