Created
June 16, 2014 00:03
-
-
Save YOwatari/d52c442b1aca5b85c0f5 to your computer and use it in GitHub Desktop.
機械学習ハッカソン(#MLHackathon)にて、作成したギブスサンプリングによるLDA実装 参考: http://ktsukuda.net/ruby/lda_with_gibbs_sampling_using_ruby/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding:utf-8 -*- | |
# パラメータの目安 | |
# alpha k/50 | |
# beta 0.01 | |
import csv | |
from random import randint | |
from itertools import chain | |
class LDA_MLHackathon: | |
def __init__(self, alpha, beta, k, iteration_num): | |
self.alpha = alpha | |
self.beta = beta | |
self.k = k | |
self.iteration_num = iteration_num | |
self.documents = [] | |
self.v = 0 | |
self.n_dz = [] # 文書:トピック | |
self.n_tz = [] # トピック:単語 | |
self.n_z = [] # トピックの単語数 | |
self.term_topic = [] # 文書:単語 | |
self.term_index = {} | |
def read_bow(self, filename): | |
with open(filename, 'rb') as fd: | |
reader = csv.DictReader(fd) | |
document = [] | |
doc_number = 1 | |
for bow in reader: | |
if doc_number != int(bow['doc']): | |
doc_number = int(bow['doc']) | |
self.documents.append(document) | |
document = [] | |
for i in xrange(int(bow['count'])): | |
document.append(bow['word']) | |
# 最後の分 | |
doc_number = int(bow['doc']) | |
self.documents.append(document) | |
# ファイル出力 | |
with open('input.txt', 'wb') as fd: | |
for document in self.documents: | |
fd.write(" ".join(document)+"\n") | |
def initialize_parameter(self): | |
# term_index | |
flatten_documents = list(set(chain.from_iterable(self.documents))) | |
self.v = len(flatten_documents) | |
for k, term in enumerate(flatten_documents): | |
self.term_index[term] = k | |
# term_topic | |
# 各ドキュメントの各単語毎にトピック割り当て(ランダム) | |
for document in self.documents: | |
topics = [] | |
for term in document: | |
topics.append(randint(0, self.k-1)) | |
self.term_topic.append(topics) | |
# n_dz | |
for topics in self.term_topic: | |
topic_freqs = self.k * [self.alpha] | |
for topic in topics: | |
topic_freqs[topic] += 1 | |
self.n_dz.append(topic_freqs) | |
# n_tz | |
# betaで初期化 | |
self.n_tz = self.k * [len(self.term_index) * [self.beta]] | |
for d_k, document in enumerate(self.documents): | |
for t_k, term in enumerate(document): | |
index = self.term_index[term] | |
this_term_topic = self.term_topic[d_k][t_k] | |
self.n_tz[this_term_topic][index] += 1 | |
# n_z | |
# betaで初期化 | |
self.n_z = self.k * [len(self.term_index) * self.beta] | |
for d_k, document in enumerate(self.documents): | |
for t_k, term in enumerate(document): | |
this_term_topic = self.term_topic[d_k][t_k] | |
self.n_z[this_term_topic] += 1 | |
def iterate(self): | |
for i in xrange(self.iteration_num): | |
for d_k, document in enumerate(self.documents): | |
for t_k, term in enumerate(document): | |
this_term_index = self.term_index[term] | |
this_term_topic = self.term_topic[d_k][t_k] | |
# decrement | |
self.n_dz[d_k][this_term_topic] -= 1 | |
self.n_tz[this_term_topic][this_term_index] -= 1 | |
self.n_z[this_term_topic] -= 1 | |
# トピック計算 | |
max_prob = 0 | |
new_topic = 0 | |
for topic_k in xrange(self.k): | |
tmp_prob = self.n_dz[d_k][topic_k] * \ | |
(float(self.n_tz[topic_k][this_term_index]) / \ | |
self.n_z[topic_k]) | |
if max_prob < tmp_prob: | |
new_topic = topic_k | |
max_prob = tmp_prob | |
# increment | |
self.n_dz[d_k][this_term_topic] += 1 | |
self.n_tz[new_topic][this_term_index] += 1 | |
self.n_z[new_topic] += 1 | |
self.term_topic[d_k][t_k] = new_topic | |
def show(self): | |
for topic_i in xrange(self.k): | |
print "topic:%02d\n" % (topic_i + 1), | |
for term, term_j in self.term_index.items(): | |
print "%s:%f " % (term, float(self.n_tz[topic_i][term_j])/self.n_z[topic_i]) | |
print "\n" | |
print "\n" | |
for doc_i, document in enumerate(self.documents): | |
print "document %d\n" % (doc_i + 1) | |
denominator = len(document) + self.k * self.alpha | |
for topic_j, topic_freq in enumerate(self.n_dz[doc_i]): | |
print "topic%d:%f " % (topic_j+1, float(topic_freq)/denominator) | |
print "\n" | |
def save(self): | |
for topic_i in xrange(self.k): | |
filename = "topic%02d.txt" % (topic_i + 1) | |
with open(filename, "wb") as fd: | |
for term, term_j in self.term_index.items(): | |
fd.write("%s:%f\n" % (term, float(self.n_tz[topic_i][term_j])/self.n_z[topic_i])) | |
filename = "document_topic.txt" | |
with open(filename, "wb") as fd: | |
for doc_i, document in enumerate(self.documents): | |
fd.write("document %02d\n" % (doc_i + 1)) | |
denominator = len(document) + self.k * self.alpha | |
for topic_j, topic_freq in enumerate(self.n_dz[doc_i]): | |
fd.write("topic%02d:%.4f " % (topic_j+1, float(topic_freq)/denominator)) | |
fd.write("\n") | |
if __name__ == "__main__": | |
lda = LDA_MLHackathon(0.1, 0.1, 5, 50) | |
lda.read_bow("./in1_small.csv") | |
lda.initialize_parameter() | |
lda.iterate() | |
# lda.show() | |
lda.save() | |
with open("n_dz.txt", 'wb') as fd: | |
for n_dz in lda.n_dz: | |
for dz in n_dz: | |
fd.write("%.1f " % dz) | |
fd.write("\n") | |
with open("n_tz.txt", 'wb') as fd: | |
for n_tz in lda.n_tz: | |
for tz in n_tz: | |
fd.write("%.1f " % tz) | |
fd.write("\n") | |
with open("n_z.txt", 'wb') as fd: | |
for n_z in lda.n_z: | |
fd.write("%.1f\n" % n_z) | |
with open("term_topic.txt", 'wb') as fd: | |
for term in lda.term_topic: | |
for t in term: | |
fd.write("%s " % t) | |
fd.write("\n") | |
with open("term_index.txt", 'wb') as fd: | |
for k, v in lda.term_index.items(): | |
fd.write("%s:%f\n" % (k, v)) | |
print "end." |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment