Created
September 28, 2016 10:43
-
-
Save kanjirz50/85cbcc6ac7ec0bc1270cde1c59055c15 to your computer and use it in GitHub Desktop.
Chainerで多層パーセプトロンを構築し、文書分類するサンプル
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# フィードフォワードニューラルネットワークで文書分類\n", | |
"- [参考サイト](http://qiita.com/ichiroex/items/9aa0bcada0b5bf6f9e1c)\n", | |
"- [参考スクリプト](https://github.com/ichiroex/chainer-ffnn/blob/master/train.py)\n", | |
"- [参考サイト2](http://aidiary.hatenablog.com/entry/20151005/1444051251)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# 標準ライブラリ\n", | |
"import sys\n", | |
"import os\n", | |
"from collections import defaultdict\n", | |
"\n", | |
"# 数値計算、ニューラルネット\n", | |
"import numpy as np\n", | |
"import chainer\n", | |
"import chainer.links as L\n", | |
"import chainer.functions as F\n", | |
"from chainer import optimizers, cuda, serializers\n", | |
"from gensim import corpora, matutils\n", | |
"from tqdm import tqdm_notebook\n", | |
"from sklearn.cross_validation import train_test_split\n", | |
"\n", | |
"# 雪だるま\n", | |
"import snowman_module3\n", | |
"import xml_read" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Print Versions\n", | |
"Python: sys.version_info(major=3, minor=5, micro=1, releaselevel='final', serial=0)\n", | |
"Chainer: 1.7.1\n", | |
"numpy: 1.10.4\n" | |
] | |
} | |
], | |
"source": [ | |
"print(\"Print Versions\")\n", | |
"print(\"Python:\", sys.version_info)\n", | |
"print(\"Chainer:\", chainer.__version__)\n", | |
"print(\"numpy:\", np.__version__)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"code_folding": [], | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# 雪だるま解析器を用いて、名詞を取り出す。\n", | |
"snow = snowman_module3.snowman('/tools/snowman/config/config.ini')\n", | |
"\n", | |
"def pick_noun(sentence, pos='名詞'):\n", | |
" \"\"\"\n", | |
" 雪だるまで入力された文を解析し、対象の品詞を取り出す。\n", | |
" 複合語処理を行わない。\n", | |
"\n", | |
" :param str text: 解析対象の文\n", | |
" :return: 単語のリスト\n", | |
" \"\"\"\n", | |
" words = snow.parse(sentence, noconbine=True)\n", | |
" return [word.org_lemma for word in words if word.pos == pos]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"code_folding": [ | |
0 | |
], | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"# カテゴリ名の設定\n", | |
"categories = ['dokujo-tsushin', 'livedoor-homme', 'smax', 'it-life-hack', 'movie-enter', 'sports-watch', 'kaden-channel', 'peachy', 'topic-news']\n", | |
"\n", | |
"# 文書のxmlからの読み込み\n", | |
"docs = list()\n", | |
"for category in tqdm_notebook(categories):\n", | |
" # ファイル名を作成(./livedoor-news-data/dokujo-tsushin.xml)\n", | |
" filename = category + '.xml'\n", | |
" p = os.path.join('.', 'livedoor-news-data', filename)\n", | |
"\n", | |
" # xmlからあるカテゴリに属する文書の集合を読み込む\n", | |
" c_docs = xml_read.get_docs(p)\n", | |
"\n", | |
" # 文書ごとに名詞のリストを作成する\n", | |
" for c_doc in c_docs:\n", | |
" nouns = list()\n", | |
" for sentence in c_doc.sentences:\n", | |
" nouns.extend(pick_noun(sentence))\n", | |
" # カテゴリのインスタンスに追加\n", | |
" c_doc.nouns = nouns\n", | |
" # 文書に追加\n", | |
" docs.extend(c_docs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# 「カテゴリ:数値」辞書の作成\n", | |
"categories_dic = {\n", | |
" 'dokujo-tsushin':0,\n", | |
" 'livedoor-homme':1,\n", | |
" 'smax':2, \n", | |
" 'it-life-hack':3,\n", | |
" 'movie-enter':4,\n", | |
" 'sports-watch':5, \n", | |
" 'kaden-channel':6, \n", | |
" 'peachy':7, \n", | |
" 'topic-news':8,\n", | |
"}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# 辞書の作成\n", | |
"words = [d.nouns for d in docs]\n", | |
"w_dic = corpora.Dictionary(words)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"# 文書をbow(数値で示された)に変換する。\n", | |
"data = []\n", | |
"target = []\n", | |
"for doc in tqdm_notebook(docs):\n", | |
" tmp = w_dic.doc2bow(doc.nouns)\n", | |
" # 密なnumpy配列に変換\n", | |
" dense = list(matutils.corpus2dense([tmp], num_terms=len(w_dic)).T[0])\n", | |
" \n", | |
" data.append(dense)\n", | |
" target.append(categories_dic[doc.cat])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# データセットをnumpy型に変換\n", | |
"dataset = {}\n", | |
"# BOW\n", | |
"dataset['source'] = np.array(data).astype(np.float32)\n", | |
"# 文書のカテゴリ\n", | |
"dataset['target'] = np.array(target).astype(np.int32)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"語彙数w_dic: 34003\n" | |
] | |
} | |
], | |
"source": [ | |
"print(\"語彙数w_dic:\",len(w_dic.keys()))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# ミニバッチサイズ\n", | |
"batchsize = 128\n", | |
"# エポック数(パラメータの更新回数)\n", | |
"n_epoch = 10\n", | |
"# データを学習データとテストセットに分ける\n", | |
"x_train, x_test, y_train, y_test = train_test_split(dataset['source'], dataset['target'], test_size=0.1, random_state=4)\n", | |
"# テストデータの大きさ\n", | |
"N_test = y_test.size\n", | |
"# 学習データの大きさ\n", | |
"N = len(x_train)\n", | |
"# 入力層のユニット数=語彙数\n", | |
"in_units = x_train.shape[1]\n", | |
"# 隠れ層のユニット数(500)\n", | |
"n_units = 100\n", | |
"# 出力層のユニット数(ラベルの数)\n", | |
"n_label = 9" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# Chainerモデルの定義\n", | |
"model = chainer.Chain(l1=L.Linear(in_units, n_units),\n", | |
" l2=L.Linear(n_units, n_units),\n", | |
" l3=L.Linear(n_units, n_label))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# GPUの利用\n", | |
"cuda.check_cuda_available()\n", | |
"cuda.get_device(0).use()\n", | |
"model.to_gpu()\n", | |
"# cuda行列を用いる\n", | |
"xp = cuda.cupy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def forward(x, t, train=True):\n", | |
" \"\"\"\n", | |
" 順伝搬を行う\n", | |
" \"\"\"\n", | |
" h1 = F.sigmoid(model.l1(x)) # 入力層->隠れ層\n", | |
" h2 = F.sigmoid(model.l2(h1)) # 隠れ層->隠れ層\n", | |
" y = model.l3(h2) # 隠れ層->出力層\n", | |
" return F.softmax_cross_entropy(y, t), F.accuracy(y, t)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# 最適化にはAdamを用いる\n", | |
"optimizer = optimizers.Adam()\n", | |
"optimizer.setup(model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch 1\n", | |
"train mean loss=1.8211338626852942, accuracy=0.5859728503730682\n", | |
"test mean loss=1.3589255775816567, accuracy=0.8168249646229531\n", | |
"epoch 2\n", | |
"train mean loss=0.939631807696226, accuracy=0.8880844646449542\n", | |
"test mean loss=0.6766098602655947, accuracy=0.8955223870892053\n", | |
"epoch 3\n", | |
"train mean loss=0.4159292607106055, accuracy=0.9508295626122487\n", | |
"test mean loss=0.413845614063853, accuracy=0.915875168797766\n", | |
"epoch 4\n", | |
"train mean loss=0.1985766537139139, accuracy=0.9794871795411203\n", | |
"test mean loss=0.3165921981735825, accuracy=0.9240162814164906\n", | |
"epoch 5\n", | |
"train mean loss=0.10457960665900243, accuracy=0.9935143283589394\n", | |
"test mean loss=0.2758631820675476, accuracy=0.9226594293133699\n", | |
"epoch 6\n", | |
"train mean loss=0.0618711051313406, accuracy=0.9975867269984917\n", | |
"test mean loss=0.25518176279042115, accuracy=0.9267299856227321\n", | |
"epoch 7\n", | |
"train mean loss=0.04092924251736649, accuracy=0.998340874811463\n", | |
"test mean loss=0.24679633902241774, accuracy=0.9240162813356159\n", | |
"epoch 8\n", | |
"train mean loss=0.029200922100499925, accuracy=0.9987933634992459\n", | |
"test mean loss=0.2410990115017548, accuracy=0.9294436897480989\n", | |
"epoch 9\n", | |
"train mean loss=0.02206639587137494, accuracy=0.9995475113122172\n", | |
"test mean loss=0.23926810424881034, accuracy=0.9321573939543405\n", | |
"epoch 10\n", | |
"train mean loss=0.017627534009010543, accuracy=0.9993966817496229\n", | |
"test mean loss=0.23763269744055243, accuracy=0.9321573939543405\n" | |
] | |
} | |
], | |
"source": [ | |
"# 学習を行う\n", | |
"for epoch in range(1, n_epoch + 1):\n", | |
" print('epoch', epoch)\n", | |
"\n", | |
" # 学習\n", | |
" # ランダムな整数列リストを取得\n", | |
" perm = np.random.permutation(N)\n", | |
" sum_train_loss = 0.0\n", | |
" sum_train_accuracy = 0.0\n", | |
"\n", | |
" for i in range(0, N, batchsize):\n", | |
" # perm(ランダムな整数列リスト)により、x_train, y_trainからデータセットを選択=毎回学習対象となるデータが異なる\n", | |
" x = chainer.Variable(xp.asarray(x_train[perm[i:i+batchsize]]))\n", | |
" t = chainer.Variable(xp.asarray(y_train[perm[i:i+batchsize]]))\n", | |
"\n", | |
" # 勾配をゼロ初期化\n", | |
" model.zerograds()\n", | |
" # 順伝搬を行う\n", | |
" loss, acc = forward(x, t)\n", | |
" # 平均誤差を計算するために、誤差を蓄積\n", | |
" sum_train_loss += float(cuda.to_cpu(loss.data)) * len(t)\n", | |
" # 平均正解率を計算するために、誤差を蓄積\n", | |
" sum_train_accuracy += float(cuda.to_cpu(acc.data)) * len(t)\n", | |
" # 誤差逆伝搬\n", | |
" loss.backward()\n", | |
" # 最適化\n", | |
" optimizer.update()\n", | |
"\n", | |
" print('train mean loss={}, accuracy={}'.format(sum_train_loss / N, sum_train_accuracy / N))\n", | |
"\n", | |
" # テストデータの誤差および正解率を計算するために準備\n", | |
" sum_test_loss = 0.0\n", | |
" sum_test_accuracy = 0.0\n", | |
"\n", | |
" for i in range(0, N_test, batchsize):\n", | |
" x = chainer.Variable(xp.asarray(x_test[i:i+batchsize]))\n", | |
" t = chainer.Variable(xp.asarray(y_test[i:i+batchsize]))\n", | |
"\n", | |
" loss, acc = forward(x, t, train=False)\n", | |
"\n", | |
" sum_test_loss += float(cuda.to_cpu(loss.data)) * len(t)\n", | |
" sum_test_accuracy += float(cuda.to_cpu(acc.data)) * len(t)\n", | |
"\n", | |
" print('test mean loss={}, accuracy={}'.format(sum_test_loss / N_test, sum_test_accuracy / N_test))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# modelとoptimizerを保存\n", | |
"serializers.save_npz('text_classifier_ffnn.model', model)\n", | |
"serializers.save_npz('text_classifier_ffnn.state', optimizer)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[[ 0. 0. 0. ..., 0. 0. 0.]\n", | |
" [ 0. 0. 0. ..., 0. 0. 0.]\n", | |
" [ 0. 0. 0. ..., 0. 0. 0.]\n", | |
" ..., \n", | |
" [ 0. 0. 0. ..., 0. 0. 0.]\n", | |
" [ 0. 0. 0. ..., 0. 0. 0.]\n", | |
" [ 0. 0. 0. ..., 0. 0. 0.]]\n", | |
"737\n" | |
] | |
} | |
], | |
"source": [ | |
"print(x_test)\n", | |
"print(len(x_test))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[5 5 3 6 2 2 6 4 2 4 4 7 7 3 4 7 3 3 4 3 6 4 3 4 1 2 1 1 0 2 7 3 7 3 5 2 5\n", | |
" 6 1 6 5 4 6 4 1 3 7 6 8 7 7 6 5 7 8 3 8 6 4 7 4 4 6 5 7 8 3 2 5 2 3 3 3 6\n", | |
" 1 4 5 7 6 4 7 0 1 3 2 2 6 0 5 1 0 1 0 8 3 4 8 6 3 0 6 4 5 0 8 7 0 0 8 0 2\n", | |
" 1 7 2 0 5 2 2 6 7 3 3 8 5 4 7 1 1 2 6 0 5 5 5 8 4 7 4 6 0 3 2 8 4 7 7 3 3\n", | |
" 6 7 4 7 3 4 1 3 7 6 0 2 8 2 8 4 4 6 8 1 8 5 7 8 2 7 5 0 2 6 6 6 3 0 7 4 2\n", | |
" 2 0 6 4 0 1 5 4 5 2 2 7 7 0 6 4 6 0 4 7 1 4 7 8 2 6 7 7 0 5 6 8 8 8 8 6 2\n", | |
" 3 3 8 7 3 2 4 2 7 2 7 7 0 8 3 6 4 5 4 0 8 6 0 7 3 7 3 8 7 2 5 6 1 3 2 0 4\n", | |
" 3 8 6 3 3 7 0 3 7 3 6 3 8 3 7 8 6 8 3 4 7 2 6 1 5 5 2 6 4 3 3 8 7 7 4 5 4\n", | |
" 7 8 6 8 7 7 5 6 4 2 7 3 4 5 5 8 4 4 4 2 5 7 3 7 7 7 3 4 8 0 7 3 6 7 0 8 2\n", | |
" 8 4 3 3 5 0 7 8 1 5 3 0 4 4 0 7 0 2 4 6 2 8 6 5 8 8 2 5 1 5 6 0 7 7 6 2 3\n", | |
" 6 6 5 2 0 0 1 6 5 0 7 1 2 6 1 3 8 4 8 8 2 2 2 3 5 0 0 1 5 8 2 3 8 1 6 0 0\n", | |
" 4 5 4 6 4 0 8 8 6 3 4 5 5 6 2 2 0 4 0 2 3 7 3 2 5 6 2 4 4 0 4 3 0 2 8 6 8\n", | |
" 7 2 8 2 3 0 6 7 7 4 5 3 1 0 8 6 5 8 6 3 3 0 4 0 5 4 5 3 4 8 3 0 2 8 1 8 0\n", | |
" 0 6 2 3 3 3 4 5 1 6 0 0 2 3 0 6 4 3 3 0 4 3 8 7 8 3 6 0 8 0 5 8 0 0 1 4 6\n", | |
" 4 6 4 7 2 5 2 7 1 2 4 3 0 0 6 2 7 7 3 2 6 4 4 8 5 1 5 1 3 0 0 7 7 0 3 1 3\n", | |
" 3 8 4 5 1 8 4 2 2 6 0 5 0 4 0 8 8 5 6 4 5 7 1 0 3 5 1 5 6 7 7 4 5 3 6 2 6\n", | |
" 8 4 6 1 2 0 6 5 8 3 2 1 0 2 3 8 1 7 1 0 8 4 6 8 8 5 6 6 6 2 4 3 1 1 6 0 5\n", | |
" 5 5 4 3 7 0 0 6 7 1 7 1 6 2 6 8 8 2 5 7 2 7 0 8 6 0 7 5 3 7 4 7 0 1 5 5 2\n", | |
" 5 6 7 5 2 4 3 4 5 7 6 2 7 2 4 7 2 5 5 8 4 6 7 0 2 5 4 3 2 8 4 5 4 7 6 8 0\n", | |
" 0 3 0 6 0 0 3 3 5 8 5 7 8 2 1 8 0 6 2 0 7 3 0 2 7 0 7 0 5 3 7 5 1 8]\n", | |
"737\n" | |
] | |
} | |
], | |
"source": [ | |
"print(y_test)\n", | |
"print(len(y_test))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def forward_detail(x, t, train=True):\n", | |
" \"\"\"\n", | |
" 順伝搬を行う\n", | |
" \"\"\"\n", | |
" h1 = F.sigmoid(model.l1(x)) # 入力層->隠れ層\n", | |
" h2 = F.sigmoid(model.l2(h1)) # 隠れ層->隠れ層\n", | |
" y = model.l3(h2) # 隠れ層->出力層\n", | |
" return F.softmax_cross_entropy(y, t), F.accuracy(y, t), y" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"for i in range(0, N_test, batchsize):\n", | |
" x = chainer.Variable(xp.asarray(x_test[i:i+batchsize]))\n", | |
" t = chainer.Variable(xp.asarray(y_test[i:i+batchsize]))\n", | |
"\n", | |
" loss, acc, pred = forward_detail(x, t, train=False)\n", | |
" ps = pred.data.tolist()\n", | |
" break" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"正解数:120, 精度:0.9375\n" | |
] | |
} | |
], | |
"source": [ | |
"correct_num = 0\n", | |
"for p, yt in zip(ps, y_test):\n", | |
" m = p.index(max(p))\n", | |
"\n", | |
" if m == yt:\n", | |
" correct_num += 1\n", | |
"print(\"正解数:{}, 精度:{}\".format(correct_num, correct_num / batchsize))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.1" | |
}, | |
"nav_menu": {}, | |
"toc": { | |
"navigate_menu": true, | |
"number_sections": true, | |
"sideBar": true, | |
"threshold": 6, | |
"toc_cell": false, | |
"toc_section_display": "block", | |
"toc_window_display": false | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment