Created
October 16, 2017 08:07
-
-
Save mrbkdad/383ca69fcc37424d796e2c7e53be1682 to your computer and use it in GitHub Desktop.
exploring MNIST Dataset in pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import random\n", | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"%matplotlib inline\n", | |
"\n", | |
"import torch\n", | |
"import torchvision.datasets as dsets\n", | |
"from torchvision.transforms import ToTensor" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dsets.MNIST?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"mnist_trn = dsets.MNIST('MNIST_torch/',train=True,\n", | |
" transform=ToTensor(),target_transform=ToTensor(),download=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"mnist_test = dsets.MNIST('MNIST_torch/',train=False,transform=ToTensor(),target_transform=ToTensor(),download=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(torch.Size([60000, 28, 28]), torch.Size([60000]))" | |
] | |
}, | |
"execution_count": 33, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mnist_trn.train_data.size(),mnist_trn.train_labels.size()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 85, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"One hot encoding : (55000, 10)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],\n", | |
" [ 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", | |
" [ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],\n", | |
" [ 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n", | |
" [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n", | |
" [ 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", | |
" [ 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n", | |
" [ 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n", | |
" [ 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n", | |
" [ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])" | |
] | |
}, | |
"execution_count": 85, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_trn, Y_trn = mnist_trn.train_data[:55000],mnist_trn.train_labels[:55000]\n", | |
"X_val, Y_val = mnist_trn.train_data[55000:],mnist_trn.train_labels[55000:]\n", | |
"X_test, Y_test = mnist_test.test_data,mnist_test.test_labels\n", | |
"\n", | |
"## 28*28 이미지를 플랫하게 변경\n", | |
"X_trn = X_trn.view(X_trn.size()[0],-1)\n", | |
"X_val = X_val.view(X_val.size()[0],-1)\n", | |
"X_test = X_test.view(X_test.size()[0],-1)\n", | |
"\n", | |
"## One hot encoding\n", | |
"Y_label = np.zeros(Y_trn.size()[0]*10)\n", | |
"Y_label.shape = (-1,10)\n", | |
"for i in range(Y_trn.size()[0]):\n", | |
" Y_label[i][Y_trn[i]] = 1\n", | |
"Y_trn = torch.Tensor(Y_label)\n", | |
"print('One hot encoding : ',Y_label.shape)\n", | |
"Y_label[:10]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 86, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Type of X: <class 'torch.ByteTensor'>\n", | |
"Type of Y: <class 'torch.FloatTensor'>\n" | |
] | |
} | |
], | |
"source": [ | |
"print(\"Type of X: \", type(X_trn))\n", | |
"print(\"Type of Y: \", type(Y_trn))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 87, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"torch.Size([55000, 784])\n", | |
"torch.Size([55000, 10])\n" | |
] | |
} | |
], | |
"source": [ | |
"print(X_trn.size())\n", | |
"print(Y_trn.size())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 88, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Number of training points: 55000\n", | |
"Number of validation points: 5000\n", | |
"Number of test points: 10000\n" | |
] | |
} | |
], | |
"source": [ | |
"num_trn = Y_trn.size()[0]\n", | |
"num_val = Y_val.size()[0]\n", | |
"num_test = Y_test.size()[0]\n", | |
"\n", | |
"print(\"Number of training points: \", num_trn)\n", | |
"print(\"Number of validation points: \", num_val)\n", | |
"print(\"Number of test points: \", num_test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 95, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Dimension of X: 784 (28 x 28)\n", | |
"Dimension of Y: 10\n" | |
] | |
} | |
], | |
"source": [ | |
"dim_X = X_trn.size()[1]\n", | |
"pixel_X = int(np.sqrt(dim_X)) # np.sqrt의 출력이 float32이므로, 이를 int 자료형으로 변경\n", | |
"dim_Y = Y_trn.size()[1]\n", | |
"\n", | |
"print(\"Dimension of X: %d (%d x %d)\" % (dim_X, pixel_X, pixel_X))\n", | |
"print(\"Dimension of Y: \", dim_Y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 119, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def display_mnist(index):\n", | |
" sample_img = X_trn[index].numpy().reshape(-1, pixel_X)\n", | |
" sample_label_encoded = torch.max(Y_trn[index],0)[1][0]\n", | |
" sample_label = np.argmax(sample_label_encoded)\n", | |
"\n", | |
" print(\"One-hot encoded class: \", sample_label_encoded)\n", | |
" print(\"Class: \", sample_label)\n", | |
" plt.imshow(sample_img,cmap='gray')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 120, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"One-hot encoded class: 3\n", | |
"Class: 0\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADUVJREFUeJzt3W+sVPWdx/HPByw+AB6IDSyKSLfq2o1/bjc3ZBPWDUYh\namqgiZj6QGkWlz6owSabsIZoSrKp0XVbdx81oYFATWvbxH/YbBYasoElrgQ0tdqyLdeGbVngssZq\nbaJplO8+uIfNFe/8Zpg5M2fu/b5fCZmZ850z55vRzz1n5nfO/BwRApDPrKYbANAMwg8kRfiBpAg/\nkBThB5Ii/EBShB9IivADSRF+IKmLBrkx25xOCPRZRLiT5/W057d9m+1f2h6z/VAvrwVgsNztuf22\nZ0v6laRVkk5IOizpnoj4RWEd9vxAnw1iz79c0lhE/Doi/ijpB5LW9PB6AAaol/BfLum3kx6fqJZ9\njO2Nto/YPtLDtgDUrJcv/KY6tPjEYX1EbJO0TeKwHxgmvez5T0i6YtLjJZJO9tYOgEHpJfyHJV1t\n+zO250j6kqTd9bQFoN+6PuyPiA9tPyBpj6TZknZExM9r6wxAX3U91NfVxvjMD/TdQE7yATB9EX4g\nKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+\nICnCDyRF+IGkCD+Q1ECn6EZ3RkZGeqqXbNiwoVjfvn1716/dzsGDB4v1sbGxvm0b7PmBtAg/kBTh\nB5Ii/EBShB9IivADSRF+IKmeZum1fVzSe5I+kvRhRIy2eT6z9HbhwQcfLNYff/zxlrXZs2cX1501\nq/z3/+zZs8V6LzZv3lysP/nkk33b9kzW6Sy9dZzkc3NEvFXD6wAYIA77gaR6DX9I2mv7Fdsb62gI\nwGD0eti/IiJO2l4o6Se2/ysiDkx+QvVHgT8MwJDpac8fESer2zOSnpO0fIrnbIuI0XZfBgIYrK7D\nb3uu7fnn7ktaLemNuhoD0F+9HPYvkvSc7XOv8/2I+LdaugLQdz2N81/wxhjn74s333yzZW3p0qXF\ndZsc529n3bp1xfrzzz8/oE6ml07H+RnqA5Ii/EBShB9IivADSRF+ICnCDyTFUN8McMstt7SsLVy4\nsLjuww8/XKy3GyocHx8v1q+88spiveT9998v1letWlWsHzp0qOttT2cM9QEoIvxAUoQfSIrwA0kR\nfiApwg8kRfiBpJiiewbYt29f1+sePny4WL/++uuL9RtuuKFYb3ceQcncuXOL9Ysvvrjr1wZ7fiAt\nwg8kRfiBpAg/kBThB5Ii/EBShB9IinH+5MbGxnqq33jjjcV6u58G72Xdas4IdIk9P5AU4QeSIvxA\nUoQfSIrwA0kRfiApwg8k1Xac3/YOSV+QdCYirquWLZD0Q0nLJB2XdHdE/K5/baJf5s2bV6w/8cQT\nxfrq1auL9X5O8T3IOSdmok72/Dsl3Xbesock7YuIqyXtqx4DmEbahj8iDkh6+7zFayTtqu7vkrS2\n5r4A9Fm3n/kXRcQpSapuy3NCARg6fT+33/ZGSRv7vR0AF6bbPf+47cWSVN2eafXEiNgWEaMRMdrl\ntgD0Qbfh3y1pfXV/vaQX6mkHwKC0Db/tpyX9p6Q/s33C9gZJj0laZfuYpFXVYwDTSNvP/BFxT4tS\n60nhMTTWri0PxNx3333F+p133llnOxginOEHJEX4gaQIP5AU4QeSIvxAUoQfSIqf7h4CS5YsKdb3\n79/f9Wtfeumlxfr8+fOL9X5ektvO6dOni/V33313QJ3MTOz5gaQIP5AU4QeSIvxAUoQfSIrwA0kR\nfiApxvmHwEUXlf8zLF26dECdDJc9e/YU66+99tqAOpmZ2PMDSRF+ICnCDyRF+IGkCD+QFOEHkiL8\nQFIe5DTHtplTuQtbt24t1h955JGuX3vWrPLf/yav52/X28qVK4v1Xn4HYTqLCHfyPPb8QFKEH0iK\n8ANJEX4gKcIPJEX4gaQIP5BU2+v5be+Q9AVJZyLiumrZVkl/K+l/q6dtiYh/7VeT2e3cubNYv/32\n21vWev3d/gULFhTrTRrkOSozUSd7/p2Sbpti+ZMRMVL9I/jANNM2/BFxQNLbA+gFwAD18pn/Ads/\ns73D9iW1dQRgILoN/7clfVbSiKRTkr7Z6om2N9o+YvtIl9sC0AddhT8ixiPio4g4K+k7kpYXnrst\nIkYjYrTbJgHUr6vw21486eEXJb1RTzsABqWTob6nJa2U9GnbJyR9XdJK2yOSQtJxSV/pY48A+oDr\n+ZMbGRkp1u+9995ifdOmTXW28zHtrue/+eabi/UDBw7U2c60wfX8AIoIP5AU4QeSIvxAUoQfSIrw\nA0kxRXdyV111VbF+1113DagTDBp7fiApwg8kRfiBpAg/kBThB5Ii/EBShB9IinH+GeDaa69tWdu8\neXNx3Xbj/JdddllXPWH4secHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQY5x8C8+bNK9bbXVN/zTXX\ntKytX7++uG67n8c+e/Zssd6LgwcPFutPPfVUsZ71p7nrwp4fSIrwA0kRfiApwg8kRfiBpAg/kBTh\nB5JqO85v+wpJ35X0J5LOStoWEf9ie4GkH0paJum4pLsj4nf9a7VZt956a8vali1biuva5RmT58yZ\nU6wvX768WC/pdZy+3frvvPNOsf7yyy+3rN1///3FdcfHx4t19KaTPf+Hkv4uIj4n6S8lfdX2n0t6\nSNK+iLha0r7qMYBpom34I+JURLxa3X9P0lFJl0taI2lX9bRdktb2q0kA9bugz/y2l0n6vKRDkhZF\nxClp4g+EpIV1Nwegfzo+t9/2PEnPSPpaRPy+3efYSettlLSxu/YA9EtHe37bn9JE8L8XEc9Wi8dt\nL67qiyWdmWrdiNgWEaMRMVpHwwDq0Tb8ntjFb5d0NCK+Nam0W9K5S8bWS3qh/vYA9Esnh/0rJN0r\n6XXbP62WbZH0mKQf2d4g6TeS1vWnxeFwxx13tKzddNNNxXWbvGy2V7t27SrW2112u3///jrbQY3a\nhj8iDkpq9QH/lnrbATAonOEHJEX4gaQIP5AU4QeSIvxAUoQfSMoRMbiN2YPbWM1GRkZa1l588cXi\nuu2mue51nL90We2mTZuK67700kvFervLaj/44INiHYMXER2de8+eH0iK8ANJEX4gKcIPJEX4gaQI\nP5AU4QeSYpy/BitWrCjW200l/eijjxbrx44dK9ZPnz7dsrZ3797iuph5GOcHUET4gaQIP5AU4QeS\nIvxAUoQfSIrwA0kxzg/MMIzzAygi/EBShB9IivADSRF+ICnCDyRF+IGk2obf9hW2/932Uds/t/1g\ntXyr7f+x/dPqX+sJ7AEMnbYn+dheLGlxRLxqe76kVyStlXS3pD9ExD91vDFO8gH6rtOTfC7q4IVO\nSTpV3X/P9lFJl/fWHoCmXdBnftvLJH1e0qFq0QO2f2Z7h+1LWqyz0fYR20d66hRArTo+t9/2PEn7\nJX0jIp61vUjSW5JC0j9o4qPB37R5DQ77gT7r9LC/o/Db/pSkH0vaExHfmqK+TNKPI+K6Nq9D+IE+\nq+3CHtuWtF3S0cnBr74IPOeLkt640CYBNKeTb/v/StJ/SHpd0rm5pLdIukfSiCYO+49L+kr15WDp\ntdjzA31W62F/XQg/0H9czw+giPADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIP\nJEX4gaQIP5BU2x/wrNlbkv570uNPV8uG0bD2Nqx9SfTWrTp7u7LTJw70ev5PbNw+EhGjjTVQMKy9\nDWtfEr11q6neOOwHkiL8QFJNh39bw9svGdbehrUvid661UhvjX7mB9Ccpvf8ABrSSPht32b7l7bH\nbD/URA+t2D5u+/Vq5uFGpxirpkE7Y/uNScsW2P6J7WPV7ZTTpDXU21DM3FyYWbrR927YZrwe+GG/\n7dmSfiVplaQTkg5LuicifjHQRlqwfVzSaEQ0PiZs+68l/UHSd8/NhmT7HyW9HRGPVX84L4mIvx+S\n3rbqAmdu7lNvrWaW/rIafO/qnPG6Dk3s+ZdLGouIX0fEHyX9QNKaBvoYehFxQNLb5y1eI2lXdX+X\nJv7nGbgWvQ2FiDgVEa9W99+TdG5m6Ubfu0JfjWgi/JdL+u2kxyc0XFN+h6S9tl+xvbHpZqaw6NzM\nSNXtwob7OV/bmZsH6byZpYfmvetmxuu6NRH+qWYTGaYhhxUR8ReSbpf01erwFp35tqTPamIat1OS\nvtlkM9XM0s9I+lpE/L7JXiaboq9G3rcmwn9C0hWTHi+RdLKBPqYUESer2zOSntPEx5RhMn5uktTq\n9kzD/fy/iBiPiI8i4qyk76jB966aWfoZSd+LiGerxY2/d1P11dT71kT4D0u62vZnbM+R9CVJuxvo\n4xNsz62+iJHtuZJWa/hmH94taX11f72kFxrs5WOGZebmVjNLq+H3bthmvG7kJJ9qKOOfJc2WtCMi\nvjHwJqZg+081sbeXJq54/H6Tvdl+WtJKTVz1NS7p65Kel/QjSUsl/UbSuogY+BdvLXpbqQucublP\nvbWaWfqQGnzv6pzxupZ+OMMPyIkz/ICkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJPV/FVbvx/KN\nRsIAAAAASUVORK5CYII=\n", | |
"text/plain": [ | |
"<matplotlib.figure.Figure at 0x7f3a0c496dd8>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"random_index = random.randint(0, num_trn - 1)\n", | |
"display_mnist(random_index)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment