Created
October 16, 2017 08:08
-
-
Save mrbkdad/4ffdb472a169c1bbfb8b5b1b8e89ef9c to your computer and use it in GitHub Desktop.
simple softmax regression by pytorch with mnist datasets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import random\n", | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"%matplotlib inline\n", | |
"\n", | |
"import torch\n", | |
"from torch.autograd import Variable\n", | |
"import torchvision.datasets as dsets\n", | |
"import torchvision.transforms as transforms" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 0. Load MNIST data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", | |
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", | |
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", | |
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", | |
"Processing...\n", | |
"Done!\n" | |
] | |
} | |
], | |
"source": [ | |
"mnist_trn = dsets.MNIST('MNIST_torch/', train=True,\n", | |
" transform=transforms.ToTensor(), download=True)\n", | |
"mnist_test = dsets.MNIST('MNIST_torch/', train=False,\n", | |
" transform=transforms.ToTensor(), download=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"X_trn = mnist_trn.train_data[:55000]\n", | |
"Y_trn = mnist_trn.train_labels[:55000] ## One hot incoding X\n", | |
"X_val = mnist_trn.train_data[55000:]\n", | |
"Y_val = mnist_trn.train_labels[55000:]\n", | |
"X_test = mnist_test.test_data\n", | |
"Y_test = mnist_test.test_labels" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Number of training points: 55000\n", | |
"Number of validation points: 5000\n", | |
"Number of test points: 10000\n" | |
] | |
} | |
], | |
"source": [ | |
"num_trn = Y_trn.size()[0]\n", | |
"num_val = Y_val.size()[0]\n", | |
"num_test = Y_test.size()[0]\n", | |
"\n", | |
"print(\"Number of training points: \", num_trn)\n", | |
"print(\"Number of validation points: \", num_val)\n", | |
"print(\"Number of test points: \", num_test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(55000, 10)\n" | |
] | |
} | |
], | |
"source": [ | |
"## sciky learn one hot encoder 이용 하여 encoding\n", | |
"from sklearn.preprocessing import OneHotEncoder\n", | |
"ohe = OneHotEncoder()\n", | |
"ohe.fit(Y_trn.view(-1,1).numpy()) \n", | |
"Y_trn_oh = ohe.transform(Y_trn.view(-1,1).numpy()).toarray()\n", | |
"Y_val_oh = ohe.transform(Y_val.view(-1,1).numpy()).toarray()\n", | |
"Y_test_oh = ohe.transform(Y_test.view(-1,1).numpy()).toarray()\n", | |
"\n", | |
"print(Y_trn_oh.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Dimension of X: 784 (28 x 28)\n", | |
"Dimension of Y: 10\n" | |
] | |
} | |
], | |
"source": [ | |
"dim_X = X_trn.size()[1]*X_trn.size()[2]\n", | |
"pixel_X = int(np.sqrt(dim_X)) # np.sqrt의 출력이 float32이므로, 이를 int 자료형으로 변경\n", | |
"dim_Y = Y_trn_oh.shape[1]\n", | |
"\n", | |
"print(\"Dimension of X: %d (%d x %d)\" % (dim_X, pixel_X, pixel_X))\n", | |
"print(\"Dimension of Y: \", dim_Y)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 1. Build the graph\n", | |
"- pytorch 에서는 Model을 torch.nn.Linear을 이용하여 생성\n", | |
"- Linear 내부에 Weight 와 Bias 를 저장\n", | |
"- Loss(Cost) 함수를 이용하여 Loss 를 계산하고 Optimizer 을 이용하여 SGD 를 처리\n", | |
"- Cross Entropy Loss 함수 내부에 Soft Max 내장 되어 있음" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 134, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Linear (784 -> 10)\n" | |
] | |
} | |
], | |
"source": [ | |
"model = torch.nn.Linear(784,10,bias=True)\n", | |
"print(model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 135, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(Parameter containing:\n", | |
" 1.00000e-02 *\n", | |
" -1.6169 -1.3665 -1.3400 ... -3.1919 -0.0318 -2.7797\n", | |
" 1.2666 -2.1546 -3.0524 ... -3.4247 -1.5329 3.2291\n", | |
" 1.8188 2.0751 -0.9358 ... 1.4454 -1.9341 -0.6081\n", | |
" ... ⋱ ... \n", | |
" 1.5825 -1.3580 0.9915 ... -1.5497 3.0195 -1.8963\n", | |
" 1.1116 -2.5431 -3.3340 ... -0.6030 -2.5569 -0.6509\n", | |
" -2.7273 -1.5376 0.7860 ... 3.3613 0.1742 0.2450\n", | |
" [torch.FloatTensor of size 10x784], Parameter containing:\n", | |
" 1.00000e-02 *\n", | |
" -1.2841\n", | |
" -2.5258\n", | |
" 2.6121\n", | |
" 2.2594\n", | |
" -3.5567\n", | |
" -2.5424\n", | |
" 2.6417\n", | |
" -0.4884\n", | |
" 3.2215\n", | |
" 0.6781\n", | |
" [torch.FloatTensor of size 10])" | |
] | |
}, | |
"execution_count": 135, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.weight,model.bias" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"y hat 은 model(x_val) 형식으로 model을 이용하여 현재 model의 state dict에 저장된 weight & bias을 이용하여 계산됨, 이때 torch.autograd.Variable을 이용하며 model에 입력시 TensorFlow의 placeholder 처럼 입력 데이터를 원소로 하는 array 형태로 입력 처리" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"logit: [ 52.03900909 -62.28118896 -65.3991394 -116.99193573 80.05447388\n", | |
" 59.60297394 -80.47360229 46.66101837 19.46802902 12.53259563]\n", | |
"max: [[ 80.05447388]]\n", | |
"max index [[4]]\n", | |
"softmax: [ 6.80829308e-13 0.00000000e+00 0.00000000e+00 0.00000000e+00\n", | |
" 1.00000000e+00 1.31227973e-09 0.00000000e+00 3.14344981e-15\n", | |
" 4.87126000e-27 4.73826802e-30]\n", | |
"max: [ 1.]\n", | |
"max index: [4]\n" | |
] | |
} | |
], | |
"source": [ | |
"X = Variable(X_trn[0].float().view(-1,dim_X)) ## -1이 None 과 같은 역할\n", | |
"Y_hat = model(X)\n", | |
"print('logit:',Y_hat.data[0].numpy())\n", | |
"print('max:',Y_hat.data.max(1)[0].numpy())\n", | |
"print('max index',Y_hat.data.max(1)[1].numpy())\n", | |
"\n", | |
"softmax = torch.nn.Softmax()\n", | |
"Y_softmax = softmax(Variable(Y_hat.data[0].view(1,-1)))\n", | |
"print('softmax:',Y_softmax.data[0].numpy())\n", | |
"print('max:',Y_softmax.data.max(1)[0][0].numpy())\n", | |
"print('max index:',Y_softmax.data.max(1)[1][0].numpy())" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Loss Function, Optimizer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 142, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"learning_rate = 0.005\n", | |
"cost_fn = torch.nn.CrossEntropyLoss()\n", | |
"optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate)\n", | |
"#optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Training" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 143, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"batch_size = 100\n", | |
"\n", | |
"mnist_trn.train_data = mnist_trn.train_data[:55000]\n", | |
"mnist_trn.train_labels = mnist_trn.train_labels[:55000]\n", | |
"\n", | |
"# dataset loader\n", | |
"data_loader = torch.utils.data.DataLoader(dataset=mnist_trn,\n", | |
" batch_size=batch_size,\n", | |
" shuffle=False,\n", | |
" num_workers=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 144, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([55000, 28, 28])" | |
] | |
}, | |
"execution_count": 144, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mnist_trn.train_data.size()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 145, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def acc_fn(X,Y):\n", | |
" y_hats = model(Variable(X.view(-1,dim_X).float()))\n", | |
" val_acc = sum(softmax(y_hats).max(1)[1].data.squeeze() == Y)\n", | |
" return val_acc/y_hats.size()[0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 146, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[500 step] validation accuracy: 91.16%, loss: 0.00 \n", | |
"[1000 step] validation accuracy: 91.24%, loss: 0.00 \n", | |
"[1500 step] validation accuracy: 91.24%, loss: 0.00 \n", | |
"[2000 step] validation accuracy: 91.30%, loss: 0.00 \n", | |
"[2500 step] validation accuracy: 91.26%, loss: 0.00 \n", | |
"[3000 step] validation accuracy: 91.24%, loss: 0.00 \n", | |
"[3500 step] validation accuracy: 91.24%, loss: 0.00 \n", | |
"[4000 step] validation accuracy: 91.22%, loss: 0.00 \n", | |
"[4500 step] validation accuracy: 91.14%, loss: 0.00 \n", | |
"[5000 step] validation accuracy: 91.10%, loss: 0.00 \n", | |
"[5500 step] validation accuracy: 91.12%, loss: 0.00 \n", | |
"[6000 step] validation accuracy: 91.10%, loss: 0.00 \n", | |
"[6500 step] validation accuracy: 91.14%, loss: 0.00 \n", | |
"[7000 step] validation accuracy: 91.22%, loss: 0.00 \n", | |
"[7500 step] validation accuracy: 91.18%, loss: 0.00 \n", | |
"[8000 step] validation accuracy: 91.18%, loss: 0.00 \n", | |
"[8500 step] validation accuracy: 91.22%, loss: 0.00 \n", | |
"[9000 step] validation accuracy: 91.18%, loss: 0.00 \n", | |
"[9500 step] validation accuracy: 91.16%, loss: 0.00 \n", | |
"[10000 step] validation accuracy: 91.12%, loss: 0.00 \n" | |
] | |
} | |
], | |
"source": [ | |
"import itertools\n", | |
"data_iter = itertools.cycle(iter(data_loader))\n", | |
"\n", | |
"for i in range(10000):\n", | |
" batch_xs,batch_ys = next(data_iter)\n", | |
" batch_xs = Variable(batch_xs.view(-1,dim_X))\n", | |
" batch_ys = Variable(batch_ys)\n", | |
" y_hats = model(batch_xs)\n", | |
" loss = cost_fn(y_hats,batch_ys)\n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" if (i+1) % 500 == 0:\n", | |
" y_hats = model(Variable(X_val.view(-1,dim_X).float()))\n", | |
" val_acc = acc_fn(X_val,Y_val)\n", | |
" print(\"[{} step] validation accuracy: {:.2%}, loss: {:.4f} \".format(i + 1,\n", | |
" val_acc,loss.data[0]/batch_size))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 147, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Test accuracy : 89.41%, loss : 0.01\n" | |
] | |
} | |
], | |
"source": [ | |
"test_loss = cost_fn(model(Variable(X_test.float().view(-1,dim_X))),Variable(Y_test)).data[0]/X_test.size()[0]\n", | |
"acc_val = acc_fn(X_test,Y_test)\n", | |
"print('Test accuracy : {:.2%}, loss : {:.2f}'.format(acc_val,test_loss))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment