Skip to content

Instantly share code, notes, and snippets.

@satzz
Last active January 2, 2016 06:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save satzz/1fc75314652e1a99fe4c to your computer and use it in GitHub Desktop.
Save satzz/1fc75314652e1a99fe4c to your computer and use it in GitHub Desktop.
MeCab, gensim, scikit-learnでニュース記事の分類 ref: http://qiita.com/satzz/items/69beb439ed440d459585
# -*- coding: utf-8 -*-
import os
from os.path import join, dirname
import MeCab
import gensim
from sklearn.ensemble import RandomForestClassifier
from sklearn.grid_search import GridSearchCV
from sklearn.cross_validation import train_test_split
from sklearn.metrics import classification_report
labels = [
'birth', # 5年ぶり出生率増加
'ekiden', # ニューイヤー駅伝関連
'tunnel', # 君津のトンネルのモルタル剥離
'ikukyu', # 国会議員の育休取得
'fe', # 日本食品標準成分表のひじきの鉄分含有量修正
'takahama', # 高浜原発関連
'thief', # キングオブコメディの高橋健一逮捕
'starwars', # スターウォーズ(フォースの覚醒)関連
'design', # 国立競技場のデザイン関連
'riken', # 理研の新元素命名権獲得
]
num_topics = len(labels)
m = MeCab.Tagger('')
article_path = join(dirname(__file__), 'articles')
for root, dirs, files in os.walk(article_path):
print '# MORPHOLOGICAL ANALYSIS'
docs = {}
for docname in files:
docs[docname] = []
f = open(join(article_path,docname))
lines = f.readlines()
for text in lines:
res = m.parseToNode(text)
while res:
arr = res.feature.split(",")
word = arr[6]
docs[docname].append(word)
res = res.next
dct = gensim.corpora.Dictionary(docs.values())
dct.filter_extremes(no_below=2, no_above=0.1)
filtered = dct.token2id.keys()
print 'number of features', len(filtered)
# for key in filtered:
# print key
number of features 1564
print "# BAG OF WORDS"
bow_docs = {}
for docname in files:
bow_docs[docname] = dct.doc2bow(docs[docname])
print '# LSI Model'
lsi_model = gensim.models.LsiModel(bow_docs.values(), num_topics=num_topics)
lsi_docs = {}
for i, docname in enumerate(files):
vec = bow_docs[docname]
lsi_docs[i] = lsi_model[vec]
print "# TRAIN DATA"
def vec2dense(vec, num_terms):
return list(gensim.matutils.corpus2dense([vec], num_terms=num_terms).T[0])
data_all = [vec2dense(lsi_docs[i],num_topics) for i, docname in enumerate(files)]
label_all = [docname.split("-")[0] for docname in files]
data_train_s, data_test_s, label_train_s, label_test_s = train_test_split(data_all, label_all, test_size=0.6)
print "# GRID SEARCH"
tuned_parameters = [{'n_estimators': [10, 20, 25, 30, 35, 40, 60, 100], 'max_features': ['auto', 'sqrt', 'log2', None]}]
clf = GridSearchCV(RandomForestClassifier(), tuned_parameters, cv=2, scoring='accuracy', n_jobs=-1)
clf.fit(data_train_s, label_train_s)
for params, mean_score, all_scores in clf.grid_scores_:
print("{:.3f} (+/- {:.3f}) for {}".format(mean_score, all_scores.std() / 2, params))
print 'best score', clf.best_score_
print 'best estimator', clf.best_estimator_
print "# PREDICTION"
prediction_clf = clf.predict(data_test_s)
print(classification_report(label_test_s, prediction_clf))
for item in zip(label_test_s, prediction_clf):
print item
# MORPHOLOGICAL ANALYSIS
number of features 1564
# BAG OF WORDS
# LSI Model
# TRAIN DATA
# GRID SEARCH
0.531 (+/- 0.035) for {'max_features': 'auto', 'n_estimators': 10}
0.594 (+/- 0.009) for {'max_features': 'auto', 'n_estimators': 20}
0.531 (+/- 0.068) for {'max_features': 'auto', 'n_estimators': 25}
0.594 (+/- 0.074) for {'max_features': 'auto', 'n_estimators': 30}
0.625 (+/- 0.093) for {'max_features': 'auto', 'n_estimators': 35}
0.625 (+/- 0.061) for {'max_features': 'auto', 'n_estimators': 40}
0.625 (+/- 0.061) for {'max_features': 'auto', 'n_estimators': 60}
0.719 (+/- 0.021) for {'max_features': 'auto', 'n_estimators': 100}
0.500 (+/- 0.049) for {'max_features': 'sqrt', 'n_estimators': 10}
0.594 (+/- 0.041) for {'max_features': 'sqrt', 'n_estimators': 20}
0.594 (+/- 0.009) for {'max_features': 'sqrt', 'n_estimators': 25}
0.531 (+/- 0.100) for {'max_features': 'sqrt', 'n_estimators': 30}
0.531 (+/- 0.035) for {'max_features': 'sqrt', 'n_estimators': 35}
0.594 (+/- 0.074) for {'max_features': 'sqrt', 'n_estimators': 40}
0.656 (+/- 0.080) for {'max_features': 'sqrt', 'n_estimators': 60}
0.656 (+/- 0.048) for {'max_features': 'sqrt', 'n_estimators': 100}
0.562 (+/- 0.022) for {'max_features': 'log2', 'n_estimators': 10}
0.625 (+/- 0.061) for {'max_features': 'log2', 'n_estimators': 20}
0.656 (+/- 0.080) for {'max_features': 'log2', 'n_estimators': 25}
0.688 (+/- 0.067) for {'max_features': 'log2', 'n_estimators': 30}
0.500 (+/- 0.081) for {'max_features': 'log2', 'n_estimators': 35}
0.625 (+/- 0.061) for {'max_features': 'log2', 'n_estimators': 40}
0.531 (+/- 0.068) for {'max_features': 'log2', 'n_estimators': 60}
0.688 (+/- 0.067) for {'max_features': 'log2', 'n_estimators': 100}
0.531 (+/- 0.100) for {'max_features': None, 'n_estimators': 10}
0.594 (+/- 0.106) for {'max_features': None, 'n_estimators': 20}
0.719 (+/- 0.054) for {'max_features': None, 'n_estimators': 25}
0.656 (+/- 0.080) for {'max_features': None, 'n_estimators': 30}
0.625 (+/- 0.093) for {'max_features': None, 'n_estimators': 35}
0.625 (+/- 0.093) for {'max_features': None, 'n_estimators': 40}
0.500 (+/- 0.081) for {'max_features': None, 'n_estimators': 60}
0.625 (+/- 0.093) for {'max_features': None, 'n_estimators': 100}
best score 0.71875
best estimator RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
max_depth=None, max_features='auto', max_leaf_nodes=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=1,
oob_score=False, random_state=None, verbose=0,
warm_start=False)
# PREDICTION
precision recall f1-score support
birth 1.00 1.00 1.00 5
design 1.00 1.00 1.00 4
ekiden 0.60 0.50 0.55 6
fe 1.00 1.00 1.00 5
ikukyu 0.00 0.00 0.00 7
riken 0.44 1.00 0.62 4
starwars 0.20 0.50 0.29 2
takahama 1.00 0.60 0.75 5
thief 0.62 1.00 0.77 5
tunnel 1.00 0.80 0.89 5
avg / total 0.69 0.71 0.67 48
('fe', 'fe')
('thief', 'thief')
('tunnel', 'tunnel')
('design', 'design')
('takahama', 'takahama')
('birth', 'birth')
('ekiden', 'ekiden')
('ekiden', 'starwars')
('ikukyu', 'riken')
('birth', 'birth')
('takahama', 'takahama')
('ekiden', 'ekiden')
('riken', 'riken')
('birth', 'birth')
('thief', 'thief')
('riken', 'riken')
('ekiden', 'thief')
('fe', 'fe')
('takahama', 'starwars')
('starwars', 'thief')
('ekiden', 'ekiden')
('fe', 'fe')
('tunnel', 'tunnel')
('riken', 'riken')
('tunnel', 'starwars')
('ikukyu', 'ekiden')
('tunnel', 'tunnel')
('takahama', 'takahama')
('ikukyu', 'riken')
('ikukyu', 'riken')
('tunnel', 'tunnel')
('ikukyu', 'ekiden')
('design', 'design')
('thief', 'thief')
('takahama', 'starwars')
('riken', 'riken')
('birth', 'birth')
('starwars', 'starwars')
('ikukyu', 'riken')
('design', 'design')
('ikukyu', 'riken')
('thief', 'thief')
('fe', 'fe')
('fe', 'fe')
('thief', 'thief')
('ekiden', 'thief')
('birth', 'birth')
('design', 'design')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment