Skip to content

Instantly share code, notes, and snippets.

@kanjirz50
Last active April 8, 2022 17:02
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kanjirz50/1ef6813df7faed3838629a3eea73774b to your computer and use it in GitHub Desktop.
Save kanjirz50/1ef6813df7faed3838629a3eea73774b to your computer and use it in GitHub Desktop.
ナイーブベイズ、SVMによる文書分類のサンプル
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# NaiveBayes、SVMによる文書分類\n",
"よくあるNaiveBayesやSVBを用いた文書分類を行う。\n",
"頻度そのままではなく、TF・IDFを利用する。\n",
"名詞を対象とする。\n",
"\n",
"## livedoorニュースコーパスについて\n",
"NHN Japan株式会社が運営する「livedoorニュース」のうち、クリエイティブ・コモンズライセンス「表示ー改変禁止」が適用されるニュース記事を収集し、可能な限りHTMLタグを取り除いて作成されたもの。\n",
"\n",
"引用\n",
"- ロンウイット, livedoorニュースコーパス, http://www.rondhuit.com/download.html#ldcc\n",
"\n",
"|分類名|URL|文書数|ファイル名| 備考 |\n",
"|:---:|:---:|:---:|:---:|:---:|\n",
"|トピックニュース|http://news.livedoor.com/category/vender/news/ | 770 | topic-news.xml | 芸能ニュース |\n",
"|Sports Watch|http://news.livedoor.com/category/vender/208/ | 900 | sports-watch.xml | スポーツ全般 |\n",
"|ITライフハック|http://news.livedoor.com/category/vender/223/ | 870 | it-life-hack.xml | IT系一般 |\n",
"|家電チャンネル|http://news.livedoor.com/category/vender/kadench/ | 864 | kaden-channel.xml | 家電? 家庭で使う物?|\n",
"|MOVIE ENTER|http://news.livedoor.com/category/vender/movie_enter/ | movie-enter.xml | 870 | 映画や俳優 |\n",
"|独女通信|http://news.livedoor.com/category/vender/90/ |870| dokujo-tsushin.xml | 女性向けコラム(どろどろ) |\n",
"|エスマックス|http://news.livedoor.com/category/vender/smax/ | 870 | smax.xml | 携帯,通信 |\n",
"|livedoor HOMME|http://news.livedoor.com/category/vender/homme/ | 511 | livedoor-homme.xml | ライフスタイル・住居 |\n",
"|Peachy|http://news.livedoor.com/category/vender/ldgirls/ | 842 | peachy.xml | 女性向けコラム(ふんわり) |\n",
"\n",
"## 手法概要\n",
"- 1記事を1文書とする。\n",
" - 文書は名詞のBag-of-Wordsとする。\n",
"- TF・IDFを計算する。\n",
"- NaiveBayesにより、分布を学習する。\n",
"- 未知の文書を学習したモデルから推定する。\n",
"- 問題としては、与えられた文書を9つのうち1つのクラスを推定する問題。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# データの読み込みと整形\n",
"livedoorニュースコーパスからテキストを読み込む。\n",
"1記事を1文書とし、文書は名詞列から構成されているBag-of-Wordsである。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import xml_read\n",
"import os\n",
"import random\n",
"import snowman_module3"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"code_folding": [],
"collapsed": true
},
"outputs": [],
"source": [
"# 雪だるま解析器を用いる。\n",
"snow = snowman_module3.snowman('/tools/snowman/config/config.ini')\n",
"\n",
"def pick_noun(text, pos='名詞'):\n",
" \"\"\"\n",
" 対象の品詞を取り出す。\n",
" 複合語処理を行わない。\n",
" \"\"\"\n",
" words = snow.parse(text, noconbine=True)\n",
" return [word.org_lemma for word in words if word.pos == pos]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"code_folding": [],
"collapsed": true
},
"outputs": [],
"source": [
"# カテゴリ名の設定\n",
"categories = ['dokujo-tsushin', 'livedoor-homme', 'smax', 'it-life-hack', 'movie-enter', 'sports-watch', 'kaden-channel', 'peachy', 'topic-news']\n",
"\n",
"# 文書のxmlからの読み込み\n",
"docs = list()\n",
"for category in categories:\n",
" # ファイル名を作成\n",
" filename = category + '.xml'\n",
" p = os.path.join('.', 'livedoor-news-data', filename)\n",
"\n",
" # xmlからファイルを読み込む\n",
" c_docs = xml_read.get_docs(p)\n",
"\n",
" # スペース区切りの名詞列を作成する\n",
" for c_doc in c_docs:\n",
" nouns = list()\n",
" for sentence in c_doc.sentences:\n",
" nouns.extend(pick_noun(sentence))\n",
" c_doc.nouns = \" \".join(nouns)\n",
" docs.extend(c_docs)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"code_folding": [],
"collapsed": true
},
"outputs": [],
"source": [
"# データのシャッフル\n",
"random.shuffle(docs)\n",
"\n",
"# トレーニングとテストに分ける\n",
"X_train = docs[:6000]\n",
"X_test = docs[6001:]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"code_folding": [],
"collapsed": false
},
"outputs": [],
"source": [
"# 「カテゴリ:数値」辞書の作成\n",
"categories_dic = {\n",
" 'dokujo-tsushin':0,\n",
" 'livedoor-homme':1,\n",
" 'smax':2, \n",
" 'it-life-hack':3,\n",
" 'movie-enter':4,\n",
" 'sports-watch':5, \n",
" 'kaden-channel':6, \n",
" 'peachy':7, \n",
" 'topic-news':8,\n",
"}\n",
"categories_dic2 = {\n",
" 0:'dokujo-tsushin',\n",
" 1:'livedoor-homme',\n",
" 2:'smax', \n",
" 3:'it-life-hack',\n",
" 4:'movie-enter',\n",
" 5:'sports-watch', \n",
" 6:'kaden-channel', \n",
" 7:'peachy', \n",
" 8:'topic-news',\n",
"}\n",
"train_cat = np.array([categories_dic.get(doc.cat) for doc in X_train], dtype=np.int64)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 単語の集合を取り出す\n",
"scikit-learn用に単語の集合が必用なので、単語の集合を取り出す。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"code_folding": [],
"collapsed": true
},
"outputs": [],
"source": [
"# 学習データの単語の集合\n",
"train_words = [doc.nouns for doc in X_train]\n",
"# テストデータの単語の集合\n",
"test_words = [doc.nouns for doc in X_test]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"code_folding": [],
"collapsed": false
},
"outputs": [],
"source": [
"# カテゴリ名をnumpyのarrayで持たせる。\n",
"test_cat = np.array([categories_dic.get(doc.cat) for doc in X_test], dtype=np.int64)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ベースラインの学習と実行\n",
"scikit-learnのモジュールを用いて学習する。"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.feature_extraction.text import CountVectorizer\n",
"from sklearn.feature_extraction.text import TfidfTransformer\n",
"from sklearn.naive_bayes import MultinomialNB\n",
"from sklearn.pipeline import Pipeline"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"code_folding": [],
"collapsed": true
},
"outputs": [],
"source": [
"# パイプラインの作成\n",
"text_clf = Pipeline([\n",
" ('vect', CountVectorizer()),\n",
" ('tfidf', TfidfTransformer()),\n",
" ('clf', MultinomialNB()),\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"code_folding": [],
"collapsed": true
},
"outputs": [],
"source": [
"# 学習\n",
"text_clf = text_clf.fit(train_words, train_cat)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"code_folding": [],
"collapsed": true
},
"outputs": [],
"source": [
"# 推定\n",
"predicted = text_clf.predict(test_words)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"0.83528550512445099"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean(predicted == test_cat)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# SVM\n",
"from sklearn.linear_model import SGDClassifier"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"code_folding": [
0
],
"collapsed": true
},
"outputs": [],
"source": [
"text_clf_sgd = Pipeline([\n",
" ('vect', CountVectorizer()),\n",
" ('tfidf', TfidfTransformer()),\n",
" ('clf', SGDClassifier(loss='hinge', penalty='l2', alpha=1e-3, n_iter=5, random_state=42)),\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"text_clf_sgd = text_clf_sgd.fit(train_words, train_cat)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sgd_predicted = text_clf_sgd.predict(test_words)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"0.87920937042459735"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean(sgd_predicted == test_cat)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# カテゴリごとの評価\n",
"from sklearn import metrics"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
"dokujo-tsushin 0.71 0.85 0.77 160\n",
"livedoor-homme 0.98 0.44 0.60 96\n",
" smax 0.80 1.00 0.89 150\n",
" it-life-hack 0.94 0.82 0.88 164\n",
" movie-enter 0.84 0.98 0.90 155\n",
" sports-watch 0.87 0.99 0.93 163\n",
" kaden-channel 0.82 0.90 0.86 138\n",
" peachy 0.78 0.71 0.74 177\n",
" topic-news 0.97 0.71 0.82 163\n",
"\n",
" avg / total 0.85 0.84 0.83 1366\n",
"\n"
]
}
],
"source": [
"print(metrics.classification_report(test_cat, predicted, target_names=categories))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"_draft": {
"nbviewer_url": "https://gist.github.com/d2141e00aca0c7a4e68d1364a8c16d35"
},
"gist": {
"data": {
"description": "Text_Classification_SVM_NaiveBayes.ipynb",
"public": true
},
"id": "d2141e00aca0c7a4e68d1364a8c16d35"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.1"
},
"nav_menu": {},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": true
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment