Skip to content

Instantly share code, notes, and snippets.

@mrbkdad
Created October 16, 2017 08:07
Show Gist options
  • Save mrbkdad/383ca69fcc37424d796e2c7e53be1682 to your computer and use it in GitHub Desktop.
Save mrbkdad/383ca69fcc37424d796e2c7e53be1682 to your computer and use it in GitHub Desktop.
exploring MNIST Dataset in pytorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"import torch\n",
"import torchvision.datasets as dsets\n",
"from torchvision.transforms import ToTensor"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"dsets.MNIST?"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"mnist_trn = dsets.MNIST('MNIST_torch/',train=True,\n",
" transform=ToTensor(),target_transform=ToTensor(),download=True)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"mnist_test = dsets.MNIST('MNIST_torch/',train=False,transform=ToTensor(),target_transform=ToTensor(),download=True)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([60000, 28, 28]), torch.Size([60000]))"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mnist_trn.train_data.size(),mnist_trn.train_labels.size()"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"One hot encoding : (55000, 10)\n"
]
},
{
"data": {
"text/plain": [
"array([[ 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],\n",
" [ 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],\n",
" [ 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
" [ 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [ 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [ 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n",
" [ 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])"
]
},
"execution_count": 85,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_trn, Y_trn = mnist_trn.train_data[:55000],mnist_trn.train_labels[:55000]\n",
"X_val, Y_val = mnist_trn.train_data[55000:],mnist_trn.train_labels[55000:]\n",
"X_test, Y_test = mnist_test.test_data,mnist_test.test_labels\n",
"\n",
"## 28*28 이미지를 플랫하게 변경\n",
"X_trn = X_trn.view(X_trn.size()[0],-1)\n",
"X_val = X_val.view(X_val.size()[0],-1)\n",
"X_test = X_test.view(X_test.size()[0],-1)\n",
"\n",
"## One hot encoding\n",
"Y_label = np.zeros(Y_trn.size()[0]*10)\n",
"Y_label.shape = (-1,10)\n",
"for i in range(Y_trn.size()[0]):\n",
" Y_label[i][Y_trn[i]] = 1\n",
"Y_trn = torch.Tensor(Y_label)\n",
"print('One hot encoding : ',Y_label.shape)\n",
"Y_label[:10]"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Type of X: <class 'torch.ByteTensor'>\n",
"Type of Y: <class 'torch.FloatTensor'>\n"
]
}
],
"source": [
"print(\"Type of X: \", type(X_trn))\n",
"print(\"Type of Y: \", type(Y_trn))"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([55000, 784])\n",
"torch.Size([55000, 10])\n"
]
}
],
"source": [
"print(X_trn.size())\n",
"print(Y_trn.size())"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of training points: 55000\n",
"Number of validation points: 5000\n",
"Number of test points: 10000\n"
]
}
],
"source": [
"num_trn = Y_trn.size()[0]\n",
"num_val = Y_val.size()[0]\n",
"num_test = Y_test.size()[0]\n",
"\n",
"print(\"Number of training points: \", num_trn)\n",
"print(\"Number of validation points: \", num_val)\n",
"print(\"Number of test points: \", num_test)"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dimension of X: 784 (28 x 28)\n",
"Dimension of Y: 10\n"
]
}
],
"source": [
"dim_X = X_trn.size()[1]\n",
"pixel_X = int(np.sqrt(dim_X)) # np.sqrt의 출력이 float32이므로, 이를 int 자료형으로 변경\n",
"dim_Y = Y_trn.size()[1]\n",
"\n",
"print(\"Dimension of X: %d (%d x %d)\" % (dim_X, pixel_X, pixel_X))\n",
"print(\"Dimension of Y: \", dim_Y)"
]
},
{
"cell_type": "code",
"execution_count": 119,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def display_mnist(index):\n",
" sample_img = X_trn[index].numpy().reshape(-1, pixel_X)\n",
" sample_label_encoded = torch.max(Y_trn[index],0)[1][0]\n",
" sample_label = np.argmax(sample_label_encoded)\n",
"\n",
" print(\"One-hot encoded class: \", sample_label_encoded)\n",
" print(\"Class: \", sample_label)\n",
" plt.imshow(sample_img,cmap='gray')"
]
},
{
"cell_type": "code",
"execution_count": 120,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"One-hot encoded class: 3\n",
"Class: 0\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADUVJREFUeJzt3W+sVPWdx/HPByw+AB6IDSyKSLfq2o1/bjc3ZBPWDUYh\namqgiZj6QGkWlz6owSabsIZoSrKp0XVbdx81oYFATWvbxH/YbBYasoElrgQ0tdqyLdeGbVngssZq\nbaJplO8+uIfNFe/8Zpg5M2fu/b5fCZmZ850z55vRzz1n5nfO/BwRApDPrKYbANAMwg8kRfiBpAg/\nkBThB5Ii/EBShB9IivADSRF+IKmLBrkx25xOCPRZRLiT5/W057d9m+1f2h6z/VAvrwVgsNztuf22\nZ0v6laRVkk5IOizpnoj4RWEd9vxAnw1iz79c0lhE/Doi/ijpB5LW9PB6AAaol/BfLum3kx6fqJZ9\njO2Nto/YPtLDtgDUrJcv/KY6tPjEYX1EbJO0TeKwHxgmvez5T0i6YtLjJZJO9tYOgEHpJfyHJV1t\n+zO250j6kqTd9bQFoN+6PuyPiA9tPyBpj6TZknZExM9r6wxAX3U91NfVxvjMD/TdQE7yATB9EX4g\nKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+\nICnCDyRF+IGkCD+Q1ECn6EZ3RkZGeqqXbNiwoVjfvn1716/dzsGDB4v1sbGxvm0b7PmBtAg/kBTh\nB5Ii/EBShB9IivADSRF+IKmeZum1fVzSe5I+kvRhRIy2eT6z9HbhwQcfLNYff/zxlrXZs2cX1501\nq/z3/+zZs8V6LzZv3lysP/nkk33b9kzW6Sy9dZzkc3NEvFXD6wAYIA77gaR6DX9I2mv7Fdsb62gI\nwGD0eti/IiJO2l4o6Se2/ysiDkx+QvVHgT8MwJDpac8fESer2zOSnpO0fIrnbIuI0XZfBgIYrK7D\nb3uu7fnn7ktaLemNuhoD0F+9HPYvkvSc7XOv8/2I+LdaugLQdz2N81/wxhjn74s333yzZW3p0qXF\ndZsc529n3bp1xfrzzz8/oE6ml07H+RnqA5Ii/EBShB9IivADSRF+ICnCDyTFUN8McMstt7SsLVy4\nsLjuww8/XKy3GyocHx8v1q+88spiveT9998v1letWlWsHzp0qOttT2cM9QEoIvxAUoQfSIrwA0kR\nfiApwg8kRfiBpJiiewbYt29f1+sePny4WL/++uuL9RtuuKFYb3ceQcncuXOL9Ysvvrjr1wZ7fiAt\nwg8kRfiBpAg/kBThB5Ii/EBShB9IinH+5MbGxnqq33jjjcV6u58G72Xdas4IdIk9P5AU4QeSIvxA\nUoQfSIrwA0kRfiApwg8k1Xac3/YOSV+QdCYirquWLZD0Q0nLJB2XdHdE/K5/baJf5s2bV6w/8cQT\nxfrq1auL9X5O8T3IOSdmok72/Dsl3Xbesock7YuIqyXtqx4DmEbahj8iDkh6+7zFayTtqu7vkrS2\n5r4A9Fm3n/kXRcQpSapuy3NCARg6fT+33/ZGSRv7vR0AF6bbPf+47cWSVN2eafXEiNgWEaMRMdrl\ntgD0Qbfh3y1pfXV/vaQX6mkHwKC0Db/tpyX9p6Q/s33C9gZJj0laZfuYpFXVYwDTSNvP/BFxT4tS\n60nhMTTWri0PxNx3333F+p133llnOxginOEHJEX4gaQIP5AU4QeSIvxAUoQfSIqf7h4CS5YsKdb3\n79/f9Wtfeumlxfr8+fOL9X5ektvO6dOni/V33313QJ3MTOz5gaQIP5AU4QeSIvxAUoQfSIrwA0kR\nfiApxvmHwEUXlf8zLF26dECdDJc9e/YU66+99tqAOpmZ2PMDSRF+ICnCDyRF+IGkCD+QFOEHkiL8\nQFIe5DTHtplTuQtbt24t1h955JGuX3vWrPLf/yav52/X28qVK4v1Xn4HYTqLCHfyPPb8QFKEH0iK\n8ANJEX4gKcIPJEX4gaQIP5BU2+v5be+Q9AVJZyLiumrZVkl/K+l/q6dtiYh/7VeT2e3cubNYv/32\n21vWev3d/gULFhTrTRrkOSozUSd7/p2Sbpti+ZMRMVL9I/jANNM2/BFxQNLbA+gFwAD18pn/Ads/\ns73D9iW1dQRgILoN/7clfVbSiKRTkr7Z6om2N9o+YvtIl9sC0AddhT8ixiPio4g4K+k7kpYXnrst\nIkYjYrTbJgHUr6vw21486eEXJb1RTzsABqWTob6nJa2U9GnbJyR9XdJK2yOSQtJxSV/pY48A+oDr\n+ZMbGRkp1u+9995ifdOmTXW28zHtrue/+eabi/UDBw7U2c60wfX8AIoIP5AU4QeSIvxAUoQfSIrw\nA0kxRXdyV111VbF+1113DagTDBp7fiApwg8kRfiBpAg/kBThB5Ii/EBShB9IinH+GeDaa69tWdu8\neXNx3Xbj/JdddllXPWH4secHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQY5x8C8+bNK9bbXVN/zTXX\ntKytX7++uG67n8c+e/Zssd6LgwcPFutPPfVUsZ71p7nrwp4fSIrwA0kRfiApwg8kRfiBpAg/kBTh\nB5JqO85v+wpJ35X0J5LOStoWEf9ie4GkH0paJum4pLsj4nf9a7VZt956a8vali1biuva5RmT58yZ\nU6wvX768WC/pdZy+3frvvPNOsf7yyy+3rN1///3FdcfHx4t19KaTPf+Hkv4uIj4n6S8lfdX2n0t6\nSNK+iLha0r7qMYBpom34I+JURLxa3X9P0lFJl0taI2lX9bRdktb2q0kA9bugz/y2l0n6vKRDkhZF\nxClp4g+EpIV1Nwegfzo+t9/2PEnPSPpaRPy+3efYSettlLSxu/YA9EtHe37bn9JE8L8XEc9Wi8dt\nL67qiyWdmWrdiNgWEaMRMVpHwwDq0Tb8ntjFb5d0NCK+Nam0W9K5S8bWS3qh/vYA9Esnh/0rJN0r\n6XXbP62WbZH0mKQf2d4g6TeS1vWnxeFwxx13tKzddNNNxXWbvGy2V7t27SrW2112u3///jrbQY3a\nhj8iDkpq9QH/lnrbATAonOEHJEX4gaQIP5AU4QeSIvxAUoQfSMoRMbiN2YPbWM1GRkZa1l588cXi\nuu2mue51nL90We2mTZuK67700kvFervLaj/44INiHYMXER2de8+eH0iK8ANJEX4gKcIPJEX4gaQI\nP5AU4QeSYpy/BitWrCjW200l/eijjxbrx44dK9ZPnz7dsrZ3797iuph5GOcHUET4gaQIP5AU4QeS\nIvxAUoQfSIrwA0kxzg/MMIzzAygi/EBShB9IivADSRF+ICnCDyRF+IGk2obf9hW2/932Uds/t/1g\ntXyr7f+x/dPqX+sJ7AEMnbYn+dheLGlxRLxqe76kVyStlXS3pD9ExD91vDFO8gH6rtOTfC7q4IVO\nSTpV3X/P9lFJl/fWHoCmXdBnftvLJH1e0qFq0QO2f2Z7h+1LWqyz0fYR20d66hRArTo+t9/2PEn7\nJX0jIp61vUjSW5JC0j9o4qPB37R5DQ77gT7r9LC/o/Db/pSkH0vaExHfmqK+TNKPI+K6Nq9D+IE+\nq+3CHtuWtF3S0cnBr74IPOeLkt640CYBNKeTb/v/StJ/SHpd0rm5pLdIukfSiCYO+49L+kr15WDp\ntdjzA31W62F/XQg/0H9czw+giPADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIP\nJEX4gaQIP5BU2x/wrNlbkv570uNPV8uG0bD2Nqx9SfTWrTp7u7LTJw70ev5PbNw+EhGjjTVQMKy9\nDWtfEr11q6neOOwHkiL8QFJNh39bw9svGdbehrUvid661UhvjX7mB9Ccpvf8ABrSSPht32b7l7bH\nbD/URA+t2D5u+/Vq5uFGpxirpkE7Y/uNScsW2P6J7WPV7ZTTpDXU21DM3FyYWbrR927YZrwe+GG/\n7dmSfiVplaQTkg5LuicifjHQRlqwfVzSaEQ0PiZs+68l/UHSd8/NhmT7HyW9HRGPVX84L4mIvx+S\n3rbqAmdu7lNvrWaW/rIafO/qnPG6Dk3s+ZdLGouIX0fEHyX9QNKaBvoYehFxQNLb5y1eI2lXdX+X\nJv7nGbgWvQ2FiDgVEa9W99+TdG5m6Ubfu0JfjWgi/JdL+u2kxyc0XFN+h6S9tl+xvbHpZqaw6NzM\nSNXtwob7OV/bmZsH6byZpYfmvetmxuu6NRH+qWYTGaYhhxUR8ReSbpf01erwFp35tqTPamIat1OS\nvtlkM9XM0s9I+lpE/L7JXiaboq9G3rcmwn9C0hWTHi+RdLKBPqYUESer2zOSntPEx5RhMn5uktTq\n9kzD/fy/iBiPiI8i4qyk76jB966aWfoZSd+LiGerxY2/d1P11dT71kT4D0u62vZnbM+R9CVJuxvo\n4xNsz62+iJHtuZJWa/hmH94taX11f72kFxrs5WOGZebmVjNLq+H3bthmvG7kJJ9qKOOfJc2WtCMi\nvjHwJqZg+081sbeXJq54/H6Tvdl+WtJKTVz1NS7p65Kel/QjSUsl/UbSuogY+BdvLXpbqQucublP\nvbWaWfqQGnzv6pzxupZ+OMMPyIkz/ICkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJPV/FVbvx/KN\nRsIAAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f3a0c496dd8>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"random_index = random.randint(0, num_trn - 1)\n",
"display_mnist(random_index)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment