Created
May 5, 2018 16:25
-
-
Save Z30G0D/8efda5c3c3d27bf2e736b2aa0dfc7c7a to your computer and use it in GitHub Desktop.
noisy_labels solution for imubit
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Imubit\n", | |
"This is my solution for the challenge. I will explain my steps through this notebook. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import tensorflow as tf\n", | |
"import matplotlib.pyplot as plt\n", | |
"from sklearn.preprocessing import OneHotEncoder\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"I used the \"mnist\" package to ease my work. located here:https://github.com/datapythonista/mnist" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Use the retry module or similar alternatives.\n", | |
"WARNING:tensorflow:From <ipython-input-2-6441638eb0cd>:1: load_dataset (from tensorflow.contrib.learn.python.learn.datasets) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Please use tf.data.\n", | |
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\__init__.py:80: load_mnist (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n", | |
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:300: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n", | |
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Please write your own downloading logic.\n", | |
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Please use tf.data to implement this functionality.\n", | |
"Extracting MNIST-data\\train-images-idx3-ubyte.gz\n", | |
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Please use tf.data to implement this functionality.\n", | |
"Extracting MNIST-data\\train-labels-idx1-ubyte.gz\n", | |
"Extracting MNIST-data\\t10k-images-idx3-ubyte.gz\n", | |
"Extracting MNIST-data\\t10k-labels-idx1-ubyte.gz\n", | |
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n" | |
] | |
} | |
], | |
"source": [ | |
"mnist = tf.contrib.learn.datasets.load_dataset(\"mnist\")\n", | |
"train_data = mnist.train.images # Returns np.array\n", | |
"train_labels = np.asarray(mnist.train.labels, dtype=np.int32)\n", | |
"eval_data = mnist.test.images # Returns np.array\n", | |
"eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_images = train_data\n", | |
"test_images = eval_data\n", | |
"test_labels = eval_labels\n", | |
"train_images = np.reshape(train_images, (-1,28,28))\n", | |
"test_images = np.reshape(test_images,(-1,28,28))\n", | |
"train_labels = np.reshape(train_labels, (-1,1))\n", | |
"test_labels = np.reshape(test_labels, (-1,1))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((55000, 28, 28), (55000, 1), (10000, 28, 28), (10000, 1))" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_images.shape, train_labels.shape, test_images.shape, test_labels.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAADL1JREFUeJzt3W/InfV9x/H3d64+MSoJRRdtsnRFRmdw6QhhkiIOUdwUomC1ASHDsRSssMIeTHxgFWkoc3XbEwsphibYWut/KaItKsbJCEadxjZrKxLbLDExKpgiIup3D+4r5W68z3VOzr/rJN/3C8I55/pd5/p9OeRz/65zrj+/yEwk1fNHXRcgqRuGXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUX88zc4iwtMJpQnLzBhkvZFG/oi4NCJ+GRGvRcSNo2xL0nTFsOf2R8RJwK+Ai4G9wPPA+sz8Rct7HPmlCZvGyL8GeC0zX8/MD4EfAetG2J6kKRol/GcDv533em+z7A9ExMaI2BkRO0foS9KYjfKD30K7Fp/arc/MzcBmcLdfmiWjjPx7gWXzXn8O2DdaOZKmZZTwPw+cExGfj4iTga8Cj46nLEmTNvRuf2Z+FBE3AE8AJwFbMvPnY6tM0kQNfahvqM78zi9N3FRO8pF0/DL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VZfilogy/VJThl4oy/FJRhl8qyvBLRRl+qaipTtEtzbdq1arW9qeeeqq1/Y033mhtv+CCC3q2HT58uPW9FTjyS0UZfqkowy8VZfilogy/VJThl4oy/FJRI83SGxF7gMPAx8BHmbm6z/rO0qvf27VrV2v7ueeeO9L2zzrrrJ5tb7755kjbnmWDztI7jpN8/iYzD41hO5KmyN1+qahRw5/ATyPihYjYOI6CJE3HqLv9azNzX0ScAfwsIv43M7fPX6H5o+AfBmnGjDTyZ+a+5vEg8BCwZoF1Nmfm6n4/BkqarqHDHxGnRMSpR54DlwCvjqswSZM1ym7/mcBDEXFkOz/MzMfHUpWkiRs6/Jn5OvCXY6xFJ6DzzjuvZ9vSpUunWImO5qE+qSjDLxVl+KWiDL9UlOGXijL8UlHeulsT1XZ77iVLlkyxEh3NkV8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXivI4v0bSdskuwB133DGxvp999tnW9vfee29ifZ8IHPmlogy/VJThl4oy/FJRhl8qyvBLRRl+qSiP82sky5Yta20f5Zr9w4cPt7Zv2rSptf39998fuu8KHPmlogy/VJThl4oy/FJRhl8qyvBLRRl+qai+x/kjYgtwOXAwM1c2y5YA9wIrgD3A1Zn57uTKVFcWLVrU2n7ttddOrO+HH364tf2JJ56YWN8VDDLyfx+49KhlNwJPZuY5wJPNa0nHkb7hz8ztwDtHLV4HbG2ebwWuGHNdkiZs2O/8Z2bmfoDm8YzxlSRpGiZ+bn9EbAQ2TrofScdm2JH/QEQsBWgeD/ZaMTM3Z+bqzFw9ZF+SJmDY8D8KbGiebwAeGU85kqalb/gj4h7gv4E/j4i9EfEPwLeBiyPi18DFzWtJx5G+3/kzc32PpovGXItmUL/77l9zzTVDb/vdd9tPDbn77ruH3rb68ww/qSjDLxVl+KWiDL9UlOGXijL8UlGRmdPrLGJ6nWkgK1eubG1/5plnWtsXL148dN9XXnlla/sjj3ju2DAyMwZZz5FfKsrwS0UZfqkowy8VZfilogy/VJThl4pyiu4T3GmnndbafvPNN7e2j3IcH+Dtt9/u2fbyyy+PtG2NxpFfKsrwS0UZfqkowy8VZfilogy/VJThl4ryOP8J7rLLLmttv+qqq0ba/qFDh1rb227tvWfPnpH61mgc+aWiDL9UlOGXijL8UlGGXyrK8EtFGX6pqL7H+SNiC3A5cDAzVzbLbgH+EXirWe2mzHxsUkWq3eWXX96z7c4775xo3zt27Ghtf/rppyfav4Y3yMj/feDSBZb/e2auav4ZfOk40zf8mbkdeGcKtUiaolG+898QEa9ExJaIGO1eT5Kmbtjwfxf4ArAK2A98p9eKEbExInZGxM4h+5I0AUOFPzMPZObHmfkJ8D1gTcu6mzNzdWauHrZISeM3VPgjYum8l1cCr46nHEnTMsihvnuAC4HPRsRe4JvAhRGxCkhgD/C1CdYoaQL6hj8z1y+w+K4J1KIhXXfddT3bTj/99JG2/dZbb7W233777SNtX93xDD+pKMMvFWX4paIMv1SU4ZeKMvxSUd66+zhwxRVXtLZfdNFFE+v7+uuvb23fvn37xPrWZDnyS0UZfqkowy8VZfilogy/VJThl4oy/FJRHuefAStWrGht37ZtW2v7okWLhu77rrvar85+7DFvzHyicuSXijL8UlGGXyrK8EtFGX6pKMMvFWX4paIiM6fXWcT0Opshp556amv7fffd19p+ySWXDN33Sy+91Nq+du3a1vYPPvhg6L7VjcyMQdZz5JeKMvxSUYZfKsrwS0UZfqkowy8VZfilovpezx8Ry4BtwJ8AnwCbM/M/I2IJcC+wAtgDXJ2Z706u1OPXbbfd1to+ynF8gIjeh3Wfe+651vd6HL+uQUb+j4B/zswvAn8NfD0i/gK4EXgyM88BnmxeSzpO9A1/Zu7PzBeb54eB3cDZwDpga7PaVqB9WhlJM+WYvvNHxArgS8AO4MzM3A9zfyCAM8ZdnKTJGfgefhGxCHgA+EZmvtf2PfOo920ENg5XnqRJGWjkj4jPMBf8H2Tmg83iAxGxtGlfChxc6L2ZuTkzV2fm6nEULGk8+oY/5ob4u4DdmXnHvKZHgQ3N8w3AI+MvT9Kk9L2kNyK+DDwL7GLuUB/ATcx97/8xsBz4DfCVzHynz7ZOyEt6161b19p+7733traffPLJI/X/yiuv9Gxbs2ZN63s//PDDkfrW7Bn0kt6+3/kz87+AXhub3MTwkibKM/ykogy/VJThl4oy/FJRhl8qyvBLRTlF9xj0uyR31OP4/WzatKlnm8fx1Ysjv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8V5XH+MXj88cdb288///zW9uXLl7e233rrra3t999/f2u7tBBHfqkowy8VZfilogy/VJThl4oy/FJRhl8qqu99+8fa2Ql6335plgx6335Hfqkowy8VZfilogy/VJThl4oy/FJRhl8qqm/4I2JZRDwdEbsj4ucR8U/N8lsi4v8i4n+af383+XIljUvfk3wiYimwNDNfjIhTgReAK4Crgd9l5r8N3Jkn+UgTN+hJPn3v5JOZ+4H9zfPDEbEbOHu08iR17Zi+80fECuBLwI5m0Q0R8UpEbImIxT3eszEidkbEzpEqlTRWA5/bHxGLgGeAb2XmgxFxJnAISOA25r4aXNdnG+72SxM26G7/QOGPiM8APwGeyMw7FmhfAfwkM1f22Y7hlyZsbBf2REQAdwG75we/+SHwiCuBV4+1SEndGeTX/i8DzwK7gE+axTcB64FVzO327wG+1vw42LYtR35pwsa62z8uhl+aPK/nl9TK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VZfilogy/VFTfG3iO2SHgjXmvP9ssm0WzWtus1gXWNqxx1vang6441ev5P9V5xM7MXN1ZAS1mtbZZrQusbVhd1eZuv1SU4ZeK6jr8mzvuv82s1jardYG1DauT2jr9zi+pO12P/JI60kn4I+LSiPhlRLwWETd2UUMvEbEnInY1Mw93OsVYMw3awYh4dd6yJRHxs4j4dfO44DRpHdU2EzM3t8ws3elnN2szXk99tz8iTgJ+BVwM7AWeB9Zn5i+mWkgPEbEHWJ2ZnR8TjogLgN8B247MhhQR/wq8k5nfbv5wLs7Mf5mR2m7hGGdunlBtvWaW/ns6/OzGOeP1OHQx8q8BXsvM1zPzQ+BHwLoO6ph5mbkdeOeoxeuArc3zrcz955m6HrXNhMzcn5kvNs8PA0dmlu70s2upqxNdhP9s4LfzXu9ltqb8TuCnEfFCRGzsupgFnHlkZqTm8YyO6zla35mbp+momaVn5rMbZsbrcesi/AvNJjJLhxzWZuZfAX8LfL3ZvdVgvgt8gblp3PYD3+mymGZm6QeAb2Tme13WMt8CdXXyuXUR/r3AsnmvPwfs66COBWXmvubxIPAQc19TZsmBI5OkNo8HO67n9zLzQGZ+nJmfAN+jw8+umVn6AeAHmflgs7jzz26hurr63LoI//PAORHx+Yg4Gfgq8GgHdXxKRJzS/BBDRJwCXMLszT78KLCheb4BeKTDWv7ArMzc3GtmaTr+7GZtxutOTvJpDmX8B3ASsCUzvzX1IhYQEX/G3GgPc1c8/rDL2iLiHuBC5q76OgB8E3gY+DGwHPgN8JXMnPoPbz1qu5BjnLl5QrX1mll6Bx1+duOc8Xos9XiGn1STZ/hJRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrq/wGWybdnU+8DKAAAAABJRU5ErkJggg==\n", | |
"text/plain": [ | |
"<matplotlib.figure.Figure at 0x21bbe08feb8>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.imshow(train_images[107],cmap='gray')\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Great, seems like the shape are okay, let's handle the labels as requested. let's start with the the test set" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(array([0, 1]), list)" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"test_labels=[1 if i % 2 != 0 else 0 for i in test_labels]\n", | |
"np.unique(test_labels), type(test_labels)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Ok, let's go preprocess our training set:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_images_odd = np.array([train_images[i,:,:] for i in range(train_images.shape[0]) if train_labels[i] % 2 != 0])\n", | |
"train_images_even = np.array([train_images[i,:,:] for i in range(train_images.shape[0]) if train_labels[i] % 2 == 0])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's flatten the images in order to feed them later to the network (not using CNN here, regular old DNN).</br>\n", | |
"Let's also seperate the training images to \"odd\" and \"even\" training images, this is in order to create the triplets required\n", | |
"in the exercise more easily." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# train_images = np.reshape(train_images, [train_images.shape[0], train_images.shape[1] * train_images.shape[2]])\n", | |
"# train_images_odd = np.reshape(train_images_odd, [train_images_odd.shape[0], train_images_odd.shape[1] * train_images_odd.shape[2]])\n", | |
"# train_images_even = np.reshape(train_images_even, [train_images_even.shape[0], train_images_even.shape[1] * train_images_even.shape[2]])\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((55000, 28, 28), (27027, 28, 28), (27973, 28, 28))" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_images.shape, train_images_even.shape, train_images_odd.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# image_triplets = np.zeros((10000 ,3,train_images.shape[1], train_images.shape[2])) #initialize triplet data array (10000 triplets of pictures)\n", | |
"# triplet_labels = np.zeros((10000, 1)) # initialize triplet label array\n", | |
"# for i in range(10000):\n", | |
"# rand_mnist = np.random.randint(60000, size=1) #a random image from mnist training set\n", | |
"# if train_labels[rand_mnist] % 2 != 0: # checking if number is odd\n", | |
"# rand_even = np.random.randint(train_images_even.shape[0], size=2) #randomizing 2 even numbers\n", | |
"# combine_images = np.concatenate((np.ravel(train_images[rand_mnist,:]), train_images_even[rand_even[0], :],train_images_even[rand_even[1], :]), axis=0) # combine 3 images\n", | |
"# image_triplets[i,:] = combine_images #add 1 odd and 2 even images to triplet array\n", | |
"# triplet_labels[i] = 1 # adding appropriate binary label\n", | |
"# else:\n", | |
"# rand_odd = np.random.randint(train_images_odd.shape[0], size=2)#randomizing 2 odd numbers\n", | |
"# combine_images = np.concatenate((np.ravel(train_images[rand_mnist,:]), train_images_odd[rand_odd[0], :],train_images_odd[rand_odd[1], :]), axis=0) # combine 3 images\n", | |
"# image_triplets[i,:] = combine_images #add 1 even and 2 odd images to triplet array\n", | |
"# triplet_labels[i] = 0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"image_triplets = np.zeros((10000, 3, train_images.shape[1], train_images.shape[2])) #initialize triplet data array (10000 triplets of pictures)\n", | |
"triplet_labels = np.zeros((10000, 1)) # initialize triplet label array\n", | |
"for i in range(10000):\n", | |
" rand_mnist = np.random.randint(55000, size=1) #a random image from mnist training set\n", | |
" if train_labels[rand_mnist] % 2 != 0: # checking if number is odd\n", | |
" rand_even = np.random.randint(train_images_even.shape[0], size=2) #randomizing 2 even numbers\n", | |
"# combine_images = np.concatenate((np.ravel(train_images[rand_mnist,:]), train_images_even[rand_even[0], :],train_images_even[rand_even[1], :]), axis=0) # combine 3 images\n", | |
" image_triplets[i, 0, :, :] = train_images[rand_mnist,:, :] #add 1 odd and 2 even images to triplet array\n", | |
" image_triplets[i, 1, :, :] = train_images_even[rand_even[0], :, :]\n", | |
" image_triplets[i, 2, :, :] = train_images_even[rand_even[1], :, :]\n", | |
" triplet_labels[i] = 1 # adding appropriate binary label\n", | |
" else:\n", | |
" rand_odd = np.random.randint(train_images_odd.shape[0], size=2)#randomizing 2 odd numbers\n", | |
" #combine_images = np.concatenate((np.ravel(train_images[rand_mnist,:]), train_images_odd[rand_odd[0], :],train_images_odd[rand_odd[1], :]), axis=0) # combine 3 images\n", | |
" image_triplets[i, 0, :, :] = train_images[rand_mnist,:, :] #add 1 odd and 2 even images to triplet array\n", | |
" image_triplets[i, 1, :, :] = train_images_odd[rand_odd[0], :, :]\n", | |
" image_triplets[i, 2, :, :] = train_images_odd[rand_odd[1], :, :]\n", | |
" triplet_labels[i] = 0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((10000, 3, 28, 28), (10000, 1))" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"image_triplets.shape,triplet_labels.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Great, we have our triplets, We got 10000 samples in each sample we have 3 images (either 1 odd and 2 even <u>or</u> 2 odd and 1 even)<br>\n", | |
"Ok, we can easily notice that our training set and test set are different.<br>We have a triplet of pictures per 1 label for the training set and one picture per label for the test set.<br>\n", | |
"This problem is usually called MIL (Multiple Instance labeling) where we have a bag of instances, but in our case our inference will be concerning SI (Single Instance).<br><a href=\"https://arxiv.org/pdf/1406.0281.pdf\">This article, section 3.4</a><br>\n", | |
"In the way we have created our training set, there is a 66% chance of a picture in the triplet to be matched with it's label and 33% of it being mislabeled.<br>\n", | |
"My solution will ravel the pictures and give every picture in the triplet the same label as the triplet<br>\n", | |
"So basically this will create a new training set from the triplet set with noise(mislabeled samples)<br>\n", | |
"I reilied on the fact that deep neural networks are <a href=\"https://arxiv.org/pdf/1705.10694\">robust to noisy labeling</a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(<matplotlib.image.AxesImage at 0x21bc0aa6d68>, array([0.]))" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAADNVJREFUeJzt3W+oXPWdx/HPZ7VFTYNEokm0qYklLqs+sMtFFxsWl8XiLoVYQyQ+Stl1bx5E2MoqBp9UkIisaXcXHxRuSUgCrW390zWWtW3QZU1FxFxZY/5sW43ZNptLrpJgFJSa+N0H96Rc453fzJ05M2duvu8XhJk53znnfBnyueecOXPOzxEhAPn8SdMNAGgG4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kNT5g1yZbX5OCPRZRLiT9/W05bd9q+1f237T9sZelgVgsNztb/ttnyfpN5JukXRE0quS7oyIA4V52PIDfTaILf8Nkt6MiEMR8QdJP5K0qoflARigXsJ/haTfT3t9pJr2KbZHbe+xvaeHdQGoWS9f+M20a/GZ3fqIGJM0JrHbDwyTXrb8RyQtnfb6i5KO9tYOgEHpJfyvSlphe7ntz0taK2lnPW0B6Leud/sj4pTtuyX9QtJ5krZGxP7aOgPQV12f6utqZRzzA303kB/5AJi7CD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iq6yG6Jcn2YUnvSzot6VREjNTRFID+6yn8lb+KiHdrWA6AAWK3H0iq1/CHpF/aHrc9WkdDAAaj193+r0bEUduXSdpl+38i4sXpb6j+KPCHARgyjoh6FmQ/KOmDiNhceE89KwPQUkS4k/d1vdtve57t+WeeS/qapH3dLg/AYPWy279I0k9tn1nODyPi57V0BaDvatvt72hl7PYDfdf33X4AcxvhB5Ii/EBShB9IivADSRF+IKk6ruoDzjnz588v1ufNm1esf/jhh8X6e++9N+ue6saWH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeS4jw/zllXXXVVy9qaNWuK865fv75Yv/LKK4v1o0ePFutLly4t1geBLT+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJMV5fvTkoYceKtZPnz7d9bJXr15drF977bVdL7sab6KlXm9p/8ILL/Q0/yCw5QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpNqe57e9VdLXJU1GxHXVtEsk/VjSMkmHJd0RESf61yb6pd396R9++OFifXR0tFg///zW/8V6Pdfey7n4119/vVh/8skni/Xly5cX6y+99NKsexq0Trb82yTdeta0jZKej4gVkp6vXgOYQ9qGPyJelHT8rMmrJG2vnm+XdFvNfQHos26P+RdFxIQkVY+X1dcSgEHo+2/7bY9KKh8YAhi4brf8x2wvkaTqcbLVGyNiLCJGImKky3UB6INuw79T0rrq+TpJz9TTDoBBaRt+249LelnSn9o+YvvvJT0i6Rbbv5V0S/UawBziXq9bntXK7MGtDJKkBQsWFOuPPFL+u33XXXfV2c6nnDx5slhvN8b9E088Uaxv2bKlZW3v3r3FeeeyiCj/gKLCL/yApAg/kBThB5Ii/EBShB9IivADSXHr7jmg3em6++67r2Wt3am6hQsXFutvvfVWsb548eJifcOGDS1ru3btKs47MTFRrKM3bPmBpAg/kBThB5Ii/EBShB9IivADSRF+ICnO888B99xzT7G+cWP3N0/et29fsb527dpi/aOPPirWDx06NOueMBhs+YGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKW7dPQSuvvrqYn18fLxYv+iii7pe98cff1ysP/vss8V6uyG6T5xg5PZB49bdAIoIP5AU4QeSIvxAUoQfSIrwA0kRfiCpttfz294q6euSJiPiumrag5L+QdI71dseiIj/6FeT2U1OThbrF154YdfLfvvtt4v1G2+8sVifP39+sc55/uHVyZZ/m6RbZ5j+LxFxffWP4ANzTNvwR8SLko4PoBcAA9TLMf/dtvfa3mq7PJ4UgKHTbfi/J+nLkq6XNCHpO63eaHvU9h7be7pcF4A+6Cr8EXEsIk5HxCeSvi/phsJ7xyJiJCJGum0SQP26Cr/tJdNefkNS+RawAIZOJ6f6Hpd0s6SFto9I+rakm21fLykkHZa0vo89AugDrudP7vLLLy/W9+/fX6yPjY0V6/fff/+se0JvuJ4fQBHhB5Ii/EBShB9IivADSRF+ICmG6E5u+/btxfrFF19crL/88st1toMBYssPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0lxnr/S7hbVq1evblnbtm1bcd4DBw5009Iftbs199q1a1vWNm3aVJx38eLFxfqqVauK9eeee65Yx/Biyw8kRfiBpAg/kBThB5Ii/EBShB9IivADSXHr7sru3buL9Ztuuqll7Z133mlZk6SJiYli3S7fafmCCy4o1lesWFGsl9x7773F+mOPPVasnzp1qut1oz+4dTeAIsIPJEX4gaQIP5AU4QeSIvxAUoQfSKrt9fy2l0raIWmxpE8kjUXEv9m+RNKPJS2TdFjSHRFxon+tDq9LL720p3q78/ztfotRGkZ78+bNxXl37NhRrOPc1cmW/5Skf4qIP5P0F5I22L5G0kZJz0fECknPV68BzBFtwx8RExHxWvX8fUkHJV0haZWkM8O9bJd0W7+aBFC/WR3z214m6SuSXpG0KCImpKk/EJIuq7s5AP3T8T38bH9B0lOSvhURJ9sdp06bb1TSaHftAeiXjrb8tj+nqeD/ICKeriYfs72kqi+RNDnTvBExFhEjETFSR8MA6tE2/J7axG+RdDAivjuttFPSuur5OknP1N8egH5pe0mv7ZWSdkt6Q1On+iTpAU0d9/9E0pck/U7Smog43mZZQ3tJ78qVK4v122+/vWXtmmuu6Wnd7Q6hxsfHi/VHH320Ze3EiZRnX1Pr9JLetsf8EfErSa0W9tezaQrA8OAXfkBShB9IivADSRF+ICnCDyRF+IGkuHU3cI7h1t0Aigg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCptuG3vdT2f9o+aHu/7X+spj9o+/9s/3f172/73y6AurQdtMP2EklLIuI12/MljUu6TdIdkj6IiM0dr4xBO4C+63TQjvM7WNCEpInq+fu2D0q6orf2ADRtVsf8tpdJ+oqkV6pJd9vea3ur7QUt5hm1vcf2np46BVCrjsfqs/0FSf8laVNEPG17kaR3JYWkhzR1aPB3bZbBbj/QZ53u9ncUftufk/QzSb+IiO/OUF8m6WcRcV2b5RB+oM9qG6jTtiVtkXRwevCrLwLP+IakfbNtEkBzOvm2f6Wk3ZLekPRJNfkBSXdKul5Tu/2HJa2vvhwsLYstP9Bnte7214XwA/1X224/gHMT4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+IKm2N/Cs2buS/nfa64XVtGE0rL0Na18SvXWrzt6u7PSNA72e/zMrt/dExEhjDRQMa2/D2pdEb91qqjd2+4GkCD+QVNPhH2t4/SXD2tuw9iXRW7ca6a3RY34AzWl6yw+gIY2E3/attn9t+03bG5vooRXbh22/UY083OgQY9UwaJO2902bdontXbZ/Wz3OOExaQ70NxcjNhZGlG/3shm3E64Hv9ts+T9JvJN0i6YikVyXdGREHBtpIC7YPSxqJiMbPCdv+S0kfSNpxZjQk2/8s6XhEPFL94VwQEfcPSW8PapYjN/ept1YjS39TDX52dY54XYcmtvw3SHozIg5FxB8k/UjSqgb6GHoR8aKk42dNXiVpe/V8u6b+8wxci96GQkRMRMRr1fP3JZ0ZWbrRz67QVyOaCP8Vkn4/7fURDdeQ3yHpl7bHbY823cwMFp0ZGal6vKzhfs7WduTmQTprZOmh+ey6GfG6bk2Ef6bRRIbplMNXI+LPJf2NpA3V7i068z1JX9bUMG4Tkr7TZDPVyNJPSfpWRJxsspfpZuirkc+tifAfkbR02usvSjraQB8zioij1eOkpJ9q6jBlmBw7M0hq9TjZcD9/FBHHIuJ0RHwi6ftq8LOrRpZ+StIPIuLpanLjn91MfTX1uTUR/lclrbC93PbnJa2VtLOBPj7D9rzqixjZnifpaxq+0Yd3SlpXPV8n6ZkGe/mUYRm5udXI0mr4sxu2Ea8b+ZFPdSrjXyWdJ2lrRGwaeBMzsH2Vprb20tQVjz9ssjfbj0u6WVNXfR2T9G1J/y7pJ5K+JOl3ktZExMC/eGvR282a5cjNfeqt1cjSr6jBz67OEa9r6Ydf+AE58Qs/ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJ/T93Jc3ycrK5KwAAAABJRU5ErkJggg==\n", | |
"text/plain": [ | |
"<matplotlib.figure.Figure at 0x21bc0aebef0>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.imshow(image_triplets[2510,0,:,:],cmap='gray'),triplet_labels[2509,:]\n", | |
"# plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((30000, 28, 28, 1), (30000, 1))" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"noisy = np.reshape(image_triplets, [-1,28,28])\n", | |
"noisy.shape\n", | |
"noisy_lab = np.zeros((noisy.shape[0],1))\n", | |
"\n", | |
"for i in range(0, triplet_labels.shape[0], 3):\n", | |
" noisy_lab[i],noisy_lab[i+1] ,noisy_lab[i+2] = triplet_labels[(int(i / 3))],triplet_labels[(int(i / 3))],triplet_labels[(int(i / 3))]\n", | |
"\n", | |
"noisy = noisy.reshape(noisy.shape[0],28, 28, 1)\n", | |
"noisy.shape, noisy_lab.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Turn to one hot encoder" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((30000, 28, 28, 1), (30000, 2), (10000, 28, 28, 1), (10000, 2))" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_y = np.array(noisy_lab[:, 0]) # Input\n", | |
"train_y = np.reshape(train_y, (-1, 1))\n", | |
"enc = OneHotEncoder()\n", | |
"enc.fit(train_y)\n", | |
"out = enc.transform(train_y).toarray()\n", | |
"noisy_lab = out\n", | |
"train_y = np.array(test_labels) # Input\n", | |
"train_y = np.reshape(train_y, (-1, 1))\n", | |
"enc = OneHotEncoder()\n", | |
"enc.fit(train_y)\n", | |
"out = enc.transform(train_y).toarray()\n", | |
"test_labels=out\n", | |
"test_images = np.reshape(test_images,(-1,28,28,1))\n", | |
"noisy.shape,noisy_lab.shape, test_images.shape, test_labels.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"I thought about subsampling the triplets at some point, didn't show any signs of improving the score. So I left it commented out." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# rand_triplet = np.random.randint(3, size=10000)\n", | |
"# plt.hist(rand_triplet, bins='auto') \n", | |
"# plt.title(\"Histogram with 'auto' bins\")\n", | |
"# plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"This is a uniform distribution of the numbers in the triplets. This will guarantee us that we have 66% chance of hitting the right label (since we have 2 images inside the triplet that match the label)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# noisy_images = np.zeros((10000, train_images.shape[1], train_images.shape[2]))\n", | |
"# for i in range(10000):\n", | |
"# rand_triplet = np.random.randint(3, size=1)\n", | |
"# noisy_images[i, :, :] = image_triplets[i, rand_triplet, :, :]\n", | |
"# # for i in range(10000):\n", | |
"# # rand_triplet = np.random.randint(3, size=1)\n", | |
"# # noisy_images[i+9999, :, :] = image_triplets[i, rand_triplet, :, :]\n", | |
"# # for i in range(10000):\n", | |
"# # rand_triplet = np.random.randint(3, size=1)\n", | |
"# # noisy_images[i+19999, :, :] = image_triplets[i, rand_triplet, :, :]\n", | |
"# noisy_labels = np.concatenate([triplet_labels])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# noisy_images.shape, noisy_labels.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We need to have our input fit to the tensorflow module(i.e. batch x height x width x nchannels)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# noisy_images = noisy_images.reshape(noisy_images.shape[0],28, 28, 1)\n", | |
"# test_images = test_images.reshape(test_images.shape[0], 28, 28, 1)\n", | |
"# noisy_images.shape, test_images.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# noisy_images.shape,noisy_labels.shape, test_images.shape, test_labels.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Ok, we have our engineered dataset , let's move to the neural network, try to train it and test it on the test set." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"WARNING:tensorflow:From <ipython-input-28-9ba6dd5eb3d7>:41: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"\n", | |
"Future major versions of TensorFlow will allow gradients to flow\n", | |
"into the labels input on backprop by default.\n", | |
"\n", | |
"See tf.nn.softmax_cross_entropy_with_logits_v2.\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"BATCH_SIZE = 64\n", | |
"learning_rate = 0.0005\n", | |
"\n", | |
"graph = tf.Graph()\n", | |
"with graph.as_default():\n", | |
" # Input data.1\n", | |
" tf_train_dataset = tf.placeholder(tf.float32, shape=(None, train_images.shape[1], train_images.shape[2], 1))\n", | |
" tf_train_labels = tf.placeholder(tf.int32, shape=(None, 2))\n", | |
" traindrop = tf.placeholder(tf.bool)\n", | |
" \n", | |
" #tf_test_dataset = tf.placeholder(tf.float32, shape=(10000, train_images.shape[1], train_images.shape[2], 1))\n", | |
"# tf_test_labels = tf.placeholder(tf.float32, shape=(10000, 2))\n", | |
" #tf_test_dataset = tf.to_float(tf.constant(test_images))\n", | |
" \n", | |
" \n", | |
" #model\n", | |
" def model(data):\n", | |
" conv1 = tf.contrib.layers.conv2d(data, num_outputs=32, kernel_size=5,padding ='SAME', activation_fn=tf.nn.relu)\n", | |
" pool1 = tf.contrib.layers.max_pool2d(conv1, [2,2], 2)\n", | |
" conv2 = tf.contrib.layers.conv2d(pool1, num_outputs=64, kernel_size=5,padding ='SAME', activation_fn=tf.nn.relu)\n", | |
" pool2 = tf.contrib.layers.max_pool2d(conv2, [2,2], 2)\n", | |
" flat = tf.layers.flatten(pool2)\n", | |
" dense1 = tf.contrib.layers.fully_connected(flat, 1024, activation_fn=tf.nn.relu)\n", | |
" #dropout = tf.nn.dropout(dense1, 0.5)\n", | |
" dropout = tf.layers.dropout(inputs=dense1, rate=0.4, training=traindrop)\n", | |
" dense2 = tf.contrib.layers.fully_connected(dropout, 512, activation_fn=tf.nn.relu)\n", | |
" logits = tf.layers.dense(inputs=dense2, units=2)\n", | |
" return logits\n", | |
"\n", | |
" # Label smoothing\n", | |
"# if correct_labels is not None:\n", | |
"# cr0= 0.6\n", | |
"# table = tf.convert_to_tensor([cr0, 1.-cr0])\n", | |
"# tf_train_labels = tf.nn.embedding_lookup(table, tf_train_labels)\n", | |
" \n", | |
" \n", | |
" \n", | |
" #Loss definition\n", | |
" logits = model(tf_train_dataset)\n", | |
" normalized_logits =((0.3/2) + (0.7*logits))\n", | |
" loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=tf_train_labels, logits=normalized_logits))\n", | |
"\n", | |
"\n", | |
" # Optimizer\n", | |
" optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss)\n", | |
"\n", | |
" # predictions for test\n", | |
" train_pred = tf.nn.softmax(normalized_logits)\n", | |
"# test_pred = tf.nn.softmax(model(normalized_logits))\n", | |
"# pred = tf.nn.softmax(model(tf_test_dataset))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Ok, we finished our neural network model... nothing special, few conv layers, pooling and dropout. eventually evaluating the one hot encoder (2 units).<br>\n", | |
"The only thing different here is the change in my predictions to p~t=0.3/N+0.7pt (pt are my logits).\n", | |
"I modeled the noisy labels as uniform noise.\n", | |
"I thought about implementing <a href= \"https://arxiv.org/pdf/1512.00567.pdf\">label smoothing</a> at some point but it didn't turn out well as well." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def accuracy(predictions, labels):\n", | |
" return (100.0 * np.sum(np.argmax(predictions, 1) == np.argmax(labels, 1))/ predictions.shape[0])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Initialized weights ,biases and other variables\n", | |
"Training batch loss at epoch 0: 0.695419\n", | |
"train predict\n", | |
"[[0.50901407 0.49098596]\n", | |
" [0.49824655 0.50175345]] (64, 2)\n", | |
"Minibatch accuracy: 51.6%\n", | |
"test batch = 0, test minibatch acc = 21.875\n", | |
"All Test acc : 27.18349358974359\n", | |
"Training batch loss at epoch 0: 0.629663\n", | |
"train predict\n", | |
"[[0.37293997 0.62706006]\n", | |
" [0.34847245 0.6515275 ]] (64, 2)\n", | |
"Minibatch accuracy: 65.6%\n", | |
"test batch = 0, test minibatch acc = 6.25\n", | |
"All Test acc : 12.880608974358974\n", | |
"Training batch loss at epoch 0: 0.000000\n", | |
"train predict\n", | |
"[[1. 0.]\n", | |
" [1. 0.]] (64, 2)\n", | |
"Minibatch accuracy: 100.0%\n", | |
"test batch = 0, test minibatch acc = 42.1875\n", | |
"All Test acc : 49.2588141025641\n", | |
"Training batch loss at epoch 0: 0.000000\n", | |
"train predict\n", | |
"[[1.0000000e+00 0.0000000e+00]\n", | |
" [1.0000000e+00 7.9840876e-36]] (64, 2)\n", | |
"Minibatch accuracy: 100.0%\n", | |
"test batch = 0, test minibatch acc = 42.1875\n", | |
"All Test acc : 49.2588141025641\n", | |
"Training batch loss at epoch 0: 0.000000\n", | |
"train predict\n", | |
"[[1. 0.]\n", | |
" [1. 0.]] (64, 2)\n", | |
"Minibatch accuracy: 100.0%\n", | |
"test batch = 0, test minibatch acc = 42.1875\n", | |
"All Test acc : 49.2588141025641\n", | |
"Training batch loss at epoch 1: 62.272629\n", | |
"train predict\n", | |
"[[1.0000000e+00 4.8006741e-37]\n", | |
" [1.0000000e+00 1.0597831e-36]] (64, 2)\n", | |
"Minibatch accuracy: 48.4%\n", | |
"test batch = 0, test minibatch acc = 42.1875\n", | |
"All Test acc : 49.2588141025641\n", | |
"Training batch loss at epoch 1: 0.640868\n", | |
"train predict\n", | |
"[[0.3552155 0.6447845 ]\n", | |
" [0.37663218 0.6233678 ]] (64, 2)\n", | |
"Minibatch accuracy: 68.8%\n", | |
"test batch = 0, test minibatch acc = 14.0625\n", | |
"All Test acc : 17.037259615384617\n", | |
"Training batch loss at epoch 1: 0.000000\n", | |
"train predict\n", | |
"[[1.0000000e+00 7.3246713e-24]\n", | |
" [1.0000000e+00 2.1906906e-17]] (64, 2)\n", | |
"Minibatch accuracy: 100.0%\n", | |
"test batch = 0, test minibatch acc = 42.1875\n", | |
"All Test acc : 49.2588141025641\n", | |
"Training batch loss at epoch 1: 0.000000\n", | |
"train predict\n", | |
"[[1.0000000e+00 1.3626768e-17]\n", | |
" [1.0000000e+00 1.2033073e-15]] (64, 2)\n", | |
"Minibatch accuracy: 100.0%\n", | |
"test batch = 0, test minibatch acc = 42.1875\n", | |
"All Test acc : 49.2588141025641\n", | |
"Training batch loss at epoch 1: 0.000000\n", | |
"train predict\n", | |
"[[1.0000000e+00 1.4841263e-20]\n", | |
" [1.0000000e+00 7.2137413e-20]] (64, 2)\n", | |
"Minibatch accuracy: 100.0%\n", | |
"test batch = 0, test minibatch acc = 42.1875\n", | |
"All Test acc : 49.2588141025641\n", | |
"Training batch loss at epoch 2: 23.142492\n", | |
"train predict\n", | |
"[[1.0000000e+00 1.2099758e-17]\n", | |
" [1.0000000e+00 4.4781130e-17]] (64, 2)\n", | |
"Minibatch accuracy: 48.4%\n", | |
"test batch = 0, test minibatch acc = 42.1875\n", | |
"All Test acc : 49.2588141025641\n", | |
"Training batch loss at epoch 2: 0.695359\n", | |
"train predict\n", | |
"[[0.49258307 0.5074169 ]\n", | |
" [0.4918901 0.50810987]] (64, 2)\n", | |
"Minibatch accuracy: 40.6%\n", | |
"test batch = 0, test minibatch acc = 57.8125\n", | |
"All Test acc : 50.7411858974359\n", | |
"Training batch loss at epoch 2: 0.000134\n", | |
"train predict\n", | |
"[[9.9985874e-01 1.4126042e-04]\n", | |
" [9.9993229e-01 6.7751578e-05]] (64, 2)\n", | |
"Minibatch accuracy: 100.0%\n", | |
"test batch = 0, test minibatch acc = 42.1875\n", | |
"All Test acc : 49.2588141025641\n", | |
"Training batch loss at epoch 2: 0.000000\n", | |
"train predict\n", | |
"[[9.9999988e-01 1.2150628e-07]\n", | |
" [9.9999571e-01 4.2739148e-06]] (64, 2)\n", | |
"Minibatch accuracy: 100.0%\n", | |
"test batch = 0, test minibatch acc = 42.1875\n", | |
"All Test acc : 49.2588141025641\n", | |
"Training batch loss at epoch 2: 0.000000\n", | |
"train predict\n", | |
"[[1.0000000e+00 2.5489841e-09]\n", | |
" [9.9999988e-01 7.4416107e-08]] (64, 2)\n", | |
"Minibatch accuracy: 100.0%\n", | |
"test batch = 0, test minibatch acc = 42.1875\n", | |
"All Test acc : 49.2588141025641\n" | |
] | |
} | |
], | |
"source": [ | |
"epoch = 3\n", | |
"TRAIN_DATASIZE,_,_,_ = noisy.shape\n", | |
"PERIOD = int(TRAIN_DATASIZE/BATCH_SIZE) #Number of iterations for each epoch\n", | |
"\n", | |
"\n", | |
"with tf.Session(graph=graph) as session:\n", | |
" tf.global_variables_initializer().run()\n", | |
" print('Initialized weights ,biases and other variables')\n", | |
" for step in range(epoch):\n", | |
" idxs = np.random.permutation(TRAIN_DATASIZE) #shuffled ordering\n", | |
" X_random = noisy[idxs]\n", | |
" Y_random = noisy_lab[idxs]\n", | |
" for i in range(PERIOD):\n", | |
" batch_data = noisy[i * BATCH_SIZE:(i+1) * BATCH_SIZE]\n", | |
" batch_labels = noisy_lab[i * BATCH_SIZE:(i+1) * BATCH_SIZE]\n", | |
" feed_dict = {tf_train_dataset : batch_data, tf_train_labels : batch_labels, traindrop:True}\n", | |
" _,l, train_predictions= session.run( [optimizer, loss, train_pred], feed_dict=feed_dict)\n", | |
" if (i % 100 == 0):\n", | |
" print('Training batch loss at epoch %d: %f' % (step, l))\n", | |
" print('train predict')\n", | |
" train_predictions = np.array(train_predictions)\n", | |
" print(train_predictions[:2,:], train_predictions.shape)\n", | |
" print('Minibatch accuracy: %.1f%%' % accuracy(train_predictions, batch_labels))\n", | |
" # Evaluation:\n", | |
" # Eval\n", | |
" X = test_images\n", | |
" Y = test_labels\n", | |
" BATCH_SIZE_TEST = 64\n", | |
" PERIOD_TEST = int(X.shape[0]/BATCH_SIZE)\n", | |
" all_acc = []\n", | |
"\n", | |
" for j in range(PERIOD_TEST):\n", | |
" batch_data = X[j * BATCH_SIZE_TEST:(j+1) * BATCH_SIZE_TEST]\n", | |
" batch_labels = Y[j * BATCH_SIZE_TEST:(j+1) * BATCH_SIZE_TEST]\n", | |
"\n", | |
" feed_dict = {tf_train_dataset : batch_data , traindrop:False}\n", | |
" train_predictions= session.run( train_pred, feed_dict=feed_dict)\n", | |
" acc = accuracy(train_predictions, batch_labels)\n", | |
" all_acc.append(acc)\n", | |
" if (j%200 == 0):\n", | |
" print('test batch number = {}, test minibatch acc = {}'.format(j,acc))\n", | |
" print('All Test acc : {}'.format(np.mean(all_acc)))\n", | |
" \n", | |
"\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"So we didn't get a good fit to the test set (50% is as good as a random guess), we got a decent fit to the training set, but every shuffling after epoch ended the NN tried to recoil itself, more work on avoiding overfitting is advised here(smaller net, higher dropout rate...)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python (myenv)", | |
"language": "python", | |
"name": "myenv" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment