Create a gist now

Instantly share code, notes, and snippets.

@tdhopper /lda.py
Last active Nov 7, 2017

Embed
What would you like to do?
from scipy.stats import dirichlet, poisson
from numpy.random import choice
num_documents = 5
num_topics = 2
topic_dirichlet_parameter = 1 # beta
term_dirichlet_parameter = 1 # alpha
vocabulary = ["see", "spot", "run"]
num_terms = len(vocabulary)
length_param = 10 # xi
term_distribution_by_topic = {} # Phi
topic_distribution_by_document = {} # Theta
document_length = {}
topic_index = defaultdict(list)
word_index = defaultdict(list)
term_distribution = dirichlet(num_terms * [term_dirichlet_parameter])
topic_distribution = dirichlet(num_topics * [topic_dirichlet_parameter])
# Topic plate: for each topic...
for topic in range(num_topics):
# ...sample a multinomial distribution over the terms.
term_distribution_by_topic[topic] = term_distribution.rvs()[0]
# Document plate: for each document...
for document in range(num_documents):
# ...sample a multinomial distribution over the topics.
topic_distribution_by_document[document] = topic_distribution.rvs()[0]
topic_distribution_param = topic_distribution_by_document[document]
# ...sample the document length from a poisson distribution.
document_length[document] = poisson(length_param).rvs()
# Word plate: for each word in the document...
for word in range(document_length[document]):
topics = range(num_topics)
# ...sample the topic generating the word.
topic = choice(topics, p=topic_distribution_param)
topic_index[document].append(topic)
# ...sample the term generated by the topic.
term_distribution_param = term_distribution_by_topic[topic]
word_index[document].append(choice(vocabulary, p=term_distribution_param))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment