Last active
January 2, 2016 06:02
-
-
Save satzz/1fc75314652e1a99fe4c to your computer and use it in GitHub Desktop.
MeCab, gensim, scikit-learnでニュース記事の分類 ref: http://qiita.com/satzz/items/69beb439ed440d459585
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import os | |
from os.path import join, dirname | |
import MeCab | |
import gensim | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.grid_search import GridSearchCV | |
from sklearn.cross_validation import train_test_split | |
from sklearn.metrics import classification_report | |
labels = [ | |
'birth', # 5年ぶり出生率増加 | |
'ekiden', # ニューイヤー駅伝関連 | |
'tunnel', # 君津のトンネルのモルタル剥離 | |
'ikukyu', # 国会議員の育休取得 | |
'fe', # 日本食品標準成分表のひじきの鉄分含有量修正 | |
'takahama', # 高浜原発関連 | |
'thief', # キングオブコメディの高橋健一逮捕 | |
'starwars', # スターウォーズ(フォースの覚醒)関連 | |
'design', # 国立競技場のデザイン関連 | |
'riken', # 理研の新元素命名権獲得 | |
] | |
num_topics = len(labels) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
m = MeCab.Tagger('') | |
article_path = join(dirname(__file__), 'articles') | |
for root, dirs, files in os.walk(article_path): | |
print '# MORPHOLOGICAL ANALYSIS' | |
docs = {} | |
for docname in files: | |
docs[docname] = [] | |
f = open(join(article_path,docname)) | |
lines = f.readlines() | |
for text in lines: | |
res = m.parseToNode(text) | |
while res: | |
arr = res.feature.split(",") | |
word = arr[6] | |
docs[docname].append(word) | |
res = res.next | |
dct = gensim.corpora.Dictionary(docs.values()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
dct.filter_extremes(no_below=2, no_above=0.1) | |
filtered = dct.token2id.keys() | |
print 'number of features', len(filtered) | |
# for key in filtered: | |
# print key |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
number of features 1564 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
print "# BAG OF WORDS" | |
bow_docs = {} | |
for docname in files: | |
bow_docs[docname] = dct.doc2bow(docs[docname]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
print '# LSI Model' | |
lsi_model = gensim.models.LsiModel(bow_docs.values(), num_topics=num_topics) | |
lsi_docs = {} | |
for i, docname in enumerate(files): | |
vec = bow_docs[docname] | |
lsi_docs[i] = lsi_model[vec] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
print "# TRAIN DATA" | |
def vec2dense(vec, num_terms): | |
return list(gensim.matutils.corpus2dense([vec], num_terms=num_terms).T[0]) | |
data_all = [vec2dense(lsi_docs[i],num_topics) for i, docname in enumerate(files)] | |
label_all = [docname.split("-")[0] for docname in files] | |
data_train_s, data_test_s, label_train_s, label_test_s = train_test_split(data_all, label_all, test_size=0.6) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
print "# GRID SEARCH" | |
tuned_parameters = [{'n_estimators': [10, 20, 25, 30, 35, 40, 60, 100], 'max_features': ['auto', 'sqrt', 'log2', None]}] | |
clf = GridSearchCV(RandomForestClassifier(), tuned_parameters, cv=2, scoring='accuracy', n_jobs=-1) | |
clf.fit(data_train_s, label_train_s) | |
for params, mean_score, all_scores in clf.grid_scores_: | |
print("{:.3f} (+/- {:.3f}) for {}".format(mean_score, all_scores.std() / 2, params)) | |
print 'best score', clf.best_score_ | |
print 'best estimator', clf.best_estimator_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
print "# PREDICTION" | |
prediction_clf = clf.predict(data_test_s) | |
print(classification_report(label_test_s, prediction_clf)) | |
for item in zip(label_test_s, prediction_clf): | |
print item |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# MORPHOLOGICAL ANALYSIS | |
number of features 1564 | |
# BAG OF WORDS | |
# LSI Model | |
# TRAIN DATA | |
# GRID SEARCH | |
0.531 (+/- 0.035) for {'max_features': 'auto', 'n_estimators': 10} | |
0.594 (+/- 0.009) for {'max_features': 'auto', 'n_estimators': 20} | |
0.531 (+/- 0.068) for {'max_features': 'auto', 'n_estimators': 25} | |
0.594 (+/- 0.074) for {'max_features': 'auto', 'n_estimators': 30} | |
0.625 (+/- 0.093) for {'max_features': 'auto', 'n_estimators': 35} | |
0.625 (+/- 0.061) for {'max_features': 'auto', 'n_estimators': 40} | |
0.625 (+/- 0.061) for {'max_features': 'auto', 'n_estimators': 60} | |
0.719 (+/- 0.021) for {'max_features': 'auto', 'n_estimators': 100} | |
0.500 (+/- 0.049) for {'max_features': 'sqrt', 'n_estimators': 10} | |
0.594 (+/- 0.041) for {'max_features': 'sqrt', 'n_estimators': 20} | |
0.594 (+/- 0.009) for {'max_features': 'sqrt', 'n_estimators': 25} | |
0.531 (+/- 0.100) for {'max_features': 'sqrt', 'n_estimators': 30} | |
0.531 (+/- 0.035) for {'max_features': 'sqrt', 'n_estimators': 35} | |
0.594 (+/- 0.074) for {'max_features': 'sqrt', 'n_estimators': 40} | |
0.656 (+/- 0.080) for {'max_features': 'sqrt', 'n_estimators': 60} | |
0.656 (+/- 0.048) for {'max_features': 'sqrt', 'n_estimators': 100} | |
0.562 (+/- 0.022) for {'max_features': 'log2', 'n_estimators': 10} | |
0.625 (+/- 0.061) for {'max_features': 'log2', 'n_estimators': 20} | |
0.656 (+/- 0.080) for {'max_features': 'log2', 'n_estimators': 25} | |
0.688 (+/- 0.067) for {'max_features': 'log2', 'n_estimators': 30} | |
0.500 (+/- 0.081) for {'max_features': 'log2', 'n_estimators': 35} | |
0.625 (+/- 0.061) for {'max_features': 'log2', 'n_estimators': 40} | |
0.531 (+/- 0.068) for {'max_features': 'log2', 'n_estimators': 60} | |
0.688 (+/- 0.067) for {'max_features': 'log2', 'n_estimators': 100} | |
0.531 (+/- 0.100) for {'max_features': None, 'n_estimators': 10} | |
0.594 (+/- 0.106) for {'max_features': None, 'n_estimators': 20} | |
0.719 (+/- 0.054) for {'max_features': None, 'n_estimators': 25} | |
0.656 (+/- 0.080) for {'max_features': None, 'n_estimators': 30} | |
0.625 (+/- 0.093) for {'max_features': None, 'n_estimators': 35} | |
0.625 (+/- 0.093) for {'max_features': None, 'n_estimators': 40} | |
0.500 (+/- 0.081) for {'max_features': None, 'n_estimators': 60} | |
0.625 (+/- 0.093) for {'max_features': None, 'n_estimators': 100} | |
best score 0.71875 | |
best estimator RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini', | |
max_depth=None, max_features='auto', max_leaf_nodes=None, | |
min_samples_leaf=1, min_samples_split=2, | |
min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=1, | |
oob_score=False, random_state=None, verbose=0, | |
warm_start=False) | |
# PREDICTION | |
precision recall f1-score support | |
birth 1.00 1.00 1.00 5 | |
design 1.00 1.00 1.00 4 | |
ekiden 0.60 0.50 0.55 6 | |
fe 1.00 1.00 1.00 5 | |
ikukyu 0.00 0.00 0.00 7 | |
riken 0.44 1.00 0.62 4 | |
starwars 0.20 0.50 0.29 2 | |
takahama 1.00 0.60 0.75 5 | |
thief 0.62 1.00 0.77 5 | |
tunnel 1.00 0.80 0.89 5 | |
avg / total 0.69 0.71 0.67 48 | |
('fe', 'fe') | |
('thief', 'thief') | |
('tunnel', 'tunnel') | |
('design', 'design') | |
('takahama', 'takahama') | |
('birth', 'birth') | |
('ekiden', 'ekiden') | |
('ekiden', 'starwars') | |
('ikukyu', 'riken') | |
('birth', 'birth') | |
('takahama', 'takahama') | |
('ekiden', 'ekiden') | |
('riken', 'riken') | |
('birth', 'birth') | |
('thief', 'thief') | |
('riken', 'riken') | |
('ekiden', 'thief') | |
('fe', 'fe') | |
('takahama', 'starwars') | |
('starwars', 'thief') | |
('ekiden', 'ekiden') | |
('fe', 'fe') | |
('tunnel', 'tunnel') | |
('riken', 'riken') | |
('tunnel', 'starwars') | |
('ikukyu', 'ekiden') | |
('tunnel', 'tunnel') | |
('takahama', 'takahama') | |
('ikukyu', 'riken') | |
('ikukyu', 'riken') | |
('tunnel', 'tunnel') | |
('ikukyu', 'ekiden') | |
('design', 'design') | |
('thief', 'thief') | |
('takahama', 'starwars') | |
('riken', 'riken') | |
('birth', 'birth') | |
('starwars', 'starwars') | |
('ikukyu', 'riken') | |
('design', 'design') | |
('ikukyu', 'riken') | |
('thief', 'thief') | |
('fe', 'fe') | |
('fe', 'fe') | |
('thief', 'thief') | |
('ekiden', 'thief') | |
('birth', 'birth') | |
('design', 'design') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment