Skip to content

Instantly share code, notes, and snippets.

@mrbkdad
Created October 17, 2017 06:10
Show Gist options
  • Save mrbkdad/ec9bca6695a563450c314b2bfa275259 to your computer and use it in GitHub Desktop.
Save mrbkdad/ec9bca6695a563450c314b2bfa275259 to your computer and use it in GitHub Desktop.
feed forward network sample by pytorch, with input layor, hidden layor, output layor, relu function, cross entropy cost function, SGD optiminzer
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import random\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"import torch\n",
"from torch.autograd import Variable\n",
"import torchvision.datasets as dsets\n",
"import torchvision.transforms as transforms"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 0. Load MNIST data\n",
"\n",
"- Load MNIST datasets(Train, Test Dataset)\n",
"- Split Train datasets and Validation datasets\n",
"- Create Data loader"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
"Processing...\n",
"Done!\n"
]
}
],
"source": [
"## Load MNIST datasets\n",
"mnist_trn = dsets.MNIST('MNIST_torch/', train=True,\n",
" transform=transforms.ToTensor(), download=True)\n",
"mnist_test = dsets.MNIST('MNIST_torch/', train=False,\n",
" transform=transforms.ToTensor(), download=True)\n",
"\n",
"## split Validation datasets\n",
"X_val = mnist_trn.train_data[55000:]\n",
"Y_val = mnist_trn.train_labels[55000:]\n",
"\n",
"mnist_trn.train_data = mnist_trn.train_data[:55000]\n",
"mnist_trn.train_labels = mnist_trn.train_labels[:55000]\n",
"\n",
"# dataset loader\n",
"batch_size = 100\n",
"data_loader = torch.utils.data.DataLoader(dataset=mnist_trn,\n",
" batch_size=batch_size,\n",
" shuffle=True,\n",
" num_workers=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Build the graph\n",
"- pytorch 에서는 Model을 torch.nn.Linear을 이용하여 생성\n",
"- Linear 내부에 Weight 와 Bias 를 저장\n",
"- Loss(Cost) 함수를 이용하여 Loss 를 계산하고 Optimizer 을 이용하여 SGD 를 처리\n",
"- Cross Entropy Loss 함수 내부에 Soft Max 내장 되어 있음"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"2개의 hidden layers 생성\n",
"각 hidden layers의 차원을 784, 300으로 정의"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential (\n",
" (0): Linear (784 -> 784)\n",
" (1): ReLU ()\n",
" (2): Linear (784 -> 300)\n",
" (3): ReLU ()\n",
" (4): Linear (300 -> 10)\n",
")\n"
]
}
],
"source": [
"dim_X = 784\n",
"hidden_dim_1 = 784\n",
"hidden_dim_2 = 300\n",
"dim_Y = 10\n",
"\n",
"layer1 = torch.nn.Linear(dim_X,hidden_dim_1,bias=True)\n",
"layer2 = torch.nn.Linear(hidden_dim_1,hidden_dim_2,bias=True)\n",
"output = torch.nn.Linear(hidden_dim_2,dim_Y,bias=True)\n",
"relu = torch.nn.ReLU()\n",
"\n",
"model = torch.nn.Sequential(layer1,relu,layer2,relu,output)\n",
"print(model)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Parameter containing:\n",
" 2.9070e-02 -3.5594e-02 -1.9954e-02 ... -2.7415e-02 -2.9358e-02 2.1725e-02\n",
" 1.1178e-02 8.5221e-03 2.8094e-02 ... -2.7357e-02 2.6644e-02 3.1264e-02\n",
" 1.9607e-02 -3.3357e-02 -1.4927e-02 ... -1.8765e-02 -1.5429e-02 3.2533e-02\n",
" ... ⋱ ... \n",
" 2.0248e-02 -1.4154e-02 2.6829e-02 ... -1.9432e-03 2.1957e-02 -2.3967e-02\n",
" 1.3988e-02 2.8683e-02 -7.3375e-04 ... -2.1320e-03 -1.6197e-02 -1.6408e-02\n",
" 9.5783e-03 7.2504e-03 -1.2358e-02 ... -2.0973e-02 1.9776e-03 -2.8294e-02\n",
" [torch.FloatTensor of size 784x784], Parameter containing:\n",
" 1.00000e-02 *\n",
" 3.9799\n",
" -2.9322\n",
" -4.0393\n",
" 3.7418\n",
" -5.7545\n",
" 3.1851\n",
" 2.9619\n",
" -0.7115\n",
" 2.5640\n",
" 0.6892\n",
" [torch.FloatTensor of size 10])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model[0].weight,model[4].bias"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Loss Function, Optimizer"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"learning_rate = 0.01\n",
"cost_fn = torch.nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate)\n",
"#optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"## for calculating accuracy\n",
"softmax = torch.nn.Softmax()\n",
"def acc_fn(X,Y):\n",
" y_hats = model(Variable(X.view(-1,dim_X).float()))\n",
" val_acc = sum(softmax(y_hats).max(1)[1].data.squeeze() == Y)\n",
" return val_acc/y_hats.size()[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[500 step] validation accuracy: 74.16%, loss: 0.0151 \n",
"[1000 step] validation accuracy: 88.10%, loss: 0.0057 \n",
"[1500 step] validation accuracy: 90.68%, loss: 0.0039 \n",
"[2000 step] validation accuracy: 91.30%, loss: 0.0042 \n",
"[2500 step] validation accuracy: 91.90%, loss: 0.0048 \n",
"[3000 step] validation accuracy: 92.40%, loss: 0.0040 \n",
"[3500 step] validation accuracy: 92.52%, loss: 0.0025 \n",
"[4000 step] validation accuracy: 92.72%, loss: 0.0039 \n",
"[4500 step] validation accuracy: 93.38%, loss: 0.0031 \n",
"[5000 step] validation accuracy: 93.58%, loss: 0.0030 \n",
"[5500 step] validation accuracy: 93.48%, loss: 0.0037 \n",
"[6000 step] validation accuracy: 93.98%, loss: 0.0020 \n",
"[6500 step] validation accuracy: 94.10%, loss: 0.0017 \n",
"[7000 step] validation accuracy: 94.18%, loss: 0.0019 \n",
"[7500 step] validation accuracy: 94.38%, loss: 0.0026 \n",
"[8000 step] validation accuracy: 94.48%, loss: 0.0039 \n",
"[8500 step] validation accuracy: 94.88%, loss: 0.0028 \n",
"[9000 step] validation accuracy: 94.66%, loss: 0.0016 \n",
"[9500 step] validation accuracy: 95.02%, loss: 0.0024 \n",
"[10000 step] validation accuracy: 95.26%, loss: 0.0022 \n"
]
}
],
"source": [
"import itertools\n",
"data_iter = itertools.cycle(iter(data_loader))\n",
"\n",
"for i in range(10000):\n",
" batch_xs,batch_ys = next(data_iter)\n",
" batch_xs = Variable(batch_xs.view(-1,dim_X))\n",
" batch_ys = Variable(batch_ys)\n",
" y_hats = model(batch_xs)\n",
" loss = cost_fn(y_hats,batch_ys)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" if (i+1) % 500 == 0:\n",
" y_hats = model(Variable(X_val.view(-1,dim_X).float()))\n",
" val_acc = acc_fn(X_val,Y_val)\n",
" print(\"[{} step] validation accuracy: {:.2%}, loss: {:.4f} \".format(i + 1,\n",
" val_acc,loss.data[0]/batch_size))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy : 94.06%, loss : 0.00\n"
]
}
],
"source": [
"X_test = mnist_test.test_data\n",
"Y_test = mnist_test.test_labels\n",
"test_loss = cost_fn(model(Variable(X_test.float().view(-1,dim_X))),Variable(Y_test)).data[0]/X_test.size()[0]\n",
"acc_val = acc_fn(X_test,Y_test)\n",
"print('Test accuracy : {:.2%}, loss : {:.2f}'.format(acc_val,test_loss))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### random test"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1242 image label - 4\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADMZJREFUeJzt3V+IXfW5xvHnMbZeJIVo6wkhDZ0KciB6kcIQgkdLYzWo\nFGMulHpxiCCdXvSEFqooFk3AG5E2oTcWpjQklhzTA21JkHL8E0+wkRKdqFVHT6MNCZ2QP60pdAJK\nm8nbi1mRaZz92zt71t5rT97vB4bZe71r7fWymWfW2nv9+TkiBCCfy5puAEAzCD+QFOEHkiL8QFKE\nH0iK8ANJEX4gKcIPJEX4gaQu7+fKbHM6IdBjEeFO5pvTlt/2bbb/YPsD2w/P5bUA9Je7Pbff9gJJ\nhyTdKmlC0muS7o2IdwvLsOUHeqwfW/5Vkj6IiMMR8XdJuyStm8PrAeijuYR/maQ/zXg+UU37F7ZH\nbI/ZHpvDugDUrOdf+EXEqKRRid1+YJDMZct/TNLyGc+/WE0DMA/MJfyvSbrW9pdtf1bSNyXtqact\nAL3W9W5/RJy1/V+SnpO0QNK2iBivrTMAPdX1ob6uVsZnfqDn+nKSD4D5i/ADSRF+ICnCDyRF+IGk\nCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiB\npAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkuh6iW5JsH5E0KWlK0tmIGK6jKQC9N6fwV9ZE\nxF9qeB0AfcRuP5DUXMMfkl60fdD2SB0NAeiPue723xgRx2z/m6QXbP9/RLw8c4bqnwL/GIAB44io\n54XszZLORMQPC/PUszIALUWEO5mv691+2wttf+78Y0lrJb3T7esB6K+57PYvkfRr2+df578j4n9r\n6QpAz9W229/Rytjtn3cef/zxYn18fLxYf+qpp1rWhofLp4UcPny4WMfser7bD2B+I/xAUoQfSIrw\nA0kRfiApwg8kVcdVfZjH1q5dW6w/+OCDxfpHH31UrC9evPiie0J/sOUHkiL8QFKEH0iK8ANJEX4g\nKcIPJEX4gaS4pPcSNzQ0VKzv27evWD969Gix/uGHHxbr1113XcvaihUristOTU0V65gdl/QCKCL8\nQFKEH0iK8ANJEX4gKcIPJEX4gaS4nv8St2bNmmJ92bJlxfru3buL9VtuuaVYn5ycbFnjOH6z2PID\nSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFJtj/Pb3ibpG5JORcT11bSrJP1C0pCkI5LuiYi/9q5NlJSu\n2X/yySeLy27ZsqVYf+ihh4r1Xbt2Fevt7geA5nSy5d8u6bYLpj0saW9EXCtpb/UcwDzSNvwR8bKk\n0xdMXidpR/V4h6S7au4LQI91+5l/SUQcrx6fkLSkpn4A9Mmcz+2PiCjdm8/2iKSRua4HQL263fKf\ntL1Ukqrfp1rNGBGjETEcEcNdrgtAD3Qb/j2SNlSPN0gqX/oFYOC0Db/tZyT9TtK/256wfb+kJyTd\navt9SbdUzwHMI20/80fEvS1KX6+5F7Rw2WXl/9GbNm1qWTtz5kxx2a1bt3bVE+Y/zvADkiL8QFKE\nH0iK8ANJEX4gKcIPJMWtu+eBdkNZ33fffS1rIyPlM6tPnDjRTUufWL16dbHOJb2Diy0/kBThB5Ii\n/EBShB9IivADSRF+ICnCDyTFcf4B0O6S3UcffbRYP3jwYMvatm3buuqpU1dffXWxfvr0hfd+xaBg\nyw8kRfiBpAg/kBThB5Ii/EBShB9IivADSXGcfwCUhtiWpLvvvrtYv/3221vWpqamummpNq+88krX\ny15+efnP8+zZs12/NtjyA2kRfiApwg8kRfiBpAg/kBThB5Ii/EBSbY/z294m6RuSTkXE9dW0zZK+\nJenP1WyPRMRvetXkpa50nF6SJiYmivV9+/bV2E29rrjiipa1xx57rLjsgQMHivXnnnuuq54wrZMt\n/3ZJt80yfWtErKx+CD4wz7QNf0S8LInbsQCXmLl85t9o+y3b22xfWVtHAPqi2/D/RNI1klZKOi7p\nR61mtD1ie8z2WJfrAtADXYU/Ik5GxFREnJP0U0mrCvOORsRwRAx32ySA+nUVfttLZzxdL+mdetoB\n0C+dHOp7RtLXJH3B9oSkTZK+ZnulpJB0RNK3e9gjgB5wRPRvZXb/VjaPrFixolgfHx8v1t94442W\ntZ07dxaXPXToULG+Zs2aYn3jxo3Feum+/a+++mpx2fXr1xfrXM8/u4hwJ/Nxhh+QFOEHkiL8QFKE\nH0iK8ANJEX4gKQ71DQC7fGRm1aqWJ1BKkjZt2tSydsMNN3TV03kLFiwo1hctWlSsP/DAAy1rW7du\nLS577ty5Yh2z41AfgCLCDyRF+IGkCD+QFOEHkiL8QFKEH0iK4/woandJ70svvVSs33TTTS1r+/fv\n76onlHGcH0AR4QeSIvxAUoQfSIrwA0kRfiApwg8k1fa+/cjt5ptvLtYnJyeL9bExRmkbVGz5gaQI\nP5AU4QeSIvxAUoQfSIrwA0kRfiCptsf5bS+X9LSkJZJC0mhE/Nj2VZJ+IWlI0hFJ90TEX3vXKpqw\ncOHCYr3dvfU//vjjOttBjTrZ8p+V9P2IWCFptaTv2F4h6WFJeyPiWkl7q+cA5om24Y+I4xHxevV4\nUtJ7kpZJWidpRzXbDkl39apJAPW7qM/8tockfUXSAUlLIuJ4VTqh6Y8FAOaJjs/tt71I0i8lfS8i\n/jZzfLmIiFb357M9Imlkro0CqFdHW37bn9F08HdGxK+qySdtL63qSyWdmm3ZiBiNiOGIGK6jYQD1\naBt+T2/ifybpvYjYMqO0R9KG6vEGSbvrbw9Ar3Sy2/8fkv5T0tu236ymPSLpCUn/Y/t+SUcl3dOb\nFtGkkydPNt0CeqRt+CNiv6RW9wH/er3tAOgXzvADkiL8QFKEH0iK8ANJEX4gKcIPJMUQ3ShavXp1\nsf78888X60NDQy1rp0+f7qYltMEQ3QCKCD+QFOEHkiL8QFKEH0iK8ANJEX4gKY7zY07aHau/8847\nW9b2799fdzsQx/kBtEH4gaQIP5AU4QeSIvxAUoQfSIrwA0l1PFwXMJt21/MvXry4T53gYrHlB5Ii\n/EBShB9IivADSRF+ICnCDyRF+IGk2h7nt71c0tOSlkgKSaMR8WPbmyV9S9Kfq1kfiYjf9KpRDKbt\n27cX62vXrm1Ze/bZZ2vuBhejk5N8zkr6fkS8bvtzkg7afqGqbY2IH/auPQC90jb8EXFc0vHq8aTt\n9yQt63VjAHrroj7z2x6S9BVJB6pJG22/ZXub7StbLDNie8z22Jw6BVCrjsNve5GkX0r6XkT8TdJP\nJF0jaaWm9wx+NNtyETEaEcMRMVxDvwBq0lH4bX9G08HfGRG/kqSIOBkRUxFxTtJPJa3qXZsA6tY2\n/LYt6WeS3ouILTOmL50x23pJ79TfHoBeaXvrbts3SvqtpLclnasmPyLpXk3v8oekI5K+XX05WHot\nbt0N9Fint+7mvv3AJYb79gMoIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivAD\nSRF+ICnCDyTV7yG6/yLp6IznX6imDaJB7W1Q+5LorVt19valTmfs6/X8n1q5PTao9/Yb1N4GtS+J\n3rrVVG/s9gNJEX4gqabDP9rw+ksGtbdB7Uuit2410lujn/kBNKfpLT+AhjQSftu32f6D7Q9sP9xE\nD63YPmL7bdtvNj3EWDUM2inb78yYdpXtF2y/X/2edZi0hnrbbPtY9d69afuOhnpbbvv/bL9re9z2\nd6vpjb53hb4aed/6vttve4GkQ5JulTQh6TVJ90bEu31tpAXbRyQNR0Tjx4Rtf1XSGUlPR8T11bQn\nJZ2OiCeqf5xXRsRDA9LbZklnmh65uRpQZunMkaUl3SXpPjX43hX6ukcNvG9NbPlXSfogIg5HxN8l\n7ZK0roE+Bl5EvCzp9AWT10naUT3eoek/nr5r0dtAiIjjEfF69XhS0vmRpRt97wp9NaKJ8C+T9KcZ\nzyc0WEN+h6QXbR+0PdJ0M7NYMmNkpBOSljTZzCzajtzcTxeMLD0w7103I17XjS/8Pu3GiFgp6XZJ\n36l2bwdSTH9mG6TDNR2N3Nwvs4ws/Ykm37tuR7yuWxPhPyZp+YznX6ymDYSIOFb9PiXp1xq80YdP\nnh8ktfp9quF+PjFIIzfPNrK0BuC9G6QRr5sI/2uSrrX9ZduflfRNSXsa6ONTbC+svoiR7YWS1mrw\nRh/eI2lD9XiDpN0N9vIvBmXk5lYjS6vh927gRryOiL7/SLpD09/4/1HSD5rooUVf10j6ffUz3nRv\nkp7R9G7gPzT93cj9kj4vaa+k9yW9KOmqAert55oezfktTQdtaUO93ajpXfq3JL1Z/dzR9HtX6KuR\n940z/ICk+MIPSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBS/wQ8/ilgYWr5cgAAAABJRU5ErkJg\ngg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x222bcc88>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"prediction : 9\n"
]
}
],
"source": [
"rv = random.randint(0,X_test.size()[0])\n",
"image = X_test[rv]\n",
"\n",
"print(rv,'image label - ',Y_test[rv])\n",
"plt.imshow(image.numpy(),cmap='gray')\n",
"plt.show()\n",
"print('prediction : ',model(Variable(image.view(-1,dim_X).float())).max(1)[1].data[0][0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment