Last active
January 1, 2016 11:37
-
-
Save satzz/00889e844131a061f0d1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import os | |
from os.path import join, dirname | |
import MeCab | |
import gensim | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.grid_search import GridSearchCV | |
from sklearn.cross_validation import train_test_split | |
from sklearn.metrics import classification_report | |
labels = [ | |
'birth', # 5年ぶり出生率増加 | |
'ekiden', # ニューイヤー駅伝関連 | |
'tunnel', # 君津のトンネルのモルタル剥離 | |
'ikukyu', # 国会議員の育休取得 | |
'fe', # 日本食品標準成分表のひじきの鉄分含有量修正 | |
'takahama', # 高浜原発関連 | |
'thief', # キングオブコメディの高橋健一逮捕 | |
'starwars', # スターウォーズ(フォースの覚醒)関連 | |
'design', # 国立競技場のデザイン関連 | |
'riken', # 理研の新元素命名権獲得 | |
] | |
num_topics = len(labels) | |
m = MeCab.Tagger('') | |
article_path = join(dirname(__file__), 'articles') | |
for root, dirs, files in os.walk(article_path): | |
print '# MORPHOLOGICAL ANALYSIS' | |
docs = {} | |
for docname in files: | |
docs[docname] = [] | |
f = open(join(article_path,docname)) | |
lines = f.readlines() | |
for text in lines: | |
res = m.parseToNode(text) | |
while res: | |
arr = res.feature.split(",") | |
word = arr[6] | |
docs[docname].append(word) | |
res = res.next | |
dct = gensim.corpora.Dictionary(docs.values()) | |
dct.filter_extremes(no_below=2, no_above=0.1) | |
filtered = dct.token2id.keys() | |
print 'number of features', len(filtered) | |
# for key in filtered: | |
# print key | |
print "# BAG OF WORDS" | |
bow_docs = {} | |
for docname in files: | |
bow_docs[docname] = dct.doc2bow(docs[docname]) | |
print '# LSI Model' | |
lsi_model = gensim.models.LsiModel(bow_docs.values(), num_topics=num_topics) | |
lsi_docs = {} | |
for i, docname in enumerate(files): | |
vec = bow_docs[docname] | |
lsi_docs[i] = lsi_model[vec] | |
print "# TRAIN DATA" | |
def vec2dense(vec, num_terms): | |
return list(gensim.matutils.corpus2dense([vec], num_terms=num_terms).T[0]) | |
data_all = [vec2dense(lsi_docs[i],num_topics) for i, docname in enumerate(files)] | |
label_all = [docname.split("-")[0] for docname in files] | |
data_train_s, data_test_s, label_train_s, label_test_s = train_test_split(data_all, label_all, test_size=0.6) | |
print "# GRID SEARCH" | |
tuned_parameters = [{'n_estimators': [10, 20, 25, 30, 35, 40, 60, 100], 'max_features': ['auto', 'sqrt', 'log2', None]}] | |
clf = GridSearchCV(RandomForestClassifier(), tuned_parameters, cv=2, scoring='accuracy', n_jobs=-1) | |
clf.fit(data_train_s, label_train_s) | |
for params, mean_score, all_scores in clf.grid_scores_: | |
print("{:.3f} (+/- {:.3f}) for {}".format(mean_score, all_scores.std() / 2, params)) | |
print 'best score', clf.best_score_ | |
print 'best estimator', clf.best_estimator_ | |
print "# PREDICTION" | |
prediction_clf = clf.predict(data_test_s) | |
print(classification_report(label_test_s, prediction_clf)) | |
for item in zip(label_test_s, prediction_clf): | |
print item | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment