Skip to content

Instantly share code, notes, and snippets.

@BaxterEaves
Created July 4, 2015 15:03
Show Gist options
  • Save BaxterEaves/d27c64e9545323b042b5 to your computer and use it in GitHub Desktop.
Save BaxterEaves/d27c64e9545323b042b5 to your computer and use it in GitHub Desktop.
Collapsed Gibbs sampler for Latent Dirichlet Allocation.
"""
Copyright (C) 2015 Baxter Eaves
License: Do what the fuck you want to public license (WTFPL) V2
Collapsed Gibbs sampler for Latent Dirichlet Allocation.
"""
import matplotlib.pyplot as plt
import numpy as np
import random
from scipy.misc import logsumexp
from math import log
from math import sqrt
from math import ceil
def discrete_draw(p, logp=False):
"""
Single draw from a discrete distribution defined by the probabilities p.
Parameters
----------
p : numpy.ndarray(n,)
array of probabilities or log probabilities
logp : bool
True if p is a vector of log probabilities
Returns
-------
idx : int
the in [0,...,len(p)-1] index drawn from p
"""
if logp:
p = np.exp(p-logsumexp(p))
else:
p /= np.sum(p)
return np.digitize([np.random.rand()], np.cumsum(p))[0]
class GibbsLDA(object):
""" Latent Dirichlet Allocation (LDA) via collapsed Gibbs sampling.
Attributes
----------
z : list<numpy.ndarray>
Entry z[d][w] is the topic to which the w^th word in document d is
assigned
n_dk : numpy.ndarrary(n_docs, n_topics)
number of words assigned to topic k in doc d
"""
def __init__(self, docs, n_topics, n_words, alpha=1.0, beta=1.0):
"""
Parameters
----------
docs : list<dict>
Each entry corresponds to a document and has the following key:
- w : list, the index (in 0,...,n_words-1) of the word
n_topics : int
number of topics
n_words : int
number of words in corpus (vocabulary)
alpha : float (0, Inf), optional
symmetric Dirchlet parameter for word/document distribution
beta : float (0, Inf), optional
symmetric Drichlet parameter for topic/document distribution
"""
self._docs = docs
self._n_docs = len(docs)
self._n_topics = n_topics
self._n_words = n_words
self._alpha = alpha
self._beta = beta
# number of words assigned to topic k in doc d
self._n_dk = np.zeros((self._n_docs, self._n_topics))
# number of times word w is assigned to topic k
self._n_kw = np.zeros((self._n_topics, self._n_words))
# number of times any word is assigned to topic k
self._n_k = np.zeros(self._n_topics)
# Entry z[d][w] is the topic to which the w^th word in document d is
# assigned
self._z = []
for d, doc in enumerate(self._docs):
self._z.append([])
for w, wrd in enumerate(doc['w']):
topic = random.randrange(self._n_topics)
self._z[d].append(topic)
self._n_dk[d, topic] += 1.0
self._n_kw[topic, wrd] += 1.0
self._n_k[topic] += 1.0
@property
def z(self):
return self._z
@property
def n_dk(self):
return self._n_dk
def _log_conditional(self, d, k, word):
logp = log(self._n_dk[d, k] + self._alpha)
logp += log(self._n_kw[k, word] + self._beta)
logp -= log(self._n_k[k] + self._beta*self._n_words)
return logp
def _step(self):
for d, doc in enumerate(self._docs):
for w, (topic, word) in enumerate(zip(self._z[d], doc['w'])):
self._n_dk[d, topic] -= 1.0
self._n_kw[topic, word] -= 1.0
self._n_k[topic] -= 1.0
logp_k = np.zeros(self._n_topics)
for k in range(self._n_topics):
logp_k[k] = self._log_conditional(d, k, word)
topic = discrete_draw(logp_k, logp=True)
self._z[d][w] = topic
self._n_dk[d, topic] += 1.0
self._n_kw[topic, word] += 1.0
self._n_k[topic] += 1.0
def run(self, n_steps=1):
""" Run the sampler for n steps """
for _ in range(n_steps):
self._step()
def gen_docs(n_docs, n_topics, n_words, alpha=1.0, beta=1.0,
n_words_per_doc=100):
""" Generate 'documents' from LDA's generative distribution.
Parameters
----------
n_docs : int
number of documents to generate
n_words : int
number of words in the corpus (vocabulary)
alpha : float (0, Inf), optional
symmetric Dirchlet parameter for word/document distribution
beta : float (0, Inf), optional
symmetric Drichlet parameter for topic/document distribution
n_words_per_doc : int, optional
number of words per document
Returns
-------
phi : numpy.ndarrary
Each row is a topic distribution
docs : list<dict>
Each entry corresponds to a document and has the following keys:
- theta : numpy.ndarray the distribution of topics in the document
- w : list, the index (in 0,...,n_words-1) of the word
- z : list, the topic of each word
"""
phi = np.random.dirichlet([beta]*n_words, n_topics)
docs = []
for d in range(n_docs):
theta_d = np.random.dirichlet([alpha]*n_topics)
docs.append({'z': [], 'w': [], 'theta': theta_d})
for w_i in range(n_words_per_doc):
z_i = discrete_draw(theta_d)
w_i = discrete_draw(phi[z_i])
docs[d]['z'].append(z_i)
docs[d]['w'].append(w_i)
return phi, docs
def answer_err(docs, n_dk, beta, do_plot=False):
""" Error of topic-document distribution and predictive mass function
induced by the topic-document counts (n_dk).
"""
n_docs = len(docs)
n_topics = len(docs[0]['theta'])
d_sbplt = ceil(sqrt(float(n_docs)))
x = np.arange(n_topics)
sumerr = 0
for d, (doc, n_dk) in enumerate(zip(docs, n_dk)):
# Sort so we can comare the true and infered distributions (label
# switching)
p_true = np.sort(doc['theta'])
p_inferred = np.sort((n_dk+beta)/np.sum(n_dk+beta))
sumerr += np.sum(np.abs(p_true-p_inferred))
if do_plot:
ax = plt.subplot(d_sbplt, d_sbplt, d+1)
ax.bar(x, p_true, fc='dodgerblue', alpha=.5, label='true')
ax.bar(x, p_inferred, fc='deeppink', alpha=.5, label='inferred')
ax.set_ylabel('p')
ax.set_xlabel('topic index')
ax.set_title('document %d' % (d,))
merr = sumerr/n_docs
return merr
if __name__ == '__main__':
n_steps = 200
n_words = 20
n_topics = 5
n_docs = 16
n_words_per_doc = 500
alpha = .1
beta = .1
_, docs = gen_docs(n_docs, n_topics, n_words, alpha=alpha, beta=beta,
n_words_per_doc=n_words_per_doc)
glda = GibbsLDA(docs, n_topics, n_words, alpha=alpha, beta=beta)
err = [answer_err(docs, glda.n_dk, beta)]
for i in range(n_steps):
glda.run()
err.append(answer_err(docs, glda.n_dk, beta))
plt.figure(facecolor='white')
plt.plot(err)
plt.xlabel('Iteration')
plt.ylabel('Mean Absolute Error')
plt.title('Error by iteration of topic/document distributions')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment