-
-
Save Salinger/aae3c73c2245870bf1c7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
#-*- coding:utf-8 -*- | |
### 使用するライブラリ | |
import MySQLdb | |
import pandas.io.sql as psql | |
import pandas as pd | |
import numpy as np | |
import MeCab | |
from sklearn import svm | |
from sklearn.grid_search import GridSearchCV | |
from sklearn.feature_extraction.text import CountVectorizer | |
### 対象ドキュメント取得 | |
print "[INFO] ドキュメント取得" | |
tweets = psql.read_sql( | |
"SELECT text, label FROM pn_tweet", | |
MySQLdb.connect( | |
host = "YOUR DB HOST", | |
user = "YOUR DB USER", | |
passwd = "YOUR DB PASSWORD", | |
db = "textdata", | |
charset = 'utf8' | |
) | |
) | |
### 形態素に分割 & 名詞・動詞(終止形)・形容詞のみ抽出 | |
print "[INFO] 分かち書き" | |
def wakati(text): | |
tagger = MeCab.Tagger() | |
text = text.encode("utf-8") | |
node = tagger.parseToNode(text) | |
word_list = [] | |
while node: | |
pos = node.feature.split(",")[0] | |
if pos in ["名詞", "動詞", "形容詞"]: | |
lemma = node.feature.split(",")[6].decode("utf-8") | |
if lemma == u"*": | |
lemma = node.surface.decode("utf-8") | |
word_list.append(lemma) | |
node = node.next | |
return u" ".join(word_list[1:-1]) | |
tweets['wakati'] = tweets['text'].apply(wakati) | |
### BoW (Term Frequency) 素性ベクトルへ変換 | |
# feature_vectors: scipy.sparse の csr_matrix 形式 | |
# vocabulary: 列要素(単語) 名 | |
print "[INFO] 素性ベクトル作成" | |
count_vectorizer = CountVectorizer() | |
feature_vectors = count_vectorizer.fit_transform(tweets['wakati']) | |
vocabulary = count_vectorizer.get_feature_names() | |
### SVM による学習 | |
print "[INFO] SVM (グリッドサーチ)" | |
svm_tuned_parameters = [ | |
{ | |
'kernel': ['rbf'], | |
'gamma': [2**n for n in range(-15, 3)], | |
'C': [2**n for n in range(-5, 15)] | |
} | |
] | |
gscv = GridSearchCV( | |
svm.SVC(), | |
svm_tuned_parameters, | |
cv=5, | |
n_jobs=1, | |
verbose=3 | |
) | |
gscv.fit(feature_vectors, list(tweets['label'])) | |
svm_model = gscv.best_estimator_ | |
print svm_model | |
# SVC(C=64, cache_size=200, class_weight=None, coef0=0.0, | |
# decision_function_shape=None, degree=3, gamma=0.0001220703125, | |
# kernel='rbf', max_iter=-1, probability=False, random_state=None, | |
# shrinking=True, tol=0.001, verbose=False) | |
### SVM による分類 | |
print "[INFO] SVM (分類)" | |
sample_text = pd.Series([ | |
u"無免許運転をネット中継 逮捕 - Y!ニュース news.yahoo.co.jp/pickup/...", | |
u"田舎特有のいじめが原因かな……複数殺人および未遂って尋常じゃない恨みだろ | Reading:兵庫県洲本市で男女5人刺される 3人死亡 NHKニュース", | |
u"BABYMETAL、CDショップ大賞おめでとうございます。これからも沢山の方がBABYMETALに触れる事でしょうね。音楽ってこんなにも楽しいって教えられましたもん。 ", | |
u"タカ丸さんかわいいな~" | |
]) | |
split_sample_text = sample_text.apply(wakati) | |
count_vectorizer = CountVectorizer( | |
vocabulary=vocabulary # 学習時の vocabulary を指定する | |
) | |
feature_vectors = count_vectorizer.fit_transform(split_sample_text) | |
print svm_model.predict(feature_vectors) | |
# [0, 0, 1, 1] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
素晴らしいコードを提供してありがとうございます!私は OS X El CapitanでPython 3を使って、
def wakati(text)
の中にtagger.parse('')
が入れなければ utf-8のエーラが起こり、node.surface
が取得できません。(同じ問題みたい)