Last active
September 20, 2018 02:44
-
-
Save henryturner27/69fc2343cb733651ed568d1a8d2dbb42 to your computer and use it in GitHub Desktop.
Shallow Network to solve MNIST
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"from os import listdir\n", | |
"import PIL\n", | |
"from random import shuffle\n", | |
"\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.optim as optim\n", | |
"import torch.nn.functional as F" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# if gpu is to be used\n", | |
"use_cuda = torch.cuda.is_available()\n", | |
"device = torch.device('cuda:0' if use_cuda else 'cpu')\n", | |
"\n", | |
"data_path = 'data/mnist'\n", | |
"trn_data_path = f'{data_path}/train'\n", | |
"test_data_path = f'{data_path}/test'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"temp_train_data = []\n", | |
"temp_val_data = []\n", | |
"for label in listdir(trn_data_path):\n", | |
" for jpg in listdir(trn_data_path+'/'+label):\n", | |
" file_path = trn_data_path + '/' + label + '/' + jpg\n", | |
" img = PIL.Image.open(file_path)\n", | |
" lr_img = img.rotate(np.random.choice(range(-10,-5)))\n", | |
" rr_img = img.rotate(np.random.choice(range(5,10)))\n", | |
" data_arr = np.asarray(img).flatten() / 255.0 * 0.99 + 0.01\n", | |
" data_arr = torch.FloatTensor([data_arr]).to(device)\n", | |
" lr_data_arr = np.asarray(lr_img).flatten() / 255.0 * 0.99 + 0.01\n", | |
" lr_data_arr = torch.FloatTensor([lr_data_arr]).to(device)\n", | |
" rr_data_arr = np.asarray(rr_img).flatten() / 255.0 * 0.99 + 0.01\n", | |
" rr_data_arr = torch.FloatTensor([rr_data_arr]).to(device)\n", | |
" label_arr = np.zeros(10) + 0.01\n", | |
" label_arr[int(label)] = 0.99\n", | |
" label_arr = torch.FloatTensor([label_arr]).to(device)\n", | |
" if np.random.choice(100) > 19: \n", | |
" temp_train_data.append([data_arr, lr_data_arr, rr_data_arr, label_arr])\n", | |
" else:\n", | |
" temp_val_data.append([data_arr, lr_data_arr, rr_data_arr, label_arr])\n", | |
" \n", | |
"shuffle(temp_train_data)\n", | |
"shuffle(temp_val_data)\n", | |
"\n", | |
"train_data = torch.FloatTensor().to(device)\n", | |
"train_labels = torch.FloatTensor().to(device)\n", | |
"validation_data = torch.FloatTensor().to(device)\n", | |
"validation_labels = torch.FloatTensor().to(device)\n", | |
"\n", | |
"for x in temp_train_data:\n", | |
" train_data = torch.cat([train_data, x[0]])\n", | |
" train_data = torch.cat([train_data, x[1]])\n", | |
" train_data = torch.cat([train_data, x[2]])\n", | |
" train_labels = torch.cat([train_labels, x[-1]])\n", | |
" train_labels = torch.cat([train_labels, x[-1]])\n", | |
" train_labels = torch.cat([train_labels, x[-1]])\n", | |
" \n", | |
"for x in temp_val_data:\n", | |
" validation_data = torch.cat([validation_data, x[0]])\n", | |
" validation_data = torch.cat([validation_data, x[1]])\n", | |
" validation_data = torch.cat([validation_data, x[2]])\n", | |
" validation_labels = torch.cat([validation_labels, x[-1]])\n", | |
" validation_labels = torch.cat([validation_labels, x[-1]])\n", | |
" validation_labels = torch.cat([validation_labels, x[-1]])\n", | |
" \n", | |
"del temp_train_data\n", | |
"del temp_val_data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"100923" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"len(train_data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"25077" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"len(validation_data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Network(nn.Module):\n", | |
" def __init__(self):\n", | |
" super(Network, self).__init__()\n", | |
" self.l1 = nn.Linear(784, 300)\n", | |
" self.l2 = nn.Linear(300, 10)\n", | |
" \n", | |
" def forward(self, x):\n", | |
" x = F.relu(self.l1(x))\n", | |
" x = F.dropout(x, p=0.6, training=self.training)\n", | |
" x = torch.sigmoid(self.l2(x))\n", | |
" return x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 40, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def train(train_data, training_labels, validation_data, validation_labels, epoch, batch_size):\n", | |
" train_loss = []\n", | |
" val_loss = []\n", | |
" train_acc = 0\n", | |
" val_acc = 0\n", | |
" for bs in range(batch_size):\n", | |
" model.train()\n", | |
" batch_start = int(len(train_data)/batch_size*bs)\n", | |
" batch_end = int(len(train_data)/batch_size*(bs+1))\n", | |
" prediction = model(train_data[batch_start:batch_end])\n", | |
" actual = train_labels[batch_start:batch_end]\n", | |
" loss = F.mse_loss(prediction, actual).mean()\n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" train_acc += (prediction.argmax(1) == actual.argmax(1)).cpu().detach().numpy().sum()\n", | |
" train_loss.append(loss.sum().cpu().detach().numpy())\n", | |
" train_loss = np.asarray(train_loss).mean()\n", | |
" train_acc = train_acc / len(train_data)\n", | |
" print('Epoch {} training loss: {}'.format(epoch, train_loss))\n", | |
" print('Epoch {} training accuracy: {}'.format(epoch, train_acc))\n", | |
"\n", | |
" for bs in range(batch_size):\n", | |
" model.eval()\n", | |
" batch_start = int(len(validation_data)/batch_size*bs)\n", | |
" batch_end = int(len(validation_data)/batch_size*(bs+1))\n", | |
" prediction = model(validation_data[batch_start:batch_end])\n", | |
" actual = validation_labels[batch_start:batch_end]\n", | |
" loss = F.mse_loss(prediction, actual).mean()\n", | |
" val_acc += (prediction.argmax(1) == actual.argmax(1)).cpu().detach().numpy().sum()\n", | |
" val_loss.append(loss.sum().cpu().detach().numpy())\n", | |
" val_loss = np.asarray(val_loss).mean()\n", | |
" val_acc = val_acc / len(validation_data)\n", | |
" print('Epoch {} validation loss: {}'.format(epoch, val_loss))\n", | |
" print('Epoch {} validation accuracy: {}\\n'.format(epoch, val_acc))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 47, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 0 training loss: 0.07884197682142258\n", | |
"Epoch 0 training accuracy: 0.48285326436986614\n", | |
"Epoch 0 validation loss: 0.044085800647735596\n", | |
"Epoch 0 validation accuracy: 0.7827491326713721\n", | |
"\n", | |
"Epoch 1 training loss: 0.03727114200592041\n", | |
"Epoch 1 training accuracy: 0.7965280461341815\n", | |
"Epoch 1 validation loss: 0.026454336941242218\n", | |
"Epoch 1 validation accuracy: 0.8639390676715716\n", | |
"\n", | |
"Epoch 2 training loss: 0.026401059702038765\n", | |
"Epoch 2 training accuracy: 0.8554145239439969\n", | |
"Epoch 2 validation loss: 0.020730365067720413\n", | |
"Epoch 2 validation accuracy: 0.8848745862742753\n", | |
"\n", | |
"Epoch 3 training loss: 0.02180325612425804\n", | |
"Epoch 3 training accuracy: 0.8766980767515828\n", | |
"Epoch 3 validation loss: 0.01779206097126007\n", | |
"Epoch 3 validation accuracy: 0.8970371256529888\n", | |
"\n", | |
"Epoch 4 training loss: 0.01913328282535076\n", | |
"Epoch 4 training accuracy: 0.891986960355915\n", | |
"Epoch 4 validation loss: 0.015950201079249382\n", | |
"Epoch 4 validation accuracy: 0.9064481397296328\n", | |
"\n", | |
"Epoch 5 training loss: 0.017314637079834938\n", | |
"Epoch 5 training accuracy: 0.900746113373562\n", | |
"Epoch 5 validation loss: 0.014626124873757362\n", | |
"Epoch 5 validation accuracy: 0.9121106990469354\n", | |
"\n", | |
"Epoch 6 training loss: 0.016011804342269897\n", | |
"Epoch 6 training accuracy: 0.907909990785054\n", | |
"Epoch 6 validation loss: 0.013589252717792988\n", | |
"Epoch 6 validation accuracy: 0.9179327670774016\n", | |
"\n", | |
"Epoch 7 training loss: 0.01495556440204382\n", | |
"Epoch 7 training accuracy: 0.9132407875310881\n", | |
"Epoch 7 validation loss: 0.012766270898282528\n", | |
"Epoch 7 validation accuracy: 0.9232763089683774\n", | |
"\n", | |
"Epoch 8 training loss: 0.014054970815777779\n", | |
"Epoch 8 training accuracy: 0.9186508526302231\n", | |
"Epoch 8 validation loss: 0.012104556895792484\n", | |
"Epoch 8 validation accuracy: 0.9265861147665191\n", | |
"\n", | |
"Epoch 9 training loss: 0.013417595066130161\n", | |
"Epoch 9 training accuracy: 0.9230502462273218\n", | |
"Epoch 9 validation loss: 0.011507791467010975\n", | |
"Epoch 9 validation accuracy: 0.9309726043785141\n", | |
"\n", | |
"Epoch 10 training loss: 0.012768560089170933\n", | |
"Epoch 10 training accuracy: 0.9266371392051366\n", | |
"Epoch 10 validation loss: 0.01100181881338358\n", | |
"Epoch 10 validation accuracy: 0.9346413047812737\n", | |
"\n", | |
"Epoch 11 training loss: 0.01228629145771265\n", | |
"Epoch 11 training accuracy: 0.9290944581512638\n", | |
"Epoch 11 validation loss: 0.010575709864497185\n", | |
"Epoch 11 validation accuracy: 0.9370339354787255\n", | |
"\n", | |
"Epoch 12 training loss: 0.011768396012485027\n", | |
"Epoch 12 training accuracy: 0.9327507109380418\n", | |
"Epoch 12 validation loss: 0.010171184316277504\n", | |
"Epoch 12 validation accuracy: 0.9390676715715596\n", | |
"\n", | |
"Epoch 13 training loss: 0.011383026838302612\n", | |
"Epoch 13 training accuracy: 0.9352278469724443\n", | |
"Epoch 13 validation loss: 0.009807510301470757\n", | |
"Epoch 13 validation accuracy: 0.9412609163775572\n", | |
"\n", | |
"Epoch 14 training loss: 0.010945308953523636\n", | |
"Epoch 14 training accuracy: 0.9386562032440573\n", | |
"Epoch 14 validation loss: 0.009531907737255096\n", | |
"Epoch 14 validation accuracy: 0.9425369860828647\n", | |
"\n", | |
"Epoch 15 training loss: 0.010687393136322498\n", | |
"Epoch 15 training accuracy: 0.9395876063929927\n", | |
"Epoch 15 validation loss: 0.009243899025022984\n", | |
"Epoch 15 validation accuracy: 0.9443314591059536\n", | |
"\n", | |
"Epoch 16 training loss: 0.01043588574975729\n", | |
"Epoch 16 training accuracy: 0.9416981262943036\n", | |
"Epoch 16 validation loss: 0.009006698615849018\n", | |
"Epoch 16 validation accuracy: 0.9456075288112613\n", | |
"\n", | |
"Epoch 17 training loss: 0.01012509036809206\n", | |
"Epoch 17 training accuracy: 0.9436104753128622\n", | |
"Epoch 17 validation loss: 0.008757558651268482\n", | |
"Epoch 17 validation accuracy: 0.9470032300514416\n", | |
"\n", | |
"Epoch 18 training loss: 0.009891458787024021\n", | |
"Epoch 18 training accuracy: 0.9450472142128157\n", | |
"Epoch 18 validation loss: 0.008560443297028542\n", | |
"Epoch 18 validation accuracy: 0.9487179487179487\n", | |
"\n", | |
"Epoch 19 training loss: 0.009660371579229832\n", | |
"Epoch 19 training accuracy: 0.9462461480534665\n", | |
"Epoch 19 validation loss: 0.008387024514377117\n", | |
"Epoch 19 validation accuracy: 0.9498345097100929\n", | |
"\n", | |
"Epoch 20 training loss: 0.00941749569028616\n", | |
"Epoch 20 training accuracy: 0.9477027040416951\n", | |
"Epoch 20 validation loss: 0.00819950457662344\n", | |
"Epoch 20 validation accuracy: 0.9513498424851458\n", | |
"\n", | |
"Epoch 21 training loss: 0.00916802603751421\n", | |
"Epoch 21 training accuracy: 0.9496348701485291\n", | |
"Epoch 21 validation loss: 0.008044376969337463\n", | |
"Epoch 21 validation accuracy: 0.9525461578338716\n", | |
"\n", | |
"Epoch 22 training loss: 0.009039959870278835\n", | |
"Epoch 22 training accuracy: 0.9502789255174737\n", | |
"Epoch 22 validation loss: 0.007887438870966434\n", | |
"Epoch 22 validation accuracy: 0.9532240698648163\n", | |
"\n", | |
"Epoch 23 training loss: 0.008864056318998337\n", | |
"Epoch 23 training accuracy: 0.9516760302408767\n", | |
"Epoch 23 validation loss: 0.007779685780405998\n", | |
"Epoch 23 validation accuracy: 0.9540216134306336\n", | |
"\n", | |
"Epoch 24 training loss: 0.008710344322025776\n", | |
"Epoch 24 training accuracy: 0.9526470675663625\n", | |
"Epoch 24 validation loss: 0.007641170173883438\n", | |
"Epoch 24 validation accuracy: 0.9552976831359413\n", | |
"\n", | |
"Epoch 25 training loss: 0.008589557372033596\n", | |
"Epoch 25 training accuracy: 0.9527560615518762\n", | |
"Epoch 25 validation loss: 0.007509811315685511\n", | |
"Epoch 25 validation accuracy: 0.9558559636320134\n", | |
"\n", | |
"Epoch 26 training loss: 0.00840373057872057\n", | |
"Epoch 26 training accuracy: 0.9538955441277013\n", | |
"Epoch 26 validation loss: 0.007363432552665472\n", | |
"Epoch 26 validation accuracy: 0.9566136300195398\n", | |
"\n", | |
"Epoch 27 training loss: 0.008271198719739914\n", | |
"Epoch 27 training accuracy: 0.9555007282779941\n", | |
"Epoch 27 validation loss: 0.007314777933061123\n", | |
"Epoch 27 validation accuracy: 0.9569725246241576\n", | |
"\n", | |
"Epoch 28 training loss: 0.00815747119486332\n", | |
"Epoch 28 training accuracy: 0.9560853323821131\n", | |
"Epoch 28 validation loss: 0.0071959588676691055\n", | |
"Epoch 28 validation accuracy: 0.9579694540814292\n", | |
"\n", | |
"Epoch 29 training loss: 0.008058251813054085\n", | |
"Epoch 29 training accuracy: 0.9564321314269294\n", | |
"Epoch 29 validation loss: 0.007111657410860062\n", | |
"Epoch 29 validation accuracy: 0.958328348686047\n", | |
"\n", | |
"Epoch 30 training loss: 0.007924782112240791\n", | |
"Epoch 30 training accuracy: 0.9574229858406904\n", | |
"Epoch 30 validation loss: 0.007042724639177322\n", | |
"Epoch 30 validation accuracy: 0.9585277345775013\n", | |
"\n", | |
"Epoch 31 training loss: 0.0077895959839224815\n", | |
"Epoch 31 training accuracy: 0.9573932602082776\n", | |
"Epoch 31 validation loss: 0.006950486917048693\n", | |
"Epoch 31 validation accuracy: 0.9595246640347729\n", | |
"\n", | |
"Epoch 32 training loss: 0.007699040696024895\n", | |
"Epoch 32 training accuracy: 0.9581165839303231\n", | |
"Epoch 32 validation loss: 0.006855316460132599\n", | |
"Epoch 32 validation accuracy: 0.9599633129959724\n", | |
"\n", | |
"Epoch 33 training loss: 0.007561581674963236\n", | |
"Epoch 33 training accuracy: 0.9593947861240748\n", | |
"Epoch 33 validation loss: 0.006781487260013819\n", | |
"Epoch 33 validation accuracy: 0.9603620847788811\n", | |
"\n", | |
"Epoch 34 training loss: 0.007479057181626558\n", | |
"Epoch 34 training accuracy: 0.9598109449778544\n", | |
"Epoch 34 validation loss: 0.006722167134284973\n", | |
"Epoch 34 validation accuracy: 0.9611995055229892\n", | |
"\n", | |
"Epoch 35 training loss: 0.007352054584771395\n", | |
"Epoch 35 training accuracy: 0.960861250656441\n", | |
"Epoch 35 validation loss: 0.006634308956563473\n", | |
"Epoch 35 validation accuracy: 0.9612393827012801\n", | |
"\n", | |
"Epoch 36 training loss: 0.007264348212629557\n", | |
"Epoch 36 training accuracy: 0.9613269522309087\n", | |
"Epoch 36 validation loss: 0.006579964887350798\n", | |
"Epoch 36 validation accuracy: 0.9613988914144436\n", | |
"\n", | |
"Epoch 37 training loss: 0.007171745412051678\n", | |
"Epoch 37 training accuracy: 0.962387166453633\n", | |
"Epoch 37 validation loss: 0.0065085249952971935\n", | |
"Epoch 37 validation accuracy: 0.9617976631973522\n", | |
"\n", | |
"Epoch 38 training loss: 0.007131265476346016\n", | |
"Epoch 38 training accuracy: 0.961802562349514\n", | |
"Epoch 38 validation loss: 0.006452173460274935\n", | |
"Epoch 38 validation accuracy: 0.9619970490888065\n", | |
"\n", | |
"Epoch 39 training loss: 0.007064123172312975\n", | |
"Epoch 39 training accuracy: 0.9626249715129356\n", | |
"Epoch 39 validation loss: 0.006395110860466957\n", | |
"Epoch 39 validation accuracy: 0.9624755752282969\n", | |
"\n", | |
"Epoch 40 training loss: 0.006988056004047394\n", | |
"Epoch 40 training accuracy: 0.9635662832060086\n", | |
"Epoch 40 validation loss: 0.006371591705828905\n", | |
"Epoch 40 validation accuracy: 0.9632731187941141\n", | |
"\n", | |
"Epoch 41 training loss: 0.006856718100607395\n", | |
"Epoch 41 training accuracy: 0.9639031737066873\n", | |
"Epoch 41 validation loss: 0.00629306398332119\n", | |
"Epoch 41 validation accuracy: 0.9633129959724049\n", | |
"\n", | |
"Epoch 42 training loss: 0.006834879517555237\n", | |
"Epoch 42 training accuracy: 0.9638040882653112\n", | |
"Epoch 42 validation loss: 0.0062795961275696754\n", | |
"Epoch 42 validation accuracy: 0.9629541013677873\n", | |
"\n", | |
"Epoch 43 training loss: 0.00673295883461833\n", | |
"Epoch 43 training accuracy: 0.964547229075632\n", | |
"Epoch 43 validation loss: 0.0062139034271240234\n", | |
"Epoch 43 validation accuracy: 0.9636718905770227\n", | |
"\n", | |
"Epoch 44 training loss: 0.006661310326308012\n", | |
"Epoch 44 training accuracy: 0.9653993638714664\n", | |
"Epoch 44 validation loss: 0.006167777813971043\n", | |
"Epoch 44 validation accuracy: 0.96446943414284\n", | |
"\n", | |
"Epoch 45 training loss: 0.006662233266979456\n", | |
"Epoch 45 training accuracy: 0.9655380834893929\n", | |
"Epoch 45 validation loss: 0.006144508719444275\n", | |
"Epoch 45 validation accuracy: 0.96446943414284\n", | |
"\n", | |
"Epoch 46 training loss: 0.006595049984753132\n", | |
"Epoch 46 training accuracy: 0.9657560714604203\n", | |
"Epoch 46 validation loss: 0.006099705584347248\n", | |
"Epoch 46 validation accuracy: 0.9647485743908761\n", | |
"\n", | |
"Epoch 47 training loss: 0.006484870798885822\n", | |
"Epoch 47 training accuracy: 0.9662911328438513\n", | |
"Epoch 47 validation loss: 0.006039380561560392\n", | |
"Epoch 47 validation accuracy: 0.9648682059257487\n", | |
"\n", | |
"Epoch 48 training loss: 0.0064193918369710445\n", | |
"Epoch 48 training accuracy: 0.9667271087859061\n", | |
"Epoch 48 validation loss: 0.005995721090584993\n", | |
"Epoch 48 validation accuracy: 0.9655062407784025\n", | |
"\n", | |
"Epoch 49 training loss: 0.006370623596012592\n", | |
"Epoch 49 training accuracy: 0.9673711641548507\n", | |
"Epoch 49 validation loss: 0.00597657123580575\n", | |
"Epoch 49 validation accuracy: 0.9651872233520756\n", | |
"\n", | |
"Epoch 50 training loss: 0.0062427399680018425\n", | |
"Epoch 50 training accuracy: 0.9677377802879423\n", | |
"Epoch 50 validation loss: 0.005949628539383411\n", | |
"Epoch 50 validation accuracy: 0.965944889739602\n", | |
"\n", | |
"Epoch 51 training loss: 0.006284354254603386\n", | |
"Epoch 51 training accuracy: 0.967390981243126\n", | |
"Epoch 51 validation loss: 0.005894490983337164\n", | |
"Epoch 51 validation accuracy: 0.9657853810264385\n", | |
"\n", | |
"Epoch 52 training loss: 0.006267307326197624\n", | |
"Epoch 52 training accuracy: 0.9674999752286396\n", | |
"Epoch 52 validation loss: 0.00585838733240962\n", | |
"Epoch 52 validation accuracy: 0.9665829245922558\n", | |
"\n", | |
"Epoch 53 training loss: 0.006208399776369333\n", | |
"Epoch 53 training accuracy: 0.9680746707886211\n", | |
"Epoch 53 validation loss: 0.005838208366185427\n", | |
"Epoch 53 validation accuracy: 0.9667025561271284\n", | |
"\n", | |
"Epoch 54 training loss: 0.006156044080853462\n", | |
"Epoch 54 training accuracy: 0.9681043964210339\n", | |
"Epoch 54 validation loss: 0.005806652829051018\n", | |
"Epoch 54 validation accuracy: 0.9667424333054193\n", | |
"\n", | |
"Epoch 55 training loss: 0.006079786457121372\n", | |
"Epoch 55 training accuracy: 0.9689565312168683\n", | |
"Epoch 55 validation loss: 0.005748886149376631\n", | |
"Epoch 55 validation accuracy: 0.9672608366232005\n", | |
"\n", | |
"Epoch 56 training loss: 0.005994039122015238\n", | |
"Epoch 56 training accuracy: 0.9696005865858129\n", | |
"Epoch 56 validation loss: 0.005743207409977913\n", | |
"Epoch 56 validation accuracy: 0.9671810822666188\n", | |
"\n", | |
"Epoch 57 training loss: 0.006007436662912369\n", | |
"Epoch 57 training accuracy: 0.9690060739375563\n", | |
"Epoch 57 validation loss: 0.0057016052305698395\n", | |
"Epoch 57 validation accuracy: 0.9678589942975635\n", | |
"\n", | |
"Epoch 58 training loss: 0.005948720499873161\n", | |
"Epoch 58 training accuracy: 0.9699473856306293\n", | |
"Epoch 58 validation loss: 0.0056551373563706875\n", | |
"Epoch 58 validation accuracy: 0.9680583801890178\n", | |
"\n", | |
"Epoch 59 training loss: 0.005832315888255835\n", | |
"Epoch 59 training accuracy: 0.9706211666319867\n", | |
"Epoch 59 validation loss: 0.005674091167747974\n", | |
"Epoch 59 validation accuracy: 0.9678589942975635\n", | |
"\n", | |
"Epoch 60 training loss: 0.005850379820913076\n", | |
"Epoch 60 training accuracy: 0.9699870198071797\n", | |
"Epoch 60 validation loss: 0.005655990913510323\n", | |
"Epoch 60 validation accuracy: 0.9678191171192726\n", | |
"\n", | |
"Epoch 61 training loss: 0.005793438758701086\n", | |
"Epoch 61 training accuracy: 0.970462629925785\n", | |
"Epoch 61 validation loss: 0.00561485206708312\n", | |
"Epoch 61 validation accuracy: 0.968018503010727\n", | |
"\n", | |
"Epoch 62 training loss: 0.00577025581151247\n", | |
"Epoch 62 training accuracy: 0.9707301606175005\n", | |
"Epoch 62 validation loss: 0.0055701094679534435\n", | |
"Epoch 62 validation accuracy: 0.9681381345455996\n", | |
"\n", | |
"Epoch 63 training loss: 0.00572052551433444\n", | |
"Epoch 63 training accuracy: 0.9713643074423075\n", | |
"Epoch 63 validation loss: 0.005567863583564758\n", | |
"Epoch 63 validation accuracy: 0.9682178889021813\n", | |
"\n", | |
"Epoch 64 training loss: 0.005703536327928305\n", | |
"Epoch 64 training accuracy: 0.9711661365595553\n", | |
"Epoch 64 validation loss: 0.0055429707281291485\n", | |
"Epoch 64 validation accuracy: 0.9685369063285082\n", | |
"\n", | |
"Epoch 65 training loss: 0.0056631918996572495\n", | |
"Epoch 65 training accuracy: 0.9714633928836836\n", | |
"Epoch 65 validation loss: 0.005500172730535269\n", | |
"Epoch 65 validation accuracy: 0.9688160465765443\n", | |
"\n", | |
"Epoch 66 training loss: 0.0056600007228553295\n", | |
"Epoch 66 training accuracy: 0.9721074482526283\n", | |
"Epoch 66 validation loss: 0.005500392057001591\n", | |
"Epoch 66 validation accuracy: 0.9689755552897077\n", | |
"\n", | |
"Epoch 67 training loss: 0.0055995844304561615\n", | |
"Epoch 67 training accuracy: 0.9723848874884813\n", | |
"Epoch 67 validation loss: 0.005494162440299988\n", | |
"Epoch 67 validation accuracy: 0.9687761693982534\n", | |
"\n", | |
"Epoch 68 training loss: 0.0055932835675776005\n", | |
"Epoch 68 training accuracy: 0.972087631164353\n", | |
"Epoch 68 validation loss: 0.005469959229230881\n", | |
"Epoch 68 validation accuracy: 0.9692148183594529\n", | |
"\n", | |
"Epoch 69 training loss: 0.005538692697882652\n", | |
"Epoch 69 training accuracy: 0.9722362593264172\n", | |
"Epoch 69 validation loss: 0.005439432337880135\n", | |
"Epoch 69 validation accuracy: 0.9692945727160346\n", | |
"\n", | |
"Epoch 70 training loss: 0.005462995730340481\n", | |
"Epoch 70 training accuracy: 0.9723353447677933\n", | |
"Epoch 70 validation loss: 0.005439480766654015\n", | |
"Epoch 70 validation accuracy: 0.9694142042509072\n", | |
"\n", | |
"Epoch 71 training loss: 0.005392757710069418\n", | |
"Epoch 71 training accuracy: 0.9731379368429397\n", | |
"Epoch 71 validation loss: 0.005391059909015894\n", | |
"Epoch 71 validation accuracy: 0.9699326075686885\n", | |
"\n", | |
"Epoch 72 training loss: 0.005394927691668272\n", | |
"Epoch 72 training accuracy: 0.973663089682233\n", | |
"Epoch 72 validation loss: 0.005383166950196028\n", | |
"Epoch 72 validation accuracy: 0.9696534673206524\n", | |
"\n", | |
"Epoch 73 training loss: 0.00538697699084878\n", | |
"Epoch 73 training accuracy: 0.9732766564608663\n", | |
"Epoch 73 validation loss: 0.005384756717830896\n", | |
"Epoch 73 validation accuracy: 0.9698129760338159\n", | |
"\n", | |
"Epoch 74 training loss: 0.005329283885657787\n", | |
"Epoch 74 training accuracy: 0.9735441871525816\n", | |
"Epoch 74 validation loss: 0.005368389189243317\n", | |
"Epoch 74 validation accuracy: 0.9700921162818519\n", | |
"\n", | |
"Epoch 75 training loss: 0.005318328272551298\n", | |
"Epoch 75 training accuracy: 0.9738612605649852\n", | |
"Epoch 75 validation loss: 0.005321497097611427\n", | |
"Epoch 75 validation accuracy: 0.9700921162818519\n", | |
"\n", | |
"Epoch 76 training loss: 0.005331655498594046\n", | |
"Epoch 76 training accuracy: 0.9734550102553432\n", | |
"Epoch 76 validation loss: 0.005318239331245422\n", | |
"Epoch 76 validation accuracy: 0.9702117478167245\n", | |
"\n", | |
"Epoch 77 training loss: 0.005248404107987881\n", | |
"Epoch 77 training accuracy: 0.9742873279629024\n", | |
"Epoch 77 validation loss: 0.005320506636053324\n", | |
"Epoch 77 validation accuracy: 0.9699326075686885\n", | |
"\n", | |
"Epoch 78 training loss: 0.005253283306956291\n", | |
"Epoch 78 training accuracy: 0.9737027238587834\n", | |
"Epoch 78 validation loss: 0.005250290967524052\n", | |
"Epoch 78 validation accuracy: 0.9703313793515971\n", | |
"\n", | |
"Epoch 79 training loss: 0.005223565269261599\n", | |
"Epoch 79 training accuracy: 0.9743963219484161\n", | |
"Epoch 79 validation loss: 0.005253038369119167\n", | |
"Epoch 79 validation accuracy: 0.9707301511345057\n", | |
"\n", | |
"Epoch 80 training loss: 0.0051938313990831375\n", | |
"Epoch 80 training accuracy: 0.974346779227728\n", | |
"Epoch 80 validation loss: 0.005250085145235062\n", | |
"Epoch 80 validation accuracy: 0.9708896598476692\n", | |
"\n", | |
"Epoch 81 training loss: 0.005150769371539354\n", | |
"Epoch 81 training accuracy: 0.9743765048601409\n", | |
"Epoch 81 validation loss: 0.005202730186283588\n", | |
"Epoch 81 validation accuracy: 0.9710491685608327\n", | |
"\n", | |
"Epoch 82 training loss: 0.005094591993838549\n", | |
"Epoch 82 training accuracy: 0.9751493713028745\n", | |
"Epoch 82 validation loss: 0.005250499118119478\n", | |
"Epoch 82 validation accuracy: 0.9707301511345057\n", | |
"\n", | |
"Epoch 83 training loss: 0.005153241567313671\n", | |
"Epoch 83 training accuracy: 0.9747431209932325\n", | |
"Epoch 83 validation loss: 0.005220973864197731\n", | |
"Epoch 83 validation accuracy: 0.9708896598476692\n", | |
"\n", | |
"Epoch 84 training loss: 0.005067689809948206\n", | |
"Epoch 84 training accuracy: 0.9752682738325258\n", | |
"Epoch 84 validation loss: 0.0051800827495753765\n", | |
"Epoch 84 validation accuracy: 0.9710890457391235\n", | |
"\n", | |
"Epoch 85 training loss: 0.004997973330318928\n", | |
"Epoch 85 training accuracy: 0.9754961703476908\n", | |
"Epoch 85 validation loss: 0.005149488337337971\n", | |
"Epoch 85 validation accuracy: 0.9715675718786139\n", | |
"\n", | |
"Epoch 86 training loss: 0.004979357589036226\n", | |
"Epoch 86 training accuracy: 0.9756249814214797\n", | |
"Epoch 86 validation loss: 0.005171299912035465\n", | |
"Epoch 86 validation accuracy: 0.971527694700323\n", | |
"\n", | |
"Epoch 87 training loss: 0.004984478931874037\n", | |
"Epoch 87 training accuracy: 0.9756051643332045\n", | |
"Epoch 87 validation loss: 0.00515236658975482\n", | |
"Epoch 87 validation accuracy: 0.9710092913825418\n", | |
"\n", | |
"Epoch 88 training loss: 0.00498884217813611\n", | |
"Epoch 88 training accuracy: 0.9755853472449293\n", | |
"Epoch 88 validation loss: 0.005136134568601847\n", | |
"Epoch 88 validation accuracy: 0.9713681859871596\n", | |
"\n", | |
"Epoch 89 training loss: 0.004946259316056967\n", | |
"Epoch 89 training accuracy: 0.9756249814214797\n", | |
"Epoch 89 validation loss: 0.005148548167198896\n", | |
"Epoch 89 validation accuracy: 0.9710092913825418\n", | |
"\n", | |
"Epoch 90 training loss: 0.004885654430836439\n", | |
"Epoch 90 training accuracy: 0.9766653785559288\n", | |
"Epoch 90 validation loss: 0.005129080731421709\n", | |
"Epoch 90 validation accuracy: 0.9716872034134865\n", | |
"\n", | |
"Epoch 91 training loss: 0.004882515873759985\n", | |
"Epoch 91 training accuracy: 0.9761699513490483\n", | |
"Epoch 91 validation loss: 0.005134152248501778\n", | |
"Epoch 91 validation accuracy: 0.9714080631654504\n", | |
"\n", | |
"Epoch 92 training loss: 0.004912896081805229\n", | |
"Epoch 92 training accuracy: 0.9759024206573328\n", | |
"Epoch 92 validation loss: 0.005110086407512426\n", | |
"Epoch 92 validation accuracy: 0.9713283088088687\n", | |
"\n", | |
"Epoch 93 training loss: 0.004869314841926098\n", | |
"Epoch 93 training accuracy: 0.9768635494386809\n", | |
"Epoch 93 validation loss: 0.005069723352789879\n", | |
"Epoch 93 validation accuracy: 0.9718068349483591\n", | |
"\n", | |
"Epoch 94 training loss: 0.004821186885237694\n", | |
"Epoch 94 training accuracy: 0.9767842810855801\n", | |
"Epoch 94 validation loss: 0.0050615654326975346\n", | |
"Epoch 94 validation accuracy: 0.97184671212665\n", | |
"\n", | |
"Epoch 95 training loss: 0.004831160884350538\n", | |
"Epoch 95 training accuracy: 0.9766455614676536\n", | |
"Epoch 95 validation loss: 0.0050490666180849075\n", | |
"Epoch 95 validation accuracy: 0.9719264664832317\n", | |
"\n", | |
"Epoch 96 training loss: 0.004834283608943224\n", | |
"Epoch 96 training accuracy: 0.9763284880552501\n", | |
"Epoch 96 validation loss: 0.0050497399643063545\n", | |
"Epoch 96 validation accuracy: 0.9718865893049408\n", | |
"\n", | |
"Epoch 97 training loss: 0.0047189281322062016\n", | |
"Epoch 97 training accuracy: 0.977626507337277\n", | |
"Epoch 97 validation loss: 0.005037985742092133\n", | |
"Epoch 97 validation accuracy: 0.9720062208398134\n", | |
"\n", | |
"Epoch 98 training loss: 0.0047546131536364555\n", | |
"Epoch 98 training accuracy: 0.9774085193662495\n", | |
"Epoch 98 validation loss: 0.0050339424051344395\n", | |
"Epoch 98 validation accuracy: 0.9718865893049408\n", | |
"\n", | |
"Epoch 99 training loss: 0.004727411083877087\n", | |
"Epoch 99 training accuracy: 0.9775373304400384\n", | |
"Epoch 99 validation loss: 0.005019115284085274\n", | |
"Epoch 99 validation accuracy: 0.9723252382661403\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"model = Network().to(device)\n", | |
"for epoch in range(100):\n", | |
" LR = 0.001\n", | |
" if epoch >= 15 & epoch < 30:\n", | |
" LR = 0.0001\n", | |
" elif epoch >= 40 & epoch < 60:\n", | |
" LR = 0.00001\n", | |
" elif epoch >= 60:\n", | |
" LR = 0.000005\n", | |
" optimizer = optim.Adam(model.parameters(), lr=LR)\n", | |
" train(train_data, train_labels, validation_data, validation_labels, epoch, batch_size=500)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 44, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/paperspace/anaconda3/envs/pytorch4/lib/python3.6/site-packages/torch/serialization.py:241: UserWarning: Couldn't retrieve source code for container of type Network. It won't be checked for correctness upon loading.\n", | |
" \"type \" + obj.__name__ + \". It won't be checked \"\n" | |
] | |
} | |
], | |
"source": [ | |
"torch.save(model, 'models/mnist_model.pt')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment