Create a gist now

Instantly share code, notes, and snippets.

import numpy as np
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
def get_vectors(vocab_size=5000):
newsgroups_train = fetch_20newsgroups(subset='train')
vectorizer = CountVectorizer(max_df=.9, max_features=vocab_size)
vecs = vectorizer.fit_transform(
vocabulary = vectorizer.vocabulary
terms = np.array(vocabulary.keys())
indices = np.array(vocabulary.values())
inverse_vocabulary = terms[np.argsort(indices)]
return vecs.tocsr(), inverse_vocabulary, newsgroups_train
def train_model(mat, topicConcentrationInDocuments, wordConcentrationInTopics, numTopics):
model = Model(mat, numTopics=numTopics,
logprobs = []
logProb, topicCountsInDocs, wordCountsInTopics = model.bestSample(250, 5, logprobs)
return dict(
def summarizeByWords(model, wordCountsInTopics, terms, n=20):
wordDistForTopic = model.wordDistribution(wordCountsInTopics)
for topic, dist in enumerate(wordDistForTopic):
topTerms = dist.argsort()[-n:][::-1]
print "{}: {}".format(topic, ', '.join(terms[term] for term in topTerms))
def summarizeByDocs(model, topicCountsInDocs, docids, n=5):
topicDistForDoc = model.topicDistribution(topicCountsInDocs)
for topic, docs in enumerate(topicDistForDoc.T):
topDocs = docs.argsort()[-n:][::-1]
print "{}: {}".format(topic, ', '.join('{}({:.2})'.format(docids[doc], docs[doc])
for doc in topDocs))
def summarize_model(model, terms, dataset, wordCountsInTopics, topicCountsInDocs):
from matplotlib import pyplot as plt
summarizeByWords(model, wordCountsInTopics, terms)
topicDist = model.topicDistribution(topicCountsInDocs)
x = np.arange(20)
for newsgroup in range(20):
plt.subplot(5,4,newsgroup+1), topicDist[].mean(axis=0))
if newsgroup < 16: plt.xticks(visible=False)
# Run this with something like
# vecs, terms, dataset = get_vectors()
# globals().update(train_model(vecs, 1, 1, 20))
# summarize_model(model, terms, dataset, wordCountsInTopics, topicCountsInDocs)
import numpy as np
from scipy.special import gammaln
from itertools import islice
# Latent Dirichlet Allocation using collapsed Gibbs sampling
# Based on the insight of Griffiths and Steyvers (2004) that you can
# analytically integrate out the distributions of words in a topic and
# of topics in a document.
def multinomialDraw(dist):
"""Returns a single draw from the given multinomial distribution."""
return np.random.multinomial(1, dist).argmax()
def iterWords(mat, row):
"""Given a row of word counts in a document, yield each word the given number of times."""
for ind in xrange(mat.indptr[row], mat.indptr[row+1]):
word = mat.indices[ind]
for i in xrange([ind]):
yield word
class Model(object):
def logProb(self):
# Omits the constant factors D*(gammaln(K*alpha)-K*gammaln(alpha)) and K*(gammaln(V*eta)-V*gammaln(eta)).
topicTerm = self.topicConcentrationInDocuments+self.topicCountsInDocs
vocabTerm = self.wordConcentrationInTopics+self.wordCountsInTopics
return ((gammaln(topicTerm).sum(axis=0) - gammaln(topicTerm.sum(axis=0))).sum() +
(gammaln(vocabTerm).sum(axis=0) - gammaln(vocabTerm.sum(axis=0))).sum())
def __init__(self, documentWordCounts, numTopics,
topicConcentrationInDocuments, wordConcentrationInTopics):
self.topicConcentrationInDocuments = float(topicConcentrationInDocuments) # alpha
self.wordConcentrationInTopics = float(wordConcentrationInTopics) # eta
self.numTopics = numTopics
numDocs, numVocab = documentWordCounts.shape
# We initialize counts as we see documents for the first time.
self.topicCountsInDocs = np.zeros((numTopics, numDocs))
self.wordCountsInTopics = np.zeros((numVocab, numTopics))
self.numWordsInTopic = np.zeros(numTopics)
self.topicAssignments = [np.zeros(doc.sum(), dtype=np.uint16) for doc in documentWordCounts]
self.documentWordCounts = documentWordCounts
self.sampleNum = 0
def doSample(self):
topicCountsInDocs = self.topicCountsInDocs
wordCountsInTopics = self.wordCountsInTopics
topicAssignments = self.topicAssignments
wordConcentrationInTopics = self.wordConcentrationInTopics
topicConcentrationInDocuments = self.topicConcentrationInDocuments
numVocab, numTopics = wordCountsInTopics.shape
numDocs = self.documentWordCounts.shape[0]
print self.sampleNum
for doc in xrange(numDocs):
# resample the topic assignments for each word in this document.
for i, word in enumerate(iterWords(self.documentWordCounts, doc)):
# On the first pass, words have not yet been assigned to topics.
if self.sampleNum > 0:
prevTopic = topicAssignments[doc][i]
#assert wordCountsInTopics[word,prevTopic] > 0
wordCountsInTopics[word,prevTopic] -= 1
#assert topicCountsInDocs[prevTopic,doc] > 0
topicCountsInDocs[prevTopic,doc] -= 1
#assert self.numWordsInTopic[prevTopic] > 0
self.numWordsInTopic[prevTopic] -= 1
p_wordInTopic = wordCountsInTopics[word,:] + wordConcentrationInTopics
p_wordInTopic /= (self.numWordsInTopic + numVocab*wordConcentrationInTopics)
p_topicInDoc = topicCountsInDocs[:,doc] + topicConcentrationInDocuments
dist = p_wordInTopic * p_topicInDoc.T
topic = multinomialDraw(dist/dist.sum())
wordCountsInTopics[word, topic] += 1
topicCountsInDocs[topic, doc] += 1
self.numWordsInTopic[topic] += 1
topicAssignments[doc][i] = topic
self.sampleNum += 1
def iterSamples(self, burnin, lag):
"""Yields samples from the posterior distribution of topic
counts in documents and word counts in topics.
burnin: number of samples to compute before starting to yield them
lag: number of samples between each pair of valid samples
Yields: (logProb, topicCountsInDocs, wordCountsInTopics)
The arrays returned will be modified in place by subsequent
samples, so if you want to save them, you should make copies
assert burnin > lag
for i in xrange(burnin-lag):
while True:
for j in xrange(lag+1):
yield self.logProb(), self.topicCountsInDocs.copy(), self.wordCountsInTopics.copy()
def drawSamples(self, burnin, nSamples, lag):
logProbs = np.zeros(nSamples)*np.nan
samples = []
for i, (logProb, topicCountsInDocs, wordCountsInTopics) in (
islice(self.iterSamples(burnin, lag), nSamples)):
samples.append(topicCountsInDocs, wordCountsInTopics)
logProbs[i] = self.logProb()
print "Sample {}, logProb {}".format(i, logProbs[i])
return samples, logProbs
def bestSample(self, n, burnin=0, logprobs=None):
"""Returns the sample with the best posterior probability over n samples."""
bestSamp = -np.inf, None, None
for logProb, topicCountsInDocs, wordCountsInTopics in (
islice(self.iterSamples(burnin, 0), n)):
if logProb > bestSamp[0]:
bestSamp = logProb, topicCountsInDocs.copy(), wordCountsInTopics.copy()
if logprobs is None:
print logProb
return bestSamp
def wordDistribution(self, wordCountsInTopics_sample):
"""Return the predictive distribution of words in a topic given a sample."""
dist = wordCountsInTopics_sample.T + self.wordConcentrationInTopics
dist /= dist.sum(axis=1)[:,np.newaxis]
return dist
def topicDistribution(self, topicCountsInDocs_sample):
"""Return the predictive distribution of topics in a document given a sample."""
dist = topicCountsInDocs_sample.T + self.topicConcentrationInDocuments
dist /= dist.sum(axis=1)[:,np.newaxis]
return dist
def dirichletDraw(alpha, N):
return np.random.dirichlet(np.ones(N)*alpha)
def exampleDataset(n, topicConcentrationInDocuments, docLength):
from scipy.sparse import csr_matrix
topics = np.zeros((10, 25))
for i in range(10):
topic = np.zeros((5,5))
if i<5:
topic[i,:] = 1
topic[:,i-5] = 1
topic = topic.ravel()
topic /= np.sum(topic)
topics[i,:] = topic
docs = []
for i in xrange(n):
doc = np.zeros(25, dtype=np.uint8)
dist = dirichletDraw(topicConcentrationInDocuments, 10)
for i in xrange(docLength):
topic = multinomialDraw(dist)
word = multinomialDraw(topics[topic])
doc[word] += 1
return topics, docs
def showTopics(dist):
from matplotlib import pyplot as plt
for i in range(10):
plt.imshow(dist[i].reshape(5,5), interpolation='nearest', cmap='gray')
def demo(n, topicConcentrationInDocuments, nSamples, lag):
topics, docs = exampleDataset(n, topicConcentrationInDocuments, docLength=100)
lda = Model(docs, numTopics=15,
samples, logProbs = lda.drawSamples(
burnin=10, nSamples=nSamples, lag=lag)
return locals()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment