Created
July 4, 2015 15:03
-
-
Save BaxterEaves/d27c64e9545323b042b5 to your computer and use it in GitHub Desktop.
Collapsed Gibbs sampler for Latent Dirichlet Allocation.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Copyright (C) 2015 Baxter Eaves | |
License: Do what the fuck you want to public license (WTFPL) V2 | |
Collapsed Gibbs sampler for Latent Dirichlet Allocation. | |
""" | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import random | |
from scipy.misc import logsumexp | |
from math import log | |
from math import sqrt | |
from math import ceil | |
def discrete_draw(p, logp=False): | |
""" | |
Single draw from a discrete distribution defined by the probabilities p. | |
Parameters | |
---------- | |
p : numpy.ndarray(n,) | |
array of probabilities or log probabilities | |
logp : bool | |
True if p is a vector of log probabilities | |
Returns | |
------- | |
idx : int | |
the in [0,...,len(p)-1] index drawn from p | |
""" | |
if logp: | |
p = np.exp(p-logsumexp(p)) | |
else: | |
p /= np.sum(p) | |
return np.digitize([np.random.rand()], np.cumsum(p))[0] | |
class GibbsLDA(object): | |
""" Latent Dirichlet Allocation (LDA) via collapsed Gibbs sampling. | |
Attributes | |
---------- | |
z : list<numpy.ndarray> | |
Entry z[d][w] is the topic to which the w^th word in document d is | |
assigned | |
n_dk : numpy.ndarrary(n_docs, n_topics) | |
number of words assigned to topic k in doc d | |
""" | |
def __init__(self, docs, n_topics, n_words, alpha=1.0, beta=1.0): | |
""" | |
Parameters | |
---------- | |
docs : list<dict> | |
Each entry corresponds to a document and has the following key: | |
- w : list, the index (in 0,...,n_words-1) of the word | |
n_topics : int | |
number of topics | |
n_words : int | |
number of words in corpus (vocabulary) | |
alpha : float (0, Inf), optional | |
symmetric Dirchlet parameter for word/document distribution | |
beta : float (0, Inf), optional | |
symmetric Drichlet parameter for topic/document distribution | |
""" | |
self._docs = docs | |
self._n_docs = len(docs) | |
self._n_topics = n_topics | |
self._n_words = n_words | |
self._alpha = alpha | |
self._beta = beta | |
# number of words assigned to topic k in doc d | |
self._n_dk = np.zeros((self._n_docs, self._n_topics)) | |
# number of times word w is assigned to topic k | |
self._n_kw = np.zeros((self._n_topics, self._n_words)) | |
# number of times any word is assigned to topic k | |
self._n_k = np.zeros(self._n_topics) | |
# Entry z[d][w] is the topic to which the w^th word in document d is | |
# assigned | |
self._z = [] | |
for d, doc in enumerate(self._docs): | |
self._z.append([]) | |
for w, wrd in enumerate(doc['w']): | |
topic = random.randrange(self._n_topics) | |
self._z[d].append(topic) | |
self._n_dk[d, topic] += 1.0 | |
self._n_kw[topic, wrd] += 1.0 | |
self._n_k[topic] += 1.0 | |
@property | |
def z(self): | |
return self._z | |
@property | |
def n_dk(self): | |
return self._n_dk | |
def _log_conditional(self, d, k, word): | |
logp = log(self._n_dk[d, k] + self._alpha) | |
logp += log(self._n_kw[k, word] + self._beta) | |
logp -= log(self._n_k[k] + self._beta*self._n_words) | |
return logp | |
def _step(self): | |
for d, doc in enumerate(self._docs): | |
for w, (topic, word) in enumerate(zip(self._z[d], doc['w'])): | |
self._n_dk[d, topic] -= 1.0 | |
self._n_kw[topic, word] -= 1.0 | |
self._n_k[topic] -= 1.0 | |
logp_k = np.zeros(self._n_topics) | |
for k in range(self._n_topics): | |
logp_k[k] = self._log_conditional(d, k, word) | |
topic = discrete_draw(logp_k, logp=True) | |
self._z[d][w] = topic | |
self._n_dk[d, topic] += 1.0 | |
self._n_kw[topic, word] += 1.0 | |
self._n_k[topic] += 1.0 | |
def run(self, n_steps=1): | |
""" Run the sampler for n steps """ | |
for _ in range(n_steps): | |
self._step() | |
def gen_docs(n_docs, n_topics, n_words, alpha=1.0, beta=1.0, | |
n_words_per_doc=100): | |
""" Generate 'documents' from LDA's generative distribution. | |
Parameters | |
---------- | |
n_docs : int | |
number of documents to generate | |
n_words : int | |
number of words in the corpus (vocabulary) | |
alpha : float (0, Inf), optional | |
symmetric Dirchlet parameter for word/document distribution | |
beta : float (0, Inf), optional | |
symmetric Drichlet parameter for topic/document distribution | |
n_words_per_doc : int, optional | |
number of words per document | |
Returns | |
------- | |
phi : numpy.ndarrary | |
Each row is a topic distribution | |
docs : list<dict> | |
Each entry corresponds to a document and has the following keys: | |
- theta : numpy.ndarray the distribution of topics in the document | |
- w : list, the index (in 0,...,n_words-1) of the word | |
- z : list, the topic of each word | |
""" | |
phi = np.random.dirichlet([beta]*n_words, n_topics) | |
docs = [] | |
for d in range(n_docs): | |
theta_d = np.random.dirichlet([alpha]*n_topics) | |
docs.append({'z': [], 'w': [], 'theta': theta_d}) | |
for w_i in range(n_words_per_doc): | |
z_i = discrete_draw(theta_d) | |
w_i = discrete_draw(phi[z_i]) | |
docs[d]['z'].append(z_i) | |
docs[d]['w'].append(w_i) | |
return phi, docs | |
def answer_err(docs, n_dk, beta, do_plot=False): | |
""" Error of topic-document distribution and predictive mass function | |
induced by the topic-document counts (n_dk). | |
""" | |
n_docs = len(docs) | |
n_topics = len(docs[0]['theta']) | |
d_sbplt = ceil(sqrt(float(n_docs))) | |
x = np.arange(n_topics) | |
sumerr = 0 | |
for d, (doc, n_dk) in enumerate(zip(docs, n_dk)): | |
# Sort so we can comare the true and infered distributions (label | |
# switching) | |
p_true = np.sort(doc['theta']) | |
p_inferred = np.sort((n_dk+beta)/np.sum(n_dk+beta)) | |
sumerr += np.sum(np.abs(p_true-p_inferred)) | |
if do_plot: | |
ax = plt.subplot(d_sbplt, d_sbplt, d+1) | |
ax.bar(x, p_true, fc='dodgerblue', alpha=.5, label='true') | |
ax.bar(x, p_inferred, fc='deeppink', alpha=.5, label='inferred') | |
ax.set_ylabel('p') | |
ax.set_xlabel('topic index') | |
ax.set_title('document %d' % (d,)) | |
merr = sumerr/n_docs | |
return merr | |
if __name__ == '__main__': | |
n_steps = 200 | |
n_words = 20 | |
n_topics = 5 | |
n_docs = 16 | |
n_words_per_doc = 500 | |
alpha = .1 | |
beta = .1 | |
_, docs = gen_docs(n_docs, n_topics, n_words, alpha=alpha, beta=beta, | |
n_words_per_doc=n_words_per_doc) | |
glda = GibbsLDA(docs, n_topics, n_words, alpha=alpha, beta=beta) | |
err = [answer_err(docs, glda.n_dk, beta)] | |
for i in range(n_steps): | |
glda.run() | |
err.append(answer_err(docs, glda.n_dk, beta)) | |
plt.figure(facecolor='white') | |
plt.plot(err) | |
plt.xlabel('Iteration') | |
plt.ylabel('Mean Absolute Error') | |
plt.title('Error by iteration of topic/document distributions') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment