Skip to content

Instantly share code, notes, and snippets.

@YOwatari
Created June 16, 2014 00:03
Show Gist options
  • Save YOwatari/d52c442b1aca5b85c0f5 to your computer and use it in GitHub Desktop.
Save YOwatari/d52c442b1aca5b85c0f5 to your computer and use it in GitHub Desktop.
機械学習ハッカソン(#MLHackathon)にて、作成したギブスサンプリングによるLDA実装 参考: http://ktsukuda.net/ruby/lda_with_gibbs_sampling_using_ruby/
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# パラメータの目安
# alpha k/50
# beta 0.01
import csv
from random import randint
from itertools import chain
class LDA_MLHackathon:
def __init__(self, alpha, beta, k, iteration_num):
self.alpha = alpha
self.beta = beta
self.k = k
self.iteration_num = iteration_num
self.documents = []
self.v = 0
self.n_dz = [] # 文書:トピック
self.n_tz = [] # トピック:単語
self.n_z = [] # トピックの単語数
self.term_topic = [] # 文書:単語
self.term_index = {}
def read_bow(self, filename):
with open(filename, 'rb') as fd:
reader = csv.DictReader(fd)
document = []
doc_number = 1
for bow in reader:
if doc_number != int(bow['doc']):
doc_number = int(bow['doc'])
self.documents.append(document)
document = []
for i in xrange(int(bow['count'])):
document.append(bow['word'])
# 最後の分
doc_number = int(bow['doc'])
self.documents.append(document)
# ファイル出力
with open('input.txt', 'wb') as fd:
for document in self.documents:
fd.write(" ".join(document)+"\n")
def initialize_parameter(self):
# term_index
flatten_documents = list(set(chain.from_iterable(self.documents)))
self.v = len(flatten_documents)
for k, term in enumerate(flatten_documents):
self.term_index[term] = k
# term_topic
# 各ドキュメントの各単語毎にトピック割り当て(ランダム)
for document in self.documents:
topics = []
for term in document:
topics.append(randint(0, self.k-1))
self.term_topic.append(topics)
# n_dz
for topics in self.term_topic:
topic_freqs = self.k * [self.alpha]
for topic in topics:
topic_freqs[topic] += 1
self.n_dz.append(topic_freqs)
# n_tz
# betaで初期化
self.n_tz = self.k * [len(self.term_index) * [self.beta]]
for d_k, document in enumerate(self.documents):
for t_k, term in enumerate(document):
index = self.term_index[term]
this_term_topic = self.term_topic[d_k][t_k]
self.n_tz[this_term_topic][index] += 1
# n_z
# betaで初期化
self.n_z = self.k * [len(self.term_index) * self.beta]
for d_k, document in enumerate(self.documents):
for t_k, term in enumerate(document):
this_term_topic = self.term_topic[d_k][t_k]
self.n_z[this_term_topic] += 1
def iterate(self):
for i in xrange(self.iteration_num):
for d_k, document in enumerate(self.documents):
for t_k, term in enumerate(document):
this_term_index = self.term_index[term]
this_term_topic = self.term_topic[d_k][t_k]
# decrement
self.n_dz[d_k][this_term_topic] -= 1
self.n_tz[this_term_topic][this_term_index] -= 1
self.n_z[this_term_topic] -= 1
# トピック計算
max_prob = 0
new_topic = 0
for topic_k in xrange(self.k):
tmp_prob = self.n_dz[d_k][topic_k] * \
(float(self.n_tz[topic_k][this_term_index]) / \
self.n_z[topic_k])
if max_prob < tmp_prob:
new_topic = topic_k
max_prob = tmp_prob
# increment
self.n_dz[d_k][this_term_topic] += 1
self.n_tz[new_topic][this_term_index] += 1
self.n_z[new_topic] += 1
self.term_topic[d_k][t_k] = new_topic
def show(self):
for topic_i in xrange(self.k):
print "topic:%02d\n" % (topic_i + 1),
for term, term_j in self.term_index.items():
print "%s:%f " % (term, float(self.n_tz[topic_i][term_j])/self.n_z[topic_i])
print "\n"
print "\n"
for doc_i, document in enumerate(self.documents):
print "document %d\n" % (doc_i + 1)
denominator = len(document) + self.k * self.alpha
for topic_j, topic_freq in enumerate(self.n_dz[doc_i]):
print "topic%d:%f " % (topic_j+1, float(topic_freq)/denominator)
print "\n"
def save(self):
for topic_i in xrange(self.k):
filename = "topic%02d.txt" % (topic_i + 1)
with open(filename, "wb") as fd:
for term, term_j in self.term_index.items():
fd.write("%s:%f\n" % (term, float(self.n_tz[topic_i][term_j])/self.n_z[topic_i]))
filename = "document_topic.txt"
with open(filename, "wb") as fd:
for doc_i, document in enumerate(self.documents):
fd.write("document %02d\n" % (doc_i + 1))
denominator = len(document) + self.k * self.alpha
for topic_j, topic_freq in enumerate(self.n_dz[doc_i]):
fd.write("topic%02d:%.4f " % (topic_j+1, float(topic_freq)/denominator))
fd.write("\n")
if __name__ == "__main__":
lda = LDA_MLHackathon(0.1, 0.1, 5, 50)
lda.read_bow("./in1_small.csv")
lda.initialize_parameter()
lda.iterate()
# lda.show()
lda.save()
with open("n_dz.txt", 'wb') as fd:
for n_dz in lda.n_dz:
for dz in n_dz:
fd.write("%.1f " % dz)
fd.write("\n")
with open("n_tz.txt", 'wb') as fd:
for n_tz in lda.n_tz:
for tz in n_tz:
fd.write("%.1f " % tz)
fd.write("\n")
with open("n_z.txt", 'wb') as fd:
for n_z in lda.n_z:
fd.write("%.1f\n" % n_z)
with open("term_topic.txt", 'wb') as fd:
for term in lda.term_topic:
for t in term:
fd.write("%s " % t)
fd.write("\n")
with open("term_index.txt", 'wb') as fd:
for k, v in lda.term_index.items():
fd.write("%s:%f\n" % (k, v))
print "end."
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment