Skip to content

Instantly share code, notes, and snippets.

@henryturner27
Last active September 20, 2018 02:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save henryturner27/69fc2343cb733651ed568d1a8d2dbb42 to your computer and use it in GitHub Desktop.
Save henryturner27/69fc2343cb733651ed568d1a8d2dbb42 to your computer and use it in GitHub Desktop.
Shallow Network to solve MNIST
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from os import listdir\n",
"import PIL\n",
"from random import shuffle\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# if gpu is to be used\n",
"use_cuda = torch.cuda.is_available()\n",
"device = torch.device('cuda:0' if use_cuda else 'cpu')\n",
"\n",
"data_path = 'data/mnist'\n",
"trn_data_path = f'{data_path}/train'\n",
"test_data_path = f'{data_path}/test'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"temp_train_data = []\n",
"temp_val_data = []\n",
"for label in listdir(trn_data_path):\n",
" for jpg in listdir(trn_data_path+'/'+label):\n",
" file_path = trn_data_path + '/' + label + '/' + jpg\n",
" img = PIL.Image.open(file_path)\n",
" lr_img = img.rotate(np.random.choice(range(-10,-5)))\n",
" rr_img = img.rotate(np.random.choice(range(5,10)))\n",
" data_arr = np.asarray(img).flatten() / 255.0 * 0.99 + 0.01\n",
" data_arr = torch.FloatTensor([data_arr]).to(device)\n",
" lr_data_arr = np.asarray(lr_img).flatten() / 255.0 * 0.99 + 0.01\n",
" lr_data_arr = torch.FloatTensor([lr_data_arr]).to(device)\n",
" rr_data_arr = np.asarray(rr_img).flatten() / 255.0 * 0.99 + 0.01\n",
" rr_data_arr = torch.FloatTensor([rr_data_arr]).to(device)\n",
" label_arr = np.zeros(10) + 0.01\n",
" label_arr[int(label)] = 0.99\n",
" label_arr = torch.FloatTensor([label_arr]).to(device)\n",
" if np.random.choice(100) > 19: \n",
" temp_train_data.append([data_arr, lr_data_arr, rr_data_arr, label_arr])\n",
" else:\n",
" temp_val_data.append([data_arr, lr_data_arr, rr_data_arr, label_arr])\n",
" \n",
"shuffle(temp_train_data)\n",
"shuffle(temp_val_data)\n",
"\n",
"train_data = torch.FloatTensor().to(device)\n",
"train_labels = torch.FloatTensor().to(device)\n",
"validation_data = torch.FloatTensor().to(device)\n",
"validation_labels = torch.FloatTensor().to(device)\n",
"\n",
"for x in temp_train_data:\n",
" train_data = torch.cat([train_data, x[0]])\n",
" train_data = torch.cat([train_data, x[1]])\n",
" train_data = torch.cat([train_data, x[2]])\n",
" train_labels = torch.cat([train_labels, x[-1]])\n",
" train_labels = torch.cat([train_labels, x[-1]])\n",
" train_labels = torch.cat([train_labels, x[-1]])\n",
" \n",
"for x in temp_val_data:\n",
" validation_data = torch.cat([validation_data, x[0]])\n",
" validation_data = torch.cat([validation_data, x[1]])\n",
" validation_data = torch.cat([validation_data, x[2]])\n",
" validation_labels = torch.cat([validation_labels, x[-1]])\n",
" validation_labels = torch.cat([validation_labels, x[-1]])\n",
" validation_labels = torch.cat([validation_labels, x[-1]])\n",
" \n",
"del temp_train_data\n",
"del temp_val_data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"100923"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(train_data)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"25077"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(validation_data)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"class Network(nn.Module):\n",
" def __init__(self):\n",
" super(Network, self).__init__()\n",
" self.l1 = nn.Linear(784, 300)\n",
" self.l2 = nn.Linear(300, 10)\n",
" \n",
" def forward(self, x):\n",
" x = F.relu(self.l1(x))\n",
" x = F.dropout(x, p=0.6, training=self.training)\n",
" x = torch.sigmoid(self.l2(x))\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"def train(train_data, training_labels, validation_data, validation_labels, epoch, batch_size):\n",
" train_loss = []\n",
" val_loss = []\n",
" train_acc = 0\n",
" val_acc = 0\n",
" for bs in range(batch_size):\n",
" model.train()\n",
" batch_start = int(len(train_data)/batch_size*bs)\n",
" batch_end = int(len(train_data)/batch_size*(bs+1))\n",
" prediction = model(train_data[batch_start:batch_end])\n",
" actual = train_labels[batch_start:batch_end]\n",
" loss = F.mse_loss(prediction, actual).mean()\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" train_acc += (prediction.argmax(1) == actual.argmax(1)).cpu().detach().numpy().sum()\n",
" train_loss.append(loss.sum().cpu().detach().numpy())\n",
" train_loss = np.asarray(train_loss).mean()\n",
" train_acc = train_acc / len(train_data)\n",
" print('Epoch {} training loss: {}'.format(epoch, train_loss))\n",
" print('Epoch {} training accuracy: {}'.format(epoch, train_acc))\n",
"\n",
" for bs in range(batch_size):\n",
" model.eval()\n",
" batch_start = int(len(validation_data)/batch_size*bs)\n",
" batch_end = int(len(validation_data)/batch_size*(bs+1))\n",
" prediction = model(validation_data[batch_start:batch_end])\n",
" actual = validation_labels[batch_start:batch_end]\n",
" loss = F.mse_loss(prediction, actual).mean()\n",
" val_acc += (prediction.argmax(1) == actual.argmax(1)).cpu().detach().numpy().sum()\n",
" val_loss.append(loss.sum().cpu().detach().numpy())\n",
" val_loss = np.asarray(val_loss).mean()\n",
" val_acc = val_acc / len(validation_data)\n",
" print('Epoch {} validation loss: {}'.format(epoch, val_loss))\n",
" print('Epoch {} validation accuracy: {}\\n'.format(epoch, val_acc))"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0 training loss: 0.07884197682142258\n",
"Epoch 0 training accuracy: 0.48285326436986614\n",
"Epoch 0 validation loss: 0.044085800647735596\n",
"Epoch 0 validation accuracy: 0.7827491326713721\n",
"\n",
"Epoch 1 training loss: 0.03727114200592041\n",
"Epoch 1 training accuracy: 0.7965280461341815\n",
"Epoch 1 validation loss: 0.026454336941242218\n",
"Epoch 1 validation accuracy: 0.8639390676715716\n",
"\n",
"Epoch 2 training loss: 0.026401059702038765\n",
"Epoch 2 training accuracy: 0.8554145239439969\n",
"Epoch 2 validation loss: 0.020730365067720413\n",
"Epoch 2 validation accuracy: 0.8848745862742753\n",
"\n",
"Epoch 3 training loss: 0.02180325612425804\n",
"Epoch 3 training accuracy: 0.8766980767515828\n",
"Epoch 3 validation loss: 0.01779206097126007\n",
"Epoch 3 validation accuracy: 0.8970371256529888\n",
"\n",
"Epoch 4 training loss: 0.01913328282535076\n",
"Epoch 4 training accuracy: 0.891986960355915\n",
"Epoch 4 validation loss: 0.015950201079249382\n",
"Epoch 4 validation accuracy: 0.9064481397296328\n",
"\n",
"Epoch 5 training loss: 0.017314637079834938\n",
"Epoch 5 training accuracy: 0.900746113373562\n",
"Epoch 5 validation loss: 0.014626124873757362\n",
"Epoch 5 validation accuracy: 0.9121106990469354\n",
"\n",
"Epoch 6 training loss: 0.016011804342269897\n",
"Epoch 6 training accuracy: 0.907909990785054\n",
"Epoch 6 validation loss: 0.013589252717792988\n",
"Epoch 6 validation accuracy: 0.9179327670774016\n",
"\n",
"Epoch 7 training loss: 0.01495556440204382\n",
"Epoch 7 training accuracy: 0.9132407875310881\n",
"Epoch 7 validation loss: 0.012766270898282528\n",
"Epoch 7 validation accuracy: 0.9232763089683774\n",
"\n",
"Epoch 8 training loss: 0.014054970815777779\n",
"Epoch 8 training accuracy: 0.9186508526302231\n",
"Epoch 8 validation loss: 0.012104556895792484\n",
"Epoch 8 validation accuracy: 0.9265861147665191\n",
"\n",
"Epoch 9 training loss: 0.013417595066130161\n",
"Epoch 9 training accuracy: 0.9230502462273218\n",
"Epoch 9 validation loss: 0.011507791467010975\n",
"Epoch 9 validation accuracy: 0.9309726043785141\n",
"\n",
"Epoch 10 training loss: 0.012768560089170933\n",
"Epoch 10 training accuracy: 0.9266371392051366\n",
"Epoch 10 validation loss: 0.01100181881338358\n",
"Epoch 10 validation accuracy: 0.9346413047812737\n",
"\n",
"Epoch 11 training loss: 0.01228629145771265\n",
"Epoch 11 training accuracy: 0.9290944581512638\n",
"Epoch 11 validation loss: 0.010575709864497185\n",
"Epoch 11 validation accuracy: 0.9370339354787255\n",
"\n",
"Epoch 12 training loss: 0.011768396012485027\n",
"Epoch 12 training accuracy: 0.9327507109380418\n",
"Epoch 12 validation loss: 0.010171184316277504\n",
"Epoch 12 validation accuracy: 0.9390676715715596\n",
"\n",
"Epoch 13 training loss: 0.011383026838302612\n",
"Epoch 13 training accuracy: 0.9352278469724443\n",
"Epoch 13 validation loss: 0.009807510301470757\n",
"Epoch 13 validation accuracy: 0.9412609163775572\n",
"\n",
"Epoch 14 training loss: 0.010945308953523636\n",
"Epoch 14 training accuracy: 0.9386562032440573\n",
"Epoch 14 validation loss: 0.009531907737255096\n",
"Epoch 14 validation accuracy: 0.9425369860828647\n",
"\n",
"Epoch 15 training loss: 0.010687393136322498\n",
"Epoch 15 training accuracy: 0.9395876063929927\n",
"Epoch 15 validation loss: 0.009243899025022984\n",
"Epoch 15 validation accuracy: 0.9443314591059536\n",
"\n",
"Epoch 16 training loss: 0.01043588574975729\n",
"Epoch 16 training accuracy: 0.9416981262943036\n",
"Epoch 16 validation loss: 0.009006698615849018\n",
"Epoch 16 validation accuracy: 0.9456075288112613\n",
"\n",
"Epoch 17 training loss: 0.01012509036809206\n",
"Epoch 17 training accuracy: 0.9436104753128622\n",
"Epoch 17 validation loss: 0.008757558651268482\n",
"Epoch 17 validation accuracy: 0.9470032300514416\n",
"\n",
"Epoch 18 training loss: 0.009891458787024021\n",
"Epoch 18 training accuracy: 0.9450472142128157\n",
"Epoch 18 validation loss: 0.008560443297028542\n",
"Epoch 18 validation accuracy: 0.9487179487179487\n",
"\n",
"Epoch 19 training loss: 0.009660371579229832\n",
"Epoch 19 training accuracy: 0.9462461480534665\n",
"Epoch 19 validation loss: 0.008387024514377117\n",
"Epoch 19 validation accuracy: 0.9498345097100929\n",
"\n",
"Epoch 20 training loss: 0.00941749569028616\n",
"Epoch 20 training accuracy: 0.9477027040416951\n",
"Epoch 20 validation loss: 0.00819950457662344\n",
"Epoch 20 validation accuracy: 0.9513498424851458\n",
"\n",
"Epoch 21 training loss: 0.00916802603751421\n",
"Epoch 21 training accuracy: 0.9496348701485291\n",
"Epoch 21 validation loss: 0.008044376969337463\n",
"Epoch 21 validation accuracy: 0.9525461578338716\n",
"\n",
"Epoch 22 training loss: 0.009039959870278835\n",
"Epoch 22 training accuracy: 0.9502789255174737\n",
"Epoch 22 validation loss: 0.007887438870966434\n",
"Epoch 22 validation accuracy: 0.9532240698648163\n",
"\n",
"Epoch 23 training loss: 0.008864056318998337\n",
"Epoch 23 training accuracy: 0.9516760302408767\n",
"Epoch 23 validation loss: 0.007779685780405998\n",
"Epoch 23 validation accuracy: 0.9540216134306336\n",
"\n",
"Epoch 24 training loss: 0.008710344322025776\n",
"Epoch 24 training accuracy: 0.9526470675663625\n",
"Epoch 24 validation loss: 0.007641170173883438\n",
"Epoch 24 validation accuracy: 0.9552976831359413\n",
"\n",
"Epoch 25 training loss: 0.008589557372033596\n",
"Epoch 25 training accuracy: 0.9527560615518762\n",
"Epoch 25 validation loss: 0.007509811315685511\n",
"Epoch 25 validation accuracy: 0.9558559636320134\n",
"\n",
"Epoch 26 training loss: 0.00840373057872057\n",
"Epoch 26 training accuracy: 0.9538955441277013\n",
"Epoch 26 validation loss: 0.007363432552665472\n",
"Epoch 26 validation accuracy: 0.9566136300195398\n",
"\n",
"Epoch 27 training loss: 0.008271198719739914\n",
"Epoch 27 training accuracy: 0.9555007282779941\n",
"Epoch 27 validation loss: 0.007314777933061123\n",
"Epoch 27 validation accuracy: 0.9569725246241576\n",
"\n",
"Epoch 28 training loss: 0.00815747119486332\n",
"Epoch 28 training accuracy: 0.9560853323821131\n",
"Epoch 28 validation loss: 0.0071959588676691055\n",
"Epoch 28 validation accuracy: 0.9579694540814292\n",
"\n",
"Epoch 29 training loss: 0.008058251813054085\n",
"Epoch 29 training accuracy: 0.9564321314269294\n",
"Epoch 29 validation loss: 0.007111657410860062\n",
"Epoch 29 validation accuracy: 0.958328348686047\n",
"\n",
"Epoch 30 training loss: 0.007924782112240791\n",
"Epoch 30 training accuracy: 0.9574229858406904\n",
"Epoch 30 validation loss: 0.007042724639177322\n",
"Epoch 30 validation accuracy: 0.9585277345775013\n",
"\n",
"Epoch 31 training loss: 0.0077895959839224815\n",
"Epoch 31 training accuracy: 0.9573932602082776\n",
"Epoch 31 validation loss: 0.006950486917048693\n",
"Epoch 31 validation accuracy: 0.9595246640347729\n",
"\n",
"Epoch 32 training loss: 0.007699040696024895\n",
"Epoch 32 training accuracy: 0.9581165839303231\n",
"Epoch 32 validation loss: 0.006855316460132599\n",
"Epoch 32 validation accuracy: 0.9599633129959724\n",
"\n",
"Epoch 33 training loss: 0.007561581674963236\n",
"Epoch 33 training accuracy: 0.9593947861240748\n",
"Epoch 33 validation loss: 0.006781487260013819\n",
"Epoch 33 validation accuracy: 0.9603620847788811\n",
"\n",
"Epoch 34 training loss: 0.007479057181626558\n",
"Epoch 34 training accuracy: 0.9598109449778544\n",
"Epoch 34 validation loss: 0.006722167134284973\n",
"Epoch 34 validation accuracy: 0.9611995055229892\n",
"\n",
"Epoch 35 training loss: 0.007352054584771395\n",
"Epoch 35 training accuracy: 0.960861250656441\n",
"Epoch 35 validation loss: 0.006634308956563473\n",
"Epoch 35 validation accuracy: 0.9612393827012801\n",
"\n",
"Epoch 36 training loss: 0.007264348212629557\n",
"Epoch 36 training accuracy: 0.9613269522309087\n",
"Epoch 36 validation loss: 0.006579964887350798\n",
"Epoch 36 validation accuracy: 0.9613988914144436\n",
"\n",
"Epoch 37 training loss: 0.007171745412051678\n",
"Epoch 37 training accuracy: 0.962387166453633\n",
"Epoch 37 validation loss: 0.0065085249952971935\n",
"Epoch 37 validation accuracy: 0.9617976631973522\n",
"\n",
"Epoch 38 training loss: 0.007131265476346016\n",
"Epoch 38 training accuracy: 0.961802562349514\n",
"Epoch 38 validation loss: 0.006452173460274935\n",
"Epoch 38 validation accuracy: 0.9619970490888065\n",
"\n",
"Epoch 39 training loss: 0.007064123172312975\n",
"Epoch 39 training accuracy: 0.9626249715129356\n",
"Epoch 39 validation loss: 0.006395110860466957\n",
"Epoch 39 validation accuracy: 0.9624755752282969\n",
"\n",
"Epoch 40 training loss: 0.006988056004047394\n",
"Epoch 40 training accuracy: 0.9635662832060086\n",
"Epoch 40 validation loss: 0.006371591705828905\n",
"Epoch 40 validation accuracy: 0.9632731187941141\n",
"\n",
"Epoch 41 training loss: 0.006856718100607395\n",
"Epoch 41 training accuracy: 0.9639031737066873\n",
"Epoch 41 validation loss: 0.00629306398332119\n",
"Epoch 41 validation accuracy: 0.9633129959724049\n",
"\n",
"Epoch 42 training loss: 0.006834879517555237\n",
"Epoch 42 training accuracy: 0.9638040882653112\n",
"Epoch 42 validation loss: 0.0062795961275696754\n",
"Epoch 42 validation accuracy: 0.9629541013677873\n",
"\n",
"Epoch 43 training loss: 0.00673295883461833\n",
"Epoch 43 training accuracy: 0.964547229075632\n",
"Epoch 43 validation loss: 0.0062139034271240234\n",
"Epoch 43 validation accuracy: 0.9636718905770227\n",
"\n",
"Epoch 44 training loss: 0.006661310326308012\n",
"Epoch 44 training accuracy: 0.9653993638714664\n",
"Epoch 44 validation loss: 0.006167777813971043\n",
"Epoch 44 validation accuracy: 0.96446943414284\n",
"\n",
"Epoch 45 training loss: 0.006662233266979456\n",
"Epoch 45 training accuracy: 0.9655380834893929\n",
"Epoch 45 validation loss: 0.006144508719444275\n",
"Epoch 45 validation accuracy: 0.96446943414284\n",
"\n",
"Epoch 46 training loss: 0.006595049984753132\n",
"Epoch 46 training accuracy: 0.9657560714604203\n",
"Epoch 46 validation loss: 0.006099705584347248\n",
"Epoch 46 validation accuracy: 0.9647485743908761\n",
"\n",
"Epoch 47 training loss: 0.006484870798885822\n",
"Epoch 47 training accuracy: 0.9662911328438513\n",
"Epoch 47 validation loss: 0.006039380561560392\n",
"Epoch 47 validation accuracy: 0.9648682059257487\n",
"\n",
"Epoch 48 training loss: 0.0064193918369710445\n",
"Epoch 48 training accuracy: 0.9667271087859061\n",
"Epoch 48 validation loss: 0.005995721090584993\n",
"Epoch 48 validation accuracy: 0.9655062407784025\n",
"\n",
"Epoch 49 training loss: 0.006370623596012592\n",
"Epoch 49 training accuracy: 0.9673711641548507\n",
"Epoch 49 validation loss: 0.00597657123580575\n",
"Epoch 49 validation accuracy: 0.9651872233520756\n",
"\n",
"Epoch 50 training loss: 0.0062427399680018425\n",
"Epoch 50 training accuracy: 0.9677377802879423\n",
"Epoch 50 validation loss: 0.005949628539383411\n",
"Epoch 50 validation accuracy: 0.965944889739602\n",
"\n",
"Epoch 51 training loss: 0.006284354254603386\n",
"Epoch 51 training accuracy: 0.967390981243126\n",
"Epoch 51 validation loss: 0.005894490983337164\n",
"Epoch 51 validation accuracy: 0.9657853810264385\n",
"\n",
"Epoch 52 training loss: 0.006267307326197624\n",
"Epoch 52 training accuracy: 0.9674999752286396\n",
"Epoch 52 validation loss: 0.00585838733240962\n",
"Epoch 52 validation accuracy: 0.9665829245922558\n",
"\n",
"Epoch 53 training loss: 0.006208399776369333\n",
"Epoch 53 training accuracy: 0.9680746707886211\n",
"Epoch 53 validation loss: 0.005838208366185427\n",
"Epoch 53 validation accuracy: 0.9667025561271284\n",
"\n",
"Epoch 54 training loss: 0.006156044080853462\n",
"Epoch 54 training accuracy: 0.9681043964210339\n",
"Epoch 54 validation loss: 0.005806652829051018\n",
"Epoch 54 validation accuracy: 0.9667424333054193\n",
"\n",
"Epoch 55 training loss: 0.006079786457121372\n",
"Epoch 55 training accuracy: 0.9689565312168683\n",
"Epoch 55 validation loss: 0.005748886149376631\n",
"Epoch 55 validation accuracy: 0.9672608366232005\n",
"\n",
"Epoch 56 training loss: 0.005994039122015238\n",
"Epoch 56 training accuracy: 0.9696005865858129\n",
"Epoch 56 validation loss: 0.005743207409977913\n",
"Epoch 56 validation accuracy: 0.9671810822666188\n",
"\n",
"Epoch 57 training loss: 0.006007436662912369\n",
"Epoch 57 training accuracy: 0.9690060739375563\n",
"Epoch 57 validation loss: 0.0057016052305698395\n",
"Epoch 57 validation accuracy: 0.9678589942975635\n",
"\n",
"Epoch 58 training loss: 0.005948720499873161\n",
"Epoch 58 training accuracy: 0.9699473856306293\n",
"Epoch 58 validation loss: 0.0056551373563706875\n",
"Epoch 58 validation accuracy: 0.9680583801890178\n",
"\n",
"Epoch 59 training loss: 0.005832315888255835\n",
"Epoch 59 training accuracy: 0.9706211666319867\n",
"Epoch 59 validation loss: 0.005674091167747974\n",
"Epoch 59 validation accuracy: 0.9678589942975635\n",
"\n",
"Epoch 60 training loss: 0.005850379820913076\n",
"Epoch 60 training accuracy: 0.9699870198071797\n",
"Epoch 60 validation loss: 0.005655990913510323\n",
"Epoch 60 validation accuracy: 0.9678191171192726\n",
"\n",
"Epoch 61 training loss: 0.005793438758701086\n",
"Epoch 61 training accuracy: 0.970462629925785\n",
"Epoch 61 validation loss: 0.00561485206708312\n",
"Epoch 61 validation accuracy: 0.968018503010727\n",
"\n",
"Epoch 62 training loss: 0.00577025581151247\n",
"Epoch 62 training accuracy: 0.9707301606175005\n",
"Epoch 62 validation loss: 0.0055701094679534435\n",
"Epoch 62 validation accuracy: 0.9681381345455996\n",
"\n",
"Epoch 63 training loss: 0.00572052551433444\n",
"Epoch 63 training accuracy: 0.9713643074423075\n",
"Epoch 63 validation loss: 0.005567863583564758\n",
"Epoch 63 validation accuracy: 0.9682178889021813\n",
"\n",
"Epoch 64 training loss: 0.005703536327928305\n",
"Epoch 64 training accuracy: 0.9711661365595553\n",
"Epoch 64 validation loss: 0.0055429707281291485\n",
"Epoch 64 validation accuracy: 0.9685369063285082\n",
"\n",
"Epoch 65 training loss: 0.0056631918996572495\n",
"Epoch 65 training accuracy: 0.9714633928836836\n",
"Epoch 65 validation loss: 0.005500172730535269\n",
"Epoch 65 validation accuracy: 0.9688160465765443\n",
"\n",
"Epoch 66 training loss: 0.0056600007228553295\n",
"Epoch 66 training accuracy: 0.9721074482526283\n",
"Epoch 66 validation loss: 0.005500392057001591\n",
"Epoch 66 validation accuracy: 0.9689755552897077\n",
"\n",
"Epoch 67 training loss: 0.0055995844304561615\n",
"Epoch 67 training accuracy: 0.9723848874884813\n",
"Epoch 67 validation loss: 0.005494162440299988\n",
"Epoch 67 validation accuracy: 0.9687761693982534\n",
"\n",
"Epoch 68 training loss: 0.0055932835675776005\n",
"Epoch 68 training accuracy: 0.972087631164353\n",
"Epoch 68 validation loss: 0.005469959229230881\n",
"Epoch 68 validation accuracy: 0.9692148183594529\n",
"\n",
"Epoch 69 training loss: 0.005538692697882652\n",
"Epoch 69 training accuracy: 0.9722362593264172\n",
"Epoch 69 validation loss: 0.005439432337880135\n",
"Epoch 69 validation accuracy: 0.9692945727160346\n",
"\n",
"Epoch 70 training loss: 0.005462995730340481\n",
"Epoch 70 training accuracy: 0.9723353447677933\n",
"Epoch 70 validation loss: 0.005439480766654015\n",
"Epoch 70 validation accuracy: 0.9694142042509072\n",
"\n",
"Epoch 71 training loss: 0.005392757710069418\n",
"Epoch 71 training accuracy: 0.9731379368429397\n",
"Epoch 71 validation loss: 0.005391059909015894\n",
"Epoch 71 validation accuracy: 0.9699326075686885\n",
"\n",
"Epoch 72 training loss: 0.005394927691668272\n",
"Epoch 72 training accuracy: 0.973663089682233\n",
"Epoch 72 validation loss: 0.005383166950196028\n",
"Epoch 72 validation accuracy: 0.9696534673206524\n",
"\n",
"Epoch 73 training loss: 0.00538697699084878\n",
"Epoch 73 training accuracy: 0.9732766564608663\n",
"Epoch 73 validation loss: 0.005384756717830896\n",
"Epoch 73 validation accuracy: 0.9698129760338159\n",
"\n",
"Epoch 74 training loss: 0.005329283885657787\n",
"Epoch 74 training accuracy: 0.9735441871525816\n",
"Epoch 74 validation loss: 0.005368389189243317\n",
"Epoch 74 validation accuracy: 0.9700921162818519\n",
"\n",
"Epoch 75 training loss: 0.005318328272551298\n",
"Epoch 75 training accuracy: 0.9738612605649852\n",
"Epoch 75 validation loss: 0.005321497097611427\n",
"Epoch 75 validation accuracy: 0.9700921162818519\n",
"\n",
"Epoch 76 training loss: 0.005331655498594046\n",
"Epoch 76 training accuracy: 0.9734550102553432\n",
"Epoch 76 validation loss: 0.005318239331245422\n",
"Epoch 76 validation accuracy: 0.9702117478167245\n",
"\n",
"Epoch 77 training loss: 0.005248404107987881\n",
"Epoch 77 training accuracy: 0.9742873279629024\n",
"Epoch 77 validation loss: 0.005320506636053324\n",
"Epoch 77 validation accuracy: 0.9699326075686885\n",
"\n",
"Epoch 78 training loss: 0.005253283306956291\n",
"Epoch 78 training accuracy: 0.9737027238587834\n",
"Epoch 78 validation loss: 0.005250290967524052\n",
"Epoch 78 validation accuracy: 0.9703313793515971\n",
"\n",
"Epoch 79 training loss: 0.005223565269261599\n",
"Epoch 79 training accuracy: 0.9743963219484161\n",
"Epoch 79 validation loss: 0.005253038369119167\n",
"Epoch 79 validation accuracy: 0.9707301511345057\n",
"\n",
"Epoch 80 training loss: 0.0051938313990831375\n",
"Epoch 80 training accuracy: 0.974346779227728\n",
"Epoch 80 validation loss: 0.005250085145235062\n",
"Epoch 80 validation accuracy: 0.9708896598476692\n",
"\n",
"Epoch 81 training loss: 0.005150769371539354\n",
"Epoch 81 training accuracy: 0.9743765048601409\n",
"Epoch 81 validation loss: 0.005202730186283588\n",
"Epoch 81 validation accuracy: 0.9710491685608327\n",
"\n",
"Epoch 82 training loss: 0.005094591993838549\n",
"Epoch 82 training accuracy: 0.9751493713028745\n",
"Epoch 82 validation loss: 0.005250499118119478\n",
"Epoch 82 validation accuracy: 0.9707301511345057\n",
"\n",
"Epoch 83 training loss: 0.005153241567313671\n",
"Epoch 83 training accuracy: 0.9747431209932325\n",
"Epoch 83 validation loss: 0.005220973864197731\n",
"Epoch 83 validation accuracy: 0.9708896598476692\n",
"\n",
"Epoch 84 training loss: 0.005067689809948206\n",
"Epoch 84 training accuracy: 0.9752682738325258\n",
"Epoch 84 validation loss: 0.0051800827495753765\n",
"Epoch 84 validation accuracy: 0.9710890457391235\n",
"\n",
"Epoch 85 training loss: 0.004997973330318928\n",
"Epoch 85 training accuracy: 0.9754961703476908\n",
"Epoch 85 validation loss: 0.005149488337337971\n",
"Epoch 85 validation accuracy: 0.9715675718786139\n",
"\n",
"Epoch 86 training loss: 0.004979357589036226\n",
"Epoch 86 training accuracy: 0.9756249814214797\n",
"Epoch 86 validation loss: 0.005171299912035465\n",
"Epoch 86 validation accuracy: 0.971527694700323\n",
"\n",
"Epoch 87 training loss: 0.004984478931874037\n",
"Epoch 87 training accuracy: 0.9756051643332045\n",
"Epoch 87 validation loss: 0.00515236658975482\n",
"Epoch 87 validation accuracy: 0.9710092913825418\n",
"\n",
"Epoch 88 training loss: 0.00498884217813611\n",
"Epoch 88 training accuracy: 0.9755853472449293\n",
"Epoch 88 validation loss: 0.005136134568601847\n",
"Epoch 88 validation accuracy: 0.9713681859871596\n",
"\n",
"Epoch 89 training loss: 0.004946259316056967\n",
"Epoch 89 training accuracy: 0.9756249814214797\n",
"Epoch 89 validation loss: 0.005148548167198896\n",
"Epoch 89 validation accuracy: 0.9710092913825418\n",
"\n",
"Epoch 90 training loss: 0.004885654430836439\n",
"Epoch 90 training accuracy: 0.9766653785559288\n",
"Epoch 90 validation loss: 0.005129080731421709\n",
"Epoch 90 validation accuracy: 0.9716872034134865\n",
"\n",
"Epoch 91 training loss: 0.004882515873759985\n",
"Epoch 91 training accuracy: 0.9761699513490483\n",
"Epoch 91 validation loss: 0.005134152248501778\n",
"Epoch 91 validation accuracy: 0.9714080631654504\n",
"\n",
"Epoch 92 training loss: 0.004912896081805229\n",
"Epoch 92 training accuracy: 0.9759024206573328\n",
"Epoch 92 validation loss: 0.005110086407512426\n",
"Epoch 92 validation accuracy: 0.9713283088088687\n",
"\n",
"Epoch 93 training loss: 0.004869314841926098\n",
"Epoch 93 training accuracy: 0.9768635494386809\n",
"Epoch 93 validation loss: 0.005069723352789879\n",
"Epoch 93 validation accuracy: 0.9718068349483591\n",
"\n",
"Epoch 94 training loss: 0.004821186885237694\n",
"Epoch 94 training accuracy: 0.9767842810855801\n",
"Epoch 94 validation loss: 0.0050615654326975346\n",
"Epoch 94 validation accuracy: 0.97184671212665\n",
"\n",
"Epoch 95 training loss: 0.004831160884350538\n",
"Epoch 95 training accuracy: 0.9766455614676536\n",
"Epoch 95 validation loss: 0.0050490666180849075\n",
"Epoch 95 validation accuracy: 0.9719264664832317\n",
"\n",
"Epoch 96 training loss: 0.004834283608943224\n",
"Epoch 96 training accuracy: 0.9763284880552501\n",
"Epoch 96 validation loss: 0.0050497399643063545\n",
"Epoch 96 validation accuracy: 0.9718865893049408\n",
"\n",
"Epoch 97 training loss: 0.0047189281322062016\n",
"Epoch 97 training accuracy: 0.977626507337277\n",
"Epoch 97 validation loss: 0.005037985742092133\n",
"Epoch 97 validation accuracy: 0.9720062208398134\n",
"\n",
"Epoch 98 training loss: 0.0047546131536364555\n",
"Epoch 98 training accuracy: 0.9774085193662495\n",
"Epoch 98 validation loss: 0.0050339424051344395\n",
"Epoch 98 validation accuracy: 0.9718865893049408\n",
"\n",
"Epoch 99 training loss: 0.004727411083877087\n",
"Epoch 99 training accuracy: 0.9775373304400384\n",
"Epoch 99 validation loss: 0.005019115284085274\n",
"Epoch 99 validation accuracy: 0.9723252382661403\n",
"\n"
]
}
],
"source": [
"model = Network().to(device)\n",
"for epoch in range(100):\n",
" LR = 0.001\n",
" if epoch >= 15 & epoch < 30:\n",
" LR = 0.0001\n",
" elif epoch >= 40 & epoch < 60:\n",
" LR = 0.00001\n",
" elif epoch >= 60:\n",
" LR = 0.000005\n",
" optimizer = optim.Adam(model.parameters(), lr=LR)\n",
" train(train_data, train_labels, validation_data, validation_labels, epoch, batch_size=500)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/paperspace/anaconda3/envs/pytorch4/lib/python3.6/site-packages/torch/serialization.py:241: UserWarning: Couldn't retrieve source code for container of type Network. It won't be checked for correctness upon loading.\n",
" \"type \" + obj.__name__ + \". It won't be checked \"\n"
]
}
],
"source": [
"torch.save(model, 'models/mnist_model.pt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment