Skip to content

Instantly share code, notes, and snippets.

@yusuke0519
Created November 16, 2015 09:45
Show Gist options
  • Save yusuke0519/92eb3df2bb5988412b65 to your computer and use it in GitHub Desktop.
Save yusuke0519/92eb3df2bb5988412b65 to your computer and use it in GitHub Desktop.
theano, logistic regression
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Theanoで多クラスロジスティック回帰\n",
"\n",
"* ロジスティック回帰を実装する\n",
"* ロジスティック回帰は次のように定式化される\n",
"$$\n",
"\\newcommand{\\argmax}{\\mathop{\\rm arg~max}\\limits}\n",
"\\argmax_{i} P(Y=i|x, W, b) = softmax(Wx+b) = \\frac{\\exp^{W_ix + b_i}}{\\sum_j \\exp^{W_jx + b_j}}\n",
"$$\n",
"* これを実現するには、mini-batch SGDなどによる最適化を行う"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ロジスティック回帰の実装に必要なもの\n",
"0. パラメタの初期化\n",
"1. 事後確率P(Y|x, W, b)の計算\n",
"2. コスト関数の定義\n",
"3. コスト関数を最適にするための設定"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import cPickle\n",
"import gzip\n",
"import os\n",
"import sys\n",
"import timeit\n",
"\n",
"import numpy as np\n",
"\n",
"import theano\n",
"import theano.tensor as T"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class LogisticRegression(object):\n",
" def __init__(self, input, n_in, n_out):\n",
" \"\"\" ロジスティック回帰の初期化\n",
" Input:\n",
" input: 入力変数(バッチ)に対応するシンボル変数\n",
" n_in: 入力次元(MNISTの場合だと32 * 32)\n",
" n_out: 出力次元(MNISTの場合だと10)\n",
" \"\"\"\n",
" # パラメタをshared variableで初期化\n",
" # 1. W\n",
" self.W = theano.shared(\n",
" value = np.zeros(\n",
" (n_in, n_out), \n",
" dtype=theano.config.floatX\n",
" ), \n",
" name='W',\n",
" borrow=True\n",
" )\n",
" # 2. b\n",
" self.b = theano.shared(\n",
" value = np.zeros(\n",
" (n_out, ), \n",
" dtype=theano.config.floatX\n",
" ), \n",
" name='b',\n",
" borrow=True\n",
" )\n",
" \n",
" # 事後確率をtheanoの関数として定義\n",
" self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W) + self.b)\n",
"\n",
" # 事後確率を最大にするyを選択\n",
" self.y_pred = T.argmax(self.p_y_given_x, axis=1)\n",
"\n",
" # SGDで更新するパラメタのリストを作成\n",
" self.params = [self.W, self.b]\n",
" \n",
" # 入力シンボルを記録\n",
" self.input = input\n",
" \n",
" def negative_log_likelihood(self, y):\n",
" \"\"\" 出力シンボルyが与えられた時に負の対数尤度を返すtheano関数を戻す\n",
" Input:\n",
" y: 出力ラベルに関するtheanoシンボル\n",
" Output:\n",
" 負の対数尤度を返すtheano関数\n",
" \"\"\"\n",
" return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]), y]) # いまいちピンときてない\n",
" \n",
" def errors(self, y):\n",
" if y.ndim != self.y_pred.ndim:\n",
" raise TypeError('y should have the same shape as self.y_pred')\n",
" if y.dtype.startswith('int'):\n",
" return T.mean(T.neq(self.y_pred, y)) # 間違ってる割合"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### データセット読み込み\n",
"* MNISTのデータセットを利用"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def load_data(dataset):\n",
" ''' Loads the dataset\n",
"\n",
" :type dataset: string\n",
" :param dataset: the path to the dataset (here MNIST)\n",
" '''\n",
"\n",
" #############\n",
" # LOAD DATA #\n",
" #############\n",
"\n",
" # Download the MNIST dataset if it is not present\n",
" data_dir, data_file = os.path.split(dataset)\n",
" if data_dir == \"\" and not os.path.isfile(dataset):\n",
" # Check if dataset is in the data directory.\n",
" new_path = os.path.join(\n",
" os.path.split('__file__')[0],\n",
" \"..\",\n",
" \"data\",\n",
" dataset\n",
" )\n",
" if os.path.isfile(new_path) or data_file == 'mnist.pkl.gz':\n",
" dataset = new_path\n",
"\n",
" if (not os.path.isfile(dataset)) and data_file == 'mnist.pkl.gz':\n",
" import urllib\n",
" origin = (\n",
" 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz'\n",
" )\n",
" print 'Downloading data from %s' % origin\n",
" urllib.urlretrieve(origin, dataset)\n",
"\n",
" print '... loading data'\n",
"\n",
" # Load the dataset\n",
" f = gzip.open(dataset, 'rb')\n",
" train_set, valid_set, test_set = cPickle.load(f)\n",
" f.close()\n",
" #train_set, valid_set, test_set format: tuple(input, target)\n",
" #input is an numpy.ndarray of 2 dimensions (a matrix)\n",
" #witch row's correspond to an example. target is a\n",
" #numpy.ndarray of 1 dimensions (vector)) that have the same length as\n",
" #the number of rows in the input. It should give the target\n",
" #target to the example with the same index in the input.\n",
"\n",
" def shared_dataset(data_xy, borrow=True):\n",
" # Sharedにしておかないとかなり遅くなる(GPUの場合)ので注意\n",
" # GPUのメモリにデータを移すコストがかなり大きいため\n",
" data_x, data_y = data_xy\n",
" shared_x = theano.shared(np.asarray(data_x,\n",
" dtype=theano.config.floatX),\n",
" borrow=borrow)\n",
" shared_y = theano.shared(np.asarray(data_y,\n",
" dtype=theano.config.floatX),\n",
" borrow=borrow)\n",
" # When storing data on the GPU it has to be stored as floats\n",
" # therefore we will store the labels as ``floatX`` as well\n",
" # (``shared_y`` does exactly that). But during our computations\n",
" # we need them as ints (we use labels as index, and if they are\n",
" # floats it doesn't make sense) therefore instead of returning\n",
" # ``shared_y`` we will have to cast it to int. This little hack\n",
" # lets ous get around this issue\n",
" return shared_x, T.cast(shared_y, 'int32')\n",
"\n",
" test_set_x, test_set_y = shared_dataset(test_set)\n",
" valid_set_x, valid_set_y = shared_dataset(valid_set)\n",
" train_set_x, train_set_y = shared_dataset(train_set)\n",
"\n",
" rval = [(train_set_x, train_set_y), (valid_set_x, valid_set_y),\n",
" (test_set_x, test_set_y)]\n",
" return rval"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"... loading data\n"
]
}
],
"source": [
"datasets = load_data('mnist.pkl.gz')"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Train, Validation, Testに分割\n",
"train_set_x, train_set_y = datasets[0]\n",
"valid_set_x, valid_set_y = datasets[1]\n",
"test_set_x, test_set_y = datasets[2]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SGD\n",
"* 定義されたコスト関数をパラメタについて偏微分してパラメタを更新"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1, minibatch 83/83, validation error 12.458333 %\n",
" epoch 1, minibatch 83/83, test error of best model 12.375000 %\n",
"The code run for 1 epochs with 15.865700 epochs/sec\n",
"epoch 2, minibatch 83/83, validation error 11.010417 %\n",
" epoch 2, minibatch 83/83, test error of best model 10.958333 %\n",
"The code run for 2 epochs with 15.949228 epochs/sec\n",
"epoch 3, minibatch 83/83, validation error 10.312500 %\n",
" epoch 3, minibatch 83/83, test error of best model 10.312500 %\n",
"The code run for 3 epochs with 15.999308 epochs/sec\n",
"epoch 4, minibatch 83/83, validation error 9.875000 %\n",
" epoch 4, minibatch 83/83, test error of best model 9.833333 %\n",
"The code run for 4 epochs with 16.014157 epochs/sec\n",
"epoch 5, minibatch 83/83, validation error 9.562500 %\n",
" epoch 5, minibatch 83/83, test error of best model 9.479167 %\n",
"The code run for 5 epochs with 15.449362 epochs/sec\n",
"epoch 6, minibatch 83/83, validation error 9.322917 %\n",
" epoch 6, minibatch 83/83, test error of best model 9.291667 %\n",
"The code run for 6 epochs with 15.185566 epochs/sec\n",
"epoch 7, minibatch 83/83, validation error 9.187500 %\n",
" epoch 7, minibatch 83/83, test error of best model 9.000000 %\n",
"The code run for 7 epochs with 14.893745 epochs/sec\n",
"epoch 8, minibatch 83/83, validation error 8.989583 %\n",
" epoch 8, minibatch 83/83, test error of best model 8.958333 %\n",
"The code run for 8 epochs with 14.653838 epochs/sec\n",
"epoch 9, minibatch 83/83, validation error 8.937500 %\n",
" epoch 9, minibatch 83/83, test error of best model 8.812500 %\n",
"The code run for 9 epochs with 14.691694 epochs/sec\n",
"epoch 10, minibatch 83/83, validation error 8.750000 %\n",
" epoch 10, minibatch 83/83, test error of best model 8.666667 %\n",
"The code run for 10 epochs with 14.825818 epochs/sec\n",
"epoch 11, minibatch 83/83, validation error 8.666667 %\n",
" epoch 11, minibatch 83/83, test error of best model 8.520833 %\n",
"The code run for 11 epochs with 14.880873 epochs/sec\n",
"epoch 12, minibatch 83/83, validation error 8.583333 %\n",
" epoch 12, minibatch 83/83, test error of best model 8.416667 %\n",
"The code run for 12 epochs with 14.952862 epochs/sec\n",
"epoch 13, minibatch 83/83, validation error 8.489583 %\n",
" epoch 13, minibatch 83/83, test error of best model 8.291667 %\n",
"The code run for 13 epochs with 15.042273 epochs/sec\n",
"epoch 14, minibatch 83/83, validation error 8.427083 %\n",
" epoch 14, minibatch 83/83, test error of best model 8.281250 %\n",
"The code run for 14 epochs with 15.109374 epochs/sec\n",
"epoch 15, minibatch 83/83, validation error 8.354167 %\n",
" epoch 15, minibatch 83/83, test error of best model 8.270833 %\n",
"The code run for 15 epochs with 15.176273 epochs/sec\n",
"epoch 16, minibatch 83/83, validation error 8.302083 %\n",
" epoch 16, minibatch 83/83, test error of best model 8.239583 %\n",
"The code run for 16 epochs with 15.232987 epochs/sec\n",
"epoch 17, minibatch 83/83, validation error 8.250000 %\n",
" epoch 17, minibatch 83/83, test error of best model 8.177083 %\n",
"The code run for 17 epochs with 15.283110 epochs/sec\n",
"epoch 18, minibatch 83/83, validation error 8.229167 %\n",
" epoch 18, minibatch 83/83, test error of best model 8.062500 %\n",
"The code run for 18 epochs with 15.320284 epochs/sec\n",
"epoch 19, minibatch 83/83, validation error 8.260417 %\n",
"The code run for 19 epochs with 15.447846 epochs/sec\n",
"epoch 20, minibatch 83/83, validation error 8.260417 %\n",
"The code run for 20 epochs with 15.562980 epochs/sec\n",
"epoch 21, minibatch 83/83, validation error 8.208333 %\n",
" epoch 21, minibatch 83/83, test error of best model 7.947917 %\n",
"The code run for 21 epochs with 15.590202 epochs/sec\n",
"epoch 22, minibatch 83/83, validation error 8.187500 %\n",
" epoch 22, minibatch 83/83, test error of best model 7.927083 %\n",
"The code run for 22 epochs with 15.536120 epochs/sec\n",
"epoch 23, minibatch 83/83, validation error 8.156250 %\n",
" epoch 23, minibatch 83/83, test error of best model 7.958333 %\n",
"The code run for 23 epochs with 15.381415 epochs/sec\n",
"epoch 24, minibatch 83/83, validation error 8.114583 %\n",
" epoch 24, minibatch 83/83, test error of best model 7.947917 %\n",
"The code run for 24 epochs with 15.352218 epochs/sec\n",
"epoch 25, minibatch 83/83, validation error 8.093750 %\n",
" epoch 25, minibatch 83/83, test error of best model 7.947917 %\n",
"The code run for 25 epochs with 15.314788 epochs/sec\n",
"epoch 26, minibatch 83/83, validation error 8.104167 %\n",
"The code run for 26 epochs with 15.409225 epochs/sec\n",
"epoch 27, minibatch 83/83, validation error 8.104167 %\n",
"The code run for 27 epochs with 15.456711 epochs/sec\n",
"epoch 28, minibatch 83/83, validation error 8.052083 %\n",
" epoch 28, minibatch 83/83, test error of best model 7.843750 %\n",
"The code run for 28 epochs with 15.406752 epochs/sec\n",
"epoch 29, minibatch 83/83, validation error 8.052083 %\n",
"The code run for 29 epochs with 15.484653 epochs/sec\n",
"epoch 30, minibatch 83/83, validation error 8.031250 %\n",
" epoch 30, minibatch 83/83, test error of best model 7.843750 %\n",
"The code run for 30 epochs with 15.501657 epochs/sec\n",
"epoch 31, minibatch 83/83, validation error 8.010417 %\n",
" epoch 31, minibatch 83/83, test error of best model 7.833333 %\n",
"The code run for 31 epochs with 15.519096 epochs/sec\n",
"epoch 32, minibatch 83/83, validation error 7.979167 %\n",
" epoch 32, minibatch 83/83, test error of best model 7.812500 %\n",
"The code run for 32 epochs with 15.535383 epochs/sec\n",
"epoch 33, minibatch 83/83, validation error 7.947917 %\n",
" epoch 33, minibatch 83/83, test error of best model 7.739583 %\n",
"The code run for 33 epochs with 15.552642 epochs/sec\n",
"epoch 34, minibatch 83/83, validation error 7.875000 %\n",
" epoch 34, minibatch 83/83, test error of best model 7.729167 %\n",
"The code run for 34 epochs with 15.568058 epochs/sec\n",
"epoch 35, minibatch 83/83, validation error 7.885417 %\n",
"The code run for 35 epochs with 15.633773 epochs/sec\n",
"epoch 36, minibatch 83/83, validation error 7.843750 %\n",
" epoch 36, minibatch 83/83, test error of best model 7.697917 %\n",
"The code run for 36 epochs with 15.642115 epochs/sec\n",
"epoch 37, minibatch 83/83, validation error 7.802083 %\n",
" epoch 37, minibatch 83/83, test error of best model 7.635417 %\n",
"The code run for 37 epochs with 15.652412 epochs/sec\n",
"epoch 38, minibatch 83/83, validation error 7.812500 %\n",
"The code run for 38 epochs with 15.708705 epochs/sec\n",
"epoch 39, minibatch 83/83, validation error 7.812500 %\n",
"The code run for 39 epochs with 15.761843 epochs/sec\n",
"epoch 40, minibatch 83/83, validation error 7.822917 %\n",
"The code run for 40 epochs with 15.814840 epochs/sec\n",
"epoch 41, minibatch 83/83, validation error 7.791667 %\n",
" epoch 41, minibatch 83/83, test error of best model 7.625000 %\n",
"The code run for 41 epochs with 15.813825 epochs/sec\n",
"epoch 42, minibatch 83/83, validation error 7.770833 %\n",
" epoch 42, minibatch 83/83, test error of best model 7.614583 %\n",
"The code run for 42 epochs with 15.818221 epochs/sec\n",
"epoch 43, minibatch 83/83, validation error 7.750000 %\n",
" epoch 43, minibatch 83/83, test error of best model 7.593750 %\n",
"The code run for 43 epochs with 15.823850 epochs/sec\n",
"epoch 44, minibatch 83/83, validation error 7.739583 %\n",
" epoch 44, minibatch 83/83, test error of best model 7.593750 %\n",
"The code run for 44 epochs with 15.826740 epochs/sec\n",
"epoch 45, minibatch 83/83, validation error 7.739583 %\n",
"The code run for 45 epochs with 15.873368 epochs/sec\n",
"epoch 46, minibatch 83/83, validation error 7.739583 %\n",
"The code run for 46 epochs with 15.916454 epochs/sec\n",
"epoch 47, minibatch 83/83, validation error 7.739583 %\n",
"The code run for 47 epochs with 15.957325 epochs/sec\n",
"epoch 48, minibatch 83/83, validation error 7.708333 %\n",
" epoch 48, minibatch 83/83, test error of best model 7.583333 %\n",
"The code run for 48 epochs with 15.957134 epochs/sec\n",
"epoch 49, minibatch 83/83, validation error 7.677083 %\n",
" epoch 49, minibatch 83/83, test error of best model 7.572917 %\n",
"The code run for 49 epochs with 15.959851 epochs/sec\n",
"epoch 50, minibatch 83/83, validation error 7.677083 %\n",
"The code run for 50 epochs with 15.998726 epochs/sec\n",
"epoch 51, minibatch 83/83, validation error 7.677083 %\n",
"The code run for 51 epochs with 16.035674 epochs/sec\n",
"epoch 52, minibatch 83/83, validation error 7.656250 %\n",
" epoch 52, minibatch 83/83, test error of best model 7.541667 %\n",
"The code run for 52 epochs with 16.036584 epochs/sec\n",
"epoch 53, minibatch 83/83, validation error 7.656250 %\n",
"The code run for 53 epochs with 16.072481 epochs/sec\n",
"epoch 54, minibatch 83/83, validation error 7.635417 %\n",
" epoch 54, minibatch 83/83, test error of best model 7.520833 %\n",
"The code run for 54 epochs with 16.070318 epochs/sec\n",
"epoch 55, minibatch 83/83, validation error 7.635417 %\n",
"The code run for 55 epochs with 16.101268 epochs/sec\n",
"epoch 56, minibatch 83/83, validation error 7.635417 %\n",
"The code run for 56 epochs with 16.135487 epochs/sec\n",
"epoch 57, minibatch 83/83, validation error 7.604167 %\n",
" epoch 57, minibatch 83/83, test error of best model 7.489583 %\n",
"The code run for 57 epochs with 16.132597 epochs/sec\n",
"epoch 58, minibatch 83/83, validation error 7.583333 %\n",
" epoch 58, minibatch 83/83, test error of best model 7.458333 %\n",
"The code run for 58 epochs with 16.129691 epochs/sec\n",
"epoch 59, minibatch 83/83, validation error 7.572917 %\n",
" epoch 59, minibatch 83/83, test error of best model 7.468750 %\n",
"The code run for 59 epochs with 16.127119 epochs/sec\n",
"epoch 60, minibatch 83/83, validation error 7.572917 %\n",
"The code run for 60 epochs with 16.158208 epochs/sec\n",
"epoch 61, minibatch 83/83, validation error 7.583333 %\n",
"The code run for 61 epochs with 16.185764 epochs/sec\n",
"epoch 62, minibatch 83/83, validation error 7.572917 %\n",
" epoch 62, minibatch 83/83, test error of best model 7.520833 %\n",
"The code run for 62 epochs with 16.182345 epochs/sec\n",
"epoch 63, minibatch 83/83, validation error 7.562500 %\n",
" epoch 63, minibatch 83/83, test error of best model 7.510417 %\n",
"The code run for 63 epochs with 16.180552 epochs/sec\n",
"epoch 64, minibatch 83/83, validation error 7.572917 %\n",
"The code run for 64 epochs with 16.205560 epochs/sec\n",
"epoch 65, minibatch 83/83, validation error 7.562500 %\n",
"The code run for 65 epochs with 16.231046 epochs/sec\n",
"epoch 66, minibatch 83/83, validation error 7.552083 %\n",
" epoch 66, minibatch 83/83, test error of best model 7.520833 %\n",
"The code run for 66 epochs with 16.227855 epochs/sec\n",
"epoch 67, minibatch 83/83, validation error 7.552083 %\n",
"The code run for 67 epochs with 16.252578 epochs/sec\n",
"epoch 68, minibatch 83/83, validation error 7.531250 %\n",
" epoch 68, minibatch 83/83, test error of best model 7.520833 %\n",
"The code run for 68 epochs with 16.247633 epochs/sec\n",
"epoch 69, minibatch 83/83, validation error 7.531250 %\n",
"The code run for 69 epochs with 16.273158 epochs/sec\n",
"epoch 70, minibatch 83/83, validation error 7.510417 %\n",
" epoch 70, minibatch 83/83, test error of best model 7.500000 %\n",
"The code run for 70 epochs with 16.271350 epochs/sec\n",
"epoch 71, minibatch 83/83, validation error 7.520833 %\n",
"The code run for 71 epochs with 16.296458 epochs/sec\n",
"epoch 72, minibatch 83/83, validation error 7.510417 %\n",
"The code run for 72 epochs with 16.320288 epochs/sec\n",
"epoch 73, minibatch 83/83, validation error 7.500000 %\n",
" epoch 73, minibatch 83/83, test error of best model 7.489583 %\n",
"The code run for 73 epochs with 16.315598 epochs/sec\n",
"epoch 74, minibatch 83/83, validation error 7.479167 %\n",
" epoch 74, minibatch 83/83, test error of best model 7.489583 %\n",
"The code run for 74 epochs with 16.312955 epochs/sec\n"
]
}
],
"source": [
"# parameters\n",
"learning_rate=0.13\n",
"n_epochs=1000\n",
"batch_size=600\n",
"\n",
"# Train, Validation, Testに分割\n",
"train_set_x, train_set_y = datasets[0]\n",
"valid_set_x, valid_set_y = datasets[1]\n",
"test_set_x, test_set_y = datasets[2]\n",
"\n",
"# Batch数を計算\n",
"nb_train_batches = train_set_x.get_value(borrow=True).shape[0] / batch_size # borrow=Trueはcuda arrayで戻すという意味\n",
"nb_valid_batches = valid_set_x.get_value(borrow=True).shape[0] / batch_size\n",
"nb_test_batches = test_set_x.get_value(borrow=True).shape[0] / batch_size\n",
"\n",
"# mini bacthのインデックスに関するシンボル変数を割り当て\n",
"index = T.lscalar()\n",
"\n",
"# xとyのシンボル変数を作成\n",
"x = T.matrix('x')\n",
"y = T.ivector('y')\n",
"\n",
"# 分類器を作成\n",
"classifier = LogisticRegression(input=x, n_in=28*28, n_out=10)\n",
"\n",
"# コスト関数のシンボルを取得\n",
"cost = classifier.negative_log_likelihood(y)\n",
"\n",
"# モデルをテストするためのTheano functionを定義\n",
"test_model = theano.function(\n",
" inputs=[index],\n",
" outputs=classifier.errors(y),\n",
" givens={\n",
" x: test_set_x[index * batch_size: (index + 1) * batch_size],\n",
" y: test_set_y[index * batch_size: (index + 1) * batch_size]\n",
" }\n",
")\n",
"\n",
"# モデルを検証するためのTheano functionを定義\n",
"validate_model = theano.function(\n",
" inputs=[index],\n",
" outputs=classifier.errors(y),\n",
" givens={\n",
" x: valid_set_x[index * batch_size: (index + 1) * batch_size],\n",
" y: valid_set_y[index * batch_size: (index + 1) * batch_size]\n",
" }\n",
")\n",
"\n",
"# コスト関数のパラメタについての偏微分のシンボル\n",
"g_W = T.grad(cost=cost, wrt=classifier.W)\n",
"g_b = T.grad(cost=cost, wrt=classifier.b)\n",
"\n",
"# 更新式をシンボルで定義\n",
"updates = [(classifier.W, classifier.W - learning_rate * g_W),\n",
" (classifier.b, classifier.b - learning_rate * g_b)\n",
" ]\n",
"\n",
"# パラメタを更新するためのTheano関数を定義\n",
"train_model = theano.function(\n",
" inputs=[index],\n",
" outputs=cost,\n",
" updates=updates,\n",
" givens={\n",
" x: train_set_x[index * batch_size: (index+1) * batch_size],\n",
" y: train_set_y[index * batch_size: (index+1) * batch_size]\n",
" }\n",
")\n",
"\n",
"# ####### Training ###### #\n",
"# early stoppingのパラメタ\n",
"patience = 5000\n",
"patience_increase = 2\n",
"\n",
"improvement_threshold = 0.995\n",
"validation_frequency = min(nb_train_batches, patience/2)\n",
"\n",
"best_validation_loss = np.inf # 大きな値で初期化\n",
"test_score = 0.\n",
"start_time = timeit.default_timer()\n",
"\n",
"done_looping = False\n",
"epoch = 0\n",
"while (epoch < n_epochs) and (not done_looping):\n",
" epoch += 1\n",
" for minibatch_index in xrange(nb_train_batches):\n",
" minibatch_avg_cost = train_model(minibatch_index) # indexシンボルに現在のmini batchのインデックスを入力\n",
" \n",
" iteration = (epoch - 1) * nb_train_batches + minibatch_index\n",
" if (iteration + 1) % validation_frequency == 0:\n",
" validation_losses = [validate_model(i) for i in range(nb_valid_batches)]\n",
" this_validation_loss = np.mean(validation_losses)\n",
" \n",
" print(\n",
" 'epoch %i, minibatch %i/%i, validation error %f %%' %\n",
" (\n",
" epoch,\n",
" minibatch_index + 1,\n",
" nb_train_batches,\n",
" this_validation_loss * 100.\n",
" )\n",
" )\n",
" \n",
" if this_validation_loss < best_validation_loss:\n",
" if this_validation_loss < best_validation_loss * improvement_threshold:\n",
" patience = max(patience, iteration * patience_increase)\n",
" best_validation_loss = this_validation_loss\n",
" \n",
" test_losses = [test_model(i) for i in range(nb_test_batches)]\n",
" test_score = np.mean(test_losses)\n",
" \n",
" print(\n",
" (\n",
" ' epoch %i, minibatch %i/%i, test error of'\n",
" ' best model %f %%'\n",
" ) %\n",
" (\n",
" epoch,\n",
" minibatch_index + 1,\n",
" nb_train_batches,\n",
" test_score * 100.\n",
" )\n",
" )\n",
"\n",
" # save the best model\n",
" with open('best_model.pkl', 'w') as f:\n",
" cPickle.dump(classifier, f) \n",
" if patience <= iteration:\n",
" done_looping = True\n",
" break \n",
" end_time = timeit.default_timer()\n",
" print 'The code run for %d epochs with %f epochs/sec' %(\n",
" epoch, 1. * epoch / (end_time - start_time)\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
@yusuke0519
Copy link
Author

  • deep learning tutorialとほぼ同じ
  • 微妙表記とかが違うのは、自分の表記に合わせため

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment