Last active
April 8, 2022 17:02
-
-
Save kanjirz50/1ef6813df7faed3838629a3eea73774b to your computer and use it in GitHub Desktop.
ナイーブベイズ、SVMによる文書分類のサンプル
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# NaiveBayes、SVMによる文書分類\n", | |
"よくあるNaiveBayesやSVBを用いた文書分類を行う。\n", | |
"頻度そのままではなく、TF・IDFを利用する。\n", | |
"名詞を対象とする。\n", | |
"\n", | |
"## livedoorニュースコーパスについて\n", | |
"NHN Japan株式会社が運営する「livedoorニュース」のうち、クリエイティブ・コモンズライセンス「表示ー改変禁止」が適用されるニュース記事を収集し、可能な限りHTMLタグを取り除いて作成されたもの。\n", | |
"\n", | |
"引用\n", | |
"- ロンウイット, livedoorニュースコーパス, http://www.rondhuit.com/download.html#ldcc\n", | |
"\n", | |
"|分類名|URL|文書数|ファイル名| 備考 |\n", | |
"|:---:|:---:|:---:|:---:|:---:|\n", | |
"|トピックニュース|http://news.livedoor.com/category/vender/news/ | 770 | topic-news.xml | 芸能ニュース |\n", | |
"|Sports Watch|http://news.livedoor.com/category/vender/208/ | 900 | sports-watch.xml | スポーツ全般 |\n", | |
"|ITライフハック|http://news.livedoor.com/category/vender/223/ | 870 | it-life-hack.xml | IT系一般 |\n", | |
"|家電チャンネル|http://news.livedoor.com/category/vender/kadench/ | 864 | kaden-channel.xml | 家電? 家庭で使う物?|\n", | |
"|MOVIE ENTER|http://news.livedoor.com/category/vender/movie_enter/ | movie-enter.xml | 870 | 映画や俳優 |\n", | |
"|独女通信|http://news.livedoor.com/category/vender/90/ |870| dokujo-tsushin.xml | 女性向けコラム(どろどろ) |\n", | |
"|エスマックス|http://news.livedoor.com/category/vender/smax/ | 870 | smax.xml | 携帯,通信 |\n", | |
"|livedoor HOMME|http://news.livedoor.com/category/vender/homme/ | 511 | livedoor-homme.xml | ライフスタイル・住居 |\n", | |
"|Peachy|http://news.livedoor.com/category/vender/ldgirls/ | 842 | peachy.xml | 女性向けコラム(ふんわり) |\n", | |
"\n", | |
"## 手法概要\n", | |
"- 1記事を1文書とする。\n", | |
" - 文書は名詞のBag-of-Wordsとする。\n", | |
"- TF・IDFを計算する。\n", | |
"- NaiveBayesにより、分布を学習する。\n", | |
"- 未知の文書を学習したモデルから推定する。\n", | |
"- 問題としては、与えられた文書を9つのうち1つのクラスを推定する問題。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# データの読み込みと整形\n", | |
"livedoorニュースコーパスからテキストを読み込む。\n", | |
"1記事を1文書とし、文書は名詞列から構成されているBag-of-Wordsである。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import xml_read\n", | |
"import os\n", | |
"import random\n", | |
"import snowman_module3" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"code_folding": [], | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# 雪だるま解析器を用いる。\n", | |
"snow = snowman_module3.snowman('/tools/snowman/config/config.ini')\n", | |
"\n", | |
"def pick_noun(text, pos='名詞'):\n", | |
" \"\"\"\n", | |
" 対象の品詞を取り出す。\n", | |
" 複合語処理を行わない。\n", | |
" \"\"\"\n", | |
" words = snow.parse(text, noconbine=True)\n", | |
" return [word.org_lemma for word in words if word.pos == pos]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"code_folding": [], | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# カテゴリ名の設定\n", | |
"categories = ['dokujo-tsushin', 'livedoor-homme', 'smax', 'it-life-hack', 'movie-enter', 'sports-watch', 'kaden-channel', 'peachy', 'topic-news']\n", | |
"\n", | |
"# 文書のxmlからの読み込み\n", | |
"docs = list()\n", | |
"for category in categories:\n", | |
" # ファイル名を作成\n", | |
" filename = category + '.xml'\n", | |
" p = os.path.join('.', 'livedoor-news-data', filename)\n", | |
"\n", | |
" # xmlからファイルを読み込む\n", | |
" c_docs = xml_read.get_docs(p)\n", | |
"\n", | |
" # スペース区切りの名詞列を作成する\n", | |
" for c_doc in c_docs:\n", | |
" nouns = list()\n", | |
" for sentence in c_doc.sentences:\n", | |
" nouns.extend(pick_noun(sentence))\n", | |
" c_doc.nouns = \" \".join(nouns)\n", | |
" docs.extend(c_docs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"code_folding": [], | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# データのシャッフル\n", | |
"random.shuffle(docs)\n", | |
"\n", | |
"# トレーニングとテストに分ける\n", | |
"X_train = docs[:6000]\n", | |
"X_test = docs[6001:]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"code_folding": [], | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# 「カテゴリ:数値」辞書の作成\n", | |
"categories_dic = {\n", | |
" 'dokujo-tsushin':0,\n", | |
" 'livedoor-homme':1,\n", | |
" 'smax':2, \n", | |
" 'it-life-hack':3,\n", | |
" 'movie-enter':4,\n", | |
" 'sports-watch':5, \n", | |
" 'kaden-channel':6, \n", | |
" 'peachy':7, \n", | |
" 'topic-news':8,\n", | |
"}\n", | |
"categories_dic2 = {\n", | |
" 0:'dokujo-tsushin',\n", | |
" 1:'livedoor-homme',\n", | |
" 2:'smax', \n", | |
" 3:'it-life-hack',\n", | |
" 4:'movie-enter',\n", | |
" 5:'sports-watch', \n", | |
" 6:'kaden-channel', \n", | |
" 7:'peachy', \n", | |
" 8:'topic-news',\n", | |
"}\n", | |
"train_cat = np.array([categories_dic.get(doc.cat) for doc in X_train], dtype=np.int64)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 単語の集合を取り出す\n", | |
"scikit-learn用に単語の集合が必用なので、単語の集合を取り出す。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"code_folding": [], | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# 学習データの単語の集合\n", | |
"train_words = [doc.nouns for doc in X_train]\n", | |
"# テストデータの単語の集合\n", | |
"test_words = [doc.nouns for doc in X_test]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"code_folding": [], | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# カテゴリ名をnumpyのarrayで持たせる。\n", | |
"test_cat = np.array([categories_dic.get(doc.cat) for doc in X_test], dtype=np.int64)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# ベースラインの学習と実行\n", | |
"scikit-learnのモジュールを用いて学習する。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.feature_extraction.text import CountVectorizer\n", | |
"from sklearn.feature_extraction.text import TfidfTransformer\n", | |
"from sklearn.naive_bayes import MultinomialNB\n", | |
"from sklearn.pipeline import Pipeline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"code_folding": [], | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# パイプラインの作成\n", | |
"text_clf = Pipeline([\n", | |
" ('vect', CountVectorizer()),\n", | |
" ('tfidf', TfidfTransformer()),\n", | |
" ('clf', MultinomialNB()),\n", | |
"])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"code_folding": [], | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# 学習\n", | |
"text_clf = text_clf.fit(train_words, train_cat)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"code_folding": [], | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# 推定\n", | |
"predicted = text_clf.predict(test_words)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.83528550512445099" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.mean(predicted == test_cat)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# SVM\n", | |
"from sklearn.linear_model import SGDClassifier" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"code_folding": [ | |
0 | |
], | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"text_clf_sgd = Pipeline([\n", | |
" ('vect', CountVectorizer()),\n", | |
" ('tfidf', TfidfTransformer()),\n", | |
" ('clf', SGDClassifier(loss='hinge', penalty='l2', alpha=1e-3, n_iter=5, random_state=42)),\n", | |
"])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"text_clf_sgd = text_clf_sgd.fit(train_words, train_cat)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"sgd_predicted = text_clf_sgd.predict(test_words)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.87920937042459735" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.mean(sgd_predicted == test_cat)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# カテゴリごとの評価\n", | |
"from sklearn import metrics" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" precision recall f1-score support\n", | |
"\n", | |
"dokujo-tsushin 0.71 0.85 0.77 160\n", | |
"livedoor-homme 0.98 0.44 0.60 96\n", | |
" smax 0.80 1.00 0.89 150\n", | |
" it-life-hack 0.94 0.82 0.88 164\n", | |
" movie-enter 0.84 0.98 0.90 155\n", | |
" sports-watch 0.87 0.99 0.93 163\n", | |
" kaden-channel 0.82 0.90 0.86 138\n", | |
" peachy 0.78 0.71 0.74 177\n", | |
" topic-news 0.97 0.71 0.82 163\n", | |
"\n", | |
" avg / total 0.85 0.84 0.83 1366\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"print(metrics.classification_report(test_cat, predicted, target_names=categories))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"_draft": { | |
"nbviewer_url": "https://gist.github.com/d2141e00aca0c7a4e68d1364a8c16d35" | |
}, | |
"gist": { | |
"data": { | |
"description": "Text_Classification_SVM_NaiveBayes.ipynb", | |
"public": true | |
}, | |
"id": "d2141e00aca0c7a4e68d1364a8c16d35" | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.1" | |
}, | |
"nav_menu": {}, | |
"toc": { | |
"navigate_menu": true, | |
"number_sections": true, | |
"sideBar": true, | |
"threshold": 6, | |
"toc_cell": false, | |
"toc_section_display": "block", | |
"toc_window_display": true | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment