Skip to content

Instantly share code, notes, and snippets.

@Z30G0D
Created May 5, 2018 16:25
Show Gist options
  • Save Z30G0D/8efda5c3c3d27bf2e736b2aa0dfc7c7a to your computer and use it in GitHub Desktop.
Save Z30G0D/8efda5c3c3d27bf2e736b2aa0dfc7c7a to your computer and use it in GitHub Desktop.
noisy_labels solution for imubit
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imubit\n",
"This is my solution for the challenge. I will explain my steps through this notebook. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.preprocessing import OneHotEncoder\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I used the \"mnist\" package to ease my work. located here:https://github.com/datapythonista/mnist"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use the retry module or similar alternatives.\n",
"WARNING:tensorflow:From <ipython-input-2-6441638eb0cd>:1: load_dataset (from tensorflow.contrib.learn.python.learn.datasets) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use tf.data.\n",
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\__init__.py:80: load_mnist (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:300: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please write your own downloading logic.\n",
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use tf.data to implement this functionality.\n",
"Extracting MNIST-data\\train-images-idx3-ubyte.gz\n",
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use tf.data to implement this functionality.\n",
"Extracting MNIST-data\\train-labels-idx1-ubyte.gz\n",
"Extracting MNIST-data\\t10k-images-idx3-ubyte.gz\n",
"Extracting MNIST-data\\t10k-labels-idx1-ubyte.gz\n",
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n"
]
}
],
"source": [
"mnist = tf.contrib.learn.datasets.load_dataset(\"mnist\")\n",
"train_data = mnist.train.images # Returns np.array\n",
"train_labels = np.asarray(mnist.train.labels, dtype=np.int32)\n",
"eval_data = mnist.test.images # Returns np.array\n",
"eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"train_images = train_data\n",
"test_images = eval_data\n",
"test_labels = eval_labels\n",
"train_images = np.reshape(train_images, (-1,28,28))\n",
"test_images = np.reshape(test_images,(-1,28,28))\n",
"train_labels = np.reshape(train_labels, (-1,1))\n",
"test_labels = np.reshape(test_labels, (-1,1))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((55000, 28, 28), (55000, 1), (10000, 28, 28), (10000, 1))"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_images.shape, train_labels.shape, test_images.shape, test_labels.shape"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAADL1JREFUeJzt3W/InfV9x/H3d64+MSoJRRdtsnRFRmdw6QhhkiIOUdwUomC1ASHDsRSssMIeTHxgFWkoc3XbEwsphibYWut/KaItKsbJCEadxjZrKxLbLDExKpgiIup3D+4r5W68z3VOzr/rJN/3C8I55/pd5/p9OeRz/65zrj+/yEwk1fNHXRcgqRuGXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUX88zc4iwtMJpQnLzBhkvZFG/oi4NCJ+GRGvRcSNo2xL0nTFsOf2R8RJwK+Ai4G9wPPA+sz8Rct7HPmlCZvGyL8GeC0zX8/MD4EfAetG2J6kKRol/GcDv533em+z7A9ExMaI2BkRO0foS9KYjfKD30K7Fp/arc/MzcBmcLdfmiWjjPx7gWXzXn8O2DdaOZKmZZTwPw+cExGfj4iTga8Cj46nLEmTNvRuf2Z+FBE3AE8AJwFbMvPnY6tM0kQNfahvqM78zi9N3FRO8pF0/DL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VZfilogy/VJThl4oy/FJRhl8qyvBLRRl+qaipTtEtzbdq1arW9qeeeqq1/Y033mhtv+CCC3q2HT58uPW9FTjyS0UZfqkowy8VZfilogy/VJThl4oy/FJRI83SGxF7gMPAx8BHmbm6z/rO0qvf27VrV2v7ueeeO9L2zzrrrJ5tb7755kjbnmWDztI7jpN8/iYzD41hO5KmyN1+qahRw5/ATyPihYjYOI6CJE3HqLv9azNzX0ScAfwsIv43M7fPX6H5o+AfBmnGjDTyZ+a+5vEg8BCwZoF1Nmfm6n4/BkqarqHDHxGnRMSpR54DlwCvjqswSZM1ym7/mcBDEXFkOz/MzMfHUpWkiRs6/Jn5OvCXY6xFJ6DzzjuvZ9vSpUunWImO5qE+qSjDLxVl+KWiDL9UlOGXijL8UlHeulsT1XZ77iVLlkyxEh3NkV8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXivI4v0bSdskuwB133DGxvp999tnW9vfee29ifZ8IHPmlogy/VJThl4oy/FJRhl8qyvBLRRl+qSiP82sky5Yta20f5Zr9w4cPt7Zv2rSptf39998fuu8KHPmlogy/VJThl4oy/FJRhl8qyvBLRRl+qai+x/kjYgtwOXAwM1c2y5YA9wIrgD3A1Zn57uTKVFcWLVrU2n7ttddOrO+HH364tf2JJ56YWN8VDDLyfx+49KhlNwJPZuY5wJPNa0nHkb7hz8ztwDtHLV4HbG2ebwWuGHNdkiZs2O/8Z2bmfoDm8YzxlSRpGiZ+bn9EbAQ2TrofScdm2JH/QEQsBWgeD/ZaMTM3Z+bqzFw9ZF+SJmDY8D8KbGiebwAeGU85kqalb/gj4h7gv4E/j4i9EfEPwLeBiyPi18DFzWtJx5G+3/kzc32PpovGXItmUL/77l9zzTVDb/vdd9tPDbn77ruH3rb68ww/qSjDLxVl+KWiDL9UlOGXijL8UlGRmdPrLGJ6nWkgK1eubG1/5plnWtsXL148dN9XXnlla/sjj3ju2DAyMwZZz5FfKsrwS0UZfqkowy8VZfilogy/VJThl4pyiu4T3GmnndbafvPNN7e2j3IcH+Dtt9/u2fbyyy+PtG2NxpFfKsrwS0UZfqkowy8VZfilogy/VJThl4ryOP8J7rLLLmttv+qqq0ba/qFDh1rb227tvWfPnpH61mgc+aWiDL9UlOGXijL8UlGGXyrK8EtFGX6pqL7H+SNiC3A5cDAzVzbLbgH+EXirWe2mzHxsUkWq3eWXX96z7c4775xo3zt27Ghtf/rppyfav4Y3yMj/feDSBZb/e2auav4ZfOk40zf8mbkdeGcKtUiaolG+898QEa9ExJaIGO1eT5Kmbtjwfxf4ArAK2A98p9eKEbExInZGxM4h+5I0AUOFPzMPZObHmfkJ8D1gTcu6mzNzdWauHrZISeM3VPgjYum8l1cCr46nHEnTMsihvnuAC4HPRsRe4JvAhRGxCkhgD/C1CdYoaQL6hj8z1y+w+K4J1KIhXXfddT3bTj/99JG2/dZbb7W233777SNtX93xDD+pKMMvFWX4paIMv1SU4ZeKMvxSUd66+zhwxRVXtLZfdNFFE+v7+uuvb23fvn37xPrWZDnyS0UZfqkowy8VZfilogy/VJThl4oy/FJRHuefAStWrGht37ZtW2v7okWLhu77rrvar85+7DFvzHyicuSXijL8UlGGXyrK8EtFGX6pKMMvFWX4paIiM6fXWcT0Opshp556amv7fffd19p+ySWXDN33Sy+91Nq+du3a1vYPPvhg6L7VjcyMQdZz5JeKMvxSUYZfKsrwS0UZfqkowy8VZfilovpezx8Ry4BtwJ8AnwCbM/M/I2IJcC+wAtgDXJ2Z706u1OPXbbfd1to+ynF8gIjeh3Wfe+651vd6HL+uQUb+j4B/zswvAn8NfD0i/gK4EXgyM88BnmxeSzpO9A1/Zu7PzBeb54eB3cDZwDpga7PaVqB9WhlJM+WYvvNHxArgS8AO4MzM3A9zfyCAM8ZdnKTJGfgefhGxCHgA+EZmvtf2PfOo920ENg5XnqRJGWjkj4jPMBf8H2Tmg83iAxGxtGlfChxc6L2ZuTkzV2fm6nEULGk8+oY/5ob4u4DdmXnHvKZHgQ3N8w3AI+MvT9Kk9L2kNyK+DDwL7GLuUB/ATcx97/8xsBz4DfCVzHynz7ZOyEt6161b19p+7733traffPLJI/X/yiuv9Gxbs2ZN63s//PDDkfrW7Bn0kt6+3/kz87+AXhub3MTwkibKM/ykogy/VJThl4oy/FJRhl8qyvBLRTlF9xj0uyR31OP4/WzatKlnm8fx1Ysjv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8V5XH+MXj88cdb288///zW9uXLl7e233rrra3t999/f2u7tBBHfqkowy8VZfilogy/VJThl4oy/FJRhl8qqu99+8fa2Ql6335plgx6335Hfqkowy8VZfilogy/VJThl4oy/FJRhl8qqm/4I2JZRDwdEbsj4ucR8U/N8lsi4v8i4n+af383+XIljUvfk3wiYimwNDNfjIhTgReAK4Crgd9l5r8N3Jkn+UgTN+hJPn3v5JOZ+4H9zfPDEbEbOHu08iR17Zi+80fECuBLwI5m0Q0R8UpEbImIxT3eszEidkbEzpEqlTRWA5/bHxGLgGeAb2XmgxFxJnAISOA25r4aXNdnG+72SxM26G7/QOGPiM8APwGeyMw7FmhfAfwkM1f22Y7hlyZsbBf2REQAdwG75we/+SHwiCuBV4+1SEndGeTX/i8DzwK7gE+axTcB64FVzO327wG+1vw42LYtR35pwsa62z8uhl+aPK/nl9TK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VZfilogy/VFTfG3iO2SHgjXmvP9ssm0WzWtus1gXWNqxx1vang6441ev5P9V5xM7MXN1ZAS1mtbZZrQusbVhd1eZuv1SU4ZeK6jr8mzvuv82s1jardYG1DauT2jr9zi+pO12P/JI60kn4I+LSiPhlRLwWETd2UUMvEbEnInY1Mw93OsVYMw3awYh4dd6yJRHxs4j4dfO44DRpHdU2EzM3t8ws3elnN2szXk99tz8iTgJ+BVwM7AWeB9Zn5i+mWkgPEbEHWJ2ZnR8TjogLgN8B247MhhQR/wq8k5nfbv5wLs7Mf5mR2m7hGGdunlBtvWaW/ns6/OzGOeP1OHQx8q8BXsvM1zPzQ+BHwLoO6ph5mbkdeOeoxeuArc3zrcz955m6HrXNhMzcn5kvNs8PA0dmlu70s2upqxNdhP9s4LfzXu9ltqb8TuCnEfFCRGzsupgFnHlkZqTm8YyO6zla35mbp+momaVn5rMbZsbrcesi/AvNJjJLhxzWZuZfAX8LfL3ZvdVgvgt8gblp3PYD3+mymGZm6QeAb2Tme13WMt8CdXXyuXUR/r3AsnmvPwfs66COBWXmvubxIPAQc19TZsmBI5OkNo8HO67n9zLzQGZ+nJmfAN+jw8+umVn6AeAHmflgs7jzz26hurr63LoI//PAORHx+Yg4Gfgq8GgHdXxKRJzS/BBDRJwCXMLszT78KLCheb4BeKTDWv7ArMzc3GtmaTr+7GZtxutOTvJpDmX8B3ASsCUzvzX1IhYQEX/G3GgPc1c8/rDL2iLiHuBC5q76OgB8E3gY+DGwHPgN8JXMnPoPbz1qu5BjnLl5QrX1mll6Bx1+duOc8Xos9XiGn1STZ/hJRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrq/wGWybdnU+8DKAAAAABJRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x21bbe08feb8>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(train_images[107],cmap='gray')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Great, seems like the shape are okay, let's handle the labels as requested. let's start with the the test set"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([0, 1]), list)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_labels=[1 if i % 2 != 0 else 0 for i in test_labels]\n",
"np.unique(test_labels), type(test_labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, let's go preprocess our training set:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"train_images_odd = np.array([train_images[i,:,:] for i in range(train_images.shape[0]) if train_labels[i] % 2 != 0])\n",
"train_images_even = np.array([train_images[i,:,:] for i in range(train_images.shape[0]) if train_labels[i] % 2 == 0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's flatten the images in order to feed them later to the network (not using CNN here, regular old DNN).</br>\n",
"Let's also seperate the training images to \"odd\" and \"even\" training images, this is in order to create the triplets required\n",
"in the exercise more easily."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# train_images = np.reshape(train_images, [train_images.shape[0], train_images.shape[1] * train_images.shape[2]])\n",
"# train_images_odd = np.reshape(train_images_odd, [train_images_odd.shape[0], train_images_odd.shape[1] * train_images_odd.shape[2]])\n",
"# train_images_even = np.reshape(train_images_even, [train_images_even.shape[0], train_images_even.shape[1] * train_images_even.shape[2]])\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((55000, 28, 28), (27027, 28, 28), (27973, 28, 28))"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_images.shape, train_images_even.shape, train_images_odd.shape"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# image_triplets = np.zeros((10000 ,3,train_images.shape[1], train_images.shape[2])) #initialize triplet data array (10000 triplets of pictures)\n",
"# triplet_labels = np.zeros((10000, 1)) # initialize triplet label array\n",
"# for i in range(10000):\n",
"# rand_mnist = np.random.randint(60000, size=1) #a random image from mnist training set\n",
"# if train_labels[rand_mnist] % 2 != 0: # checking if number is odd\n",
"# rand_even = np.random.randint(train_images_even.shape[0], size=2) #randomizing 2 even numbers\n",
"# combine_images = np.concatenate((np.ravel(train_images[rand_mnist,:]), train_images_even[rand_even[0], :],train_images_even[rand_even[1], :]), axis=0) # combine 3 images\n",
"# image_triplets[i,:] = combine_images #add 1 odd and 2 even images to triplet array\n",
"# triplet_labels[i] = 1 # adding appropriate binary label\n",
"# else:\n",
"# rand_odd = np.random.randint(train_images_odd.shape[0], size=2)#randomizing 2 odd numbers\n",
"# combine_images = np.concatenate((np.ravel(train_images[rand_mnist,:]), train_images_odd[rand_odd[0], :],train_images_odd[rand_odd[1], :]), axis=0) # combine 3 images\n",
"# image_triplets[i,:] = combine_images #add 1 even and 2 odd images to triplet array\n",
"# triplet_labels[i] = 0"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"image_triplets = np.zeros((10000, 3, train_images.shape[1], train_images.shape[2])) #initialize triplet data array (10000 triplets of pictures)\n",
"triplet_labels = np.zeros((10000, 1)) # initialize triplet label array\n",
"for i in range(10000):\n",
" rand_mnist = np.random.randint(55000, size=1) #a random image from mnist training set\n",
" if train_labels[rand_mnist] % 2 != 0: # checking if number is odd\n",
" rand_even = np.random.randint(train_images_even.shape[0], size=2) #randomizing 2 even numbers\n",
"# combine_images = np.concatenate((np.ravel(train_images[rand_mnist,:]), train_images_even[rand_even[0], :],train_images_even[rand_even[1], :]), axis=0) # combine 3 images\n",
" image_triplets[i, 0, :, :] = train_images[rand_mnist,:, :] #add 1 odd and 2 even images to triplet array\n",
" image_triplets[i, 1, :, :] = train_images_even[rand_even[0], :, :]\n",
" image_triplets[i, 2, :, :] = train_images_even[rand_even[1], :, :]\n",
" triplet_labels[i] = 1 # adding appropriate binary label\n",
" else:\n",
" rand_odd = np.random.randint(train_images_odd.shape[0], size=2)#randomizing 2 odd numbers\n",
" #combine_images = np.concatenate((np.ravel(train_images[rand_mnist,:]), train_images_odd[rand_odd[0], :],train_images_odd[rand_odd[1], :]), axis=0) # combine 3 images\n",
" image_triplets[i, 0, :, :] = train_images[rand_mnist,:, :] #add 1 odd and 2 even images to triplet array\n",
" image_triplets[i, 1, :, :] = train_images_odd[rand_odd[0], :, :]\n",
" image_triplets[i, 2, :, :] = train_images_odd[rand_odd[1], :, :]\n",
" triplet_labels[i] = 0"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((10000, 3, 28, 28), (10000, 1))"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"image_triplets.shape,triplet_labels.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Great, we have our triplets, We got 10000 samples in each sample we have 3 images (either 1 odd and 2 even <u>or</u> 2 odd and 1 even)<br>\n",
"Ok, we can easily notice that our training set and test set are different.<br>We have a triplet of pictures per 1 label for the training set and one picture per label for the test set.<br>\n",
"This problem is usually called MIL (Multiple Instance labeling) where we have a bag of instances, but in our case our inference will be concerning SI (Single Instance).<br><a href=\"https://arxiv.org/pdf/1406.0281.pdf\">This article, section 3.4</a><br>\n",
"In the way we have created our training set, there is a 66% chance of a picture in the triplet to be matched with it's label and 33% of it being mislabeled.<br>\n",
"My solution will ravel the pictures and give every picture in the triplet the same label as the triplet<br>\n",
"So basically this will create a new training set from the triplet set with noise(mislabeled samples)<br>\n",
"I reilied on the fact that deep neural networks are <a href=\"https://arxiv.org/pdf/1705.10694\">robust to noisy labeling</a>"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(<matplotlib.image.AxesImage at 0x21bc0aa6d68>, array([0.]))"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAADNVJREFUeJzt3W+oXPWdx/HPZ7VFTYNEokm0qYklLqs+sMtFFxsWl8XiLoVYQyQ+Stl1bx5E2MoqBp9UkIisaXcXHxRuSUgCrW390zWWtW3QZU1FxFxZY/5sW43ZNptLrpJgFJSa+N0H96Rc453fzJ05M2duvu8XhJk53znnfBnyueecOXPOzxEhAPn8SdMNAGgG4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kNT5g1yZbX5OCPRZRLiT9/W05bd9q+1f237T9sZelgVgsNztb/ttnyfpN5JukXRE0quS7oyIA4V52PIDfTaILf8Nkt6MiEMR8QdJP5K0qoflARigXsJ/haTfT3t9pJr2KbZHbe+xvaeHdQGoWS9f+M20a/GZ3fqIGJM0JrHbDwyTXrb8RyQtnfb6i5KO9tYOgEHpJfyvSlphe7ntz0taK2lnPW0B6Leud/sj4pTtuyX9QtJ5krZGxP7aOgPQV12f6utqZRzzA303kB/5AJi7CD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iq6yG6Jcn2YUnvSzot6VREjNTRFID+6yn8lb+KiHdrWA6AAWK3H0iq1/CHpF/aHrc9WkdDAAaj193+r0bEUduXSdpl+38i4sXpb6j+KPCHARgyjoh6FmQ/KOmDiNhceE89KwPQUkS4k/d1vdtve57t+WeeS/qapH3dLg/AYPWy279I0k9tn1nODyPi57V0BaDvatvt72hl7PYDfdf33X4AcxvhB5Ii/EBShB9IivADSRF+IKk6ruoDzjnz588v1ufNm1esf/jhh8X6e++9N+ue6saWH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeS4jw/zllXXXVVy9qaNWuK865fv75Yv/LKK4v1o0ePFutLly4t1geBLT+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJMV5fvTkoYceKtZPnz7d9bJXr15drF977bVdL7sab6KlXm9p/8ILL/Q0/yCw5QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpNqe57e9VdLXJU1GxHXVtEsk/VjSMkmHJd0RESf61yb6pd396R9++OFifXR0tFg///zW/8V6Pdfey7n4119/vVh/8skni/Xly5cX6y+99NKsexq0Trb82yTdeta0jZKej4gVkp6vXgOYQ9qGPyJelHT8rMmrJG2vnm+XdFvNfQHos26P+RdFxIQkVY+X1dcSgEHo+2/7bY9KKh8YAhi4brf8x2wvkaTqcbLVGyNiLCJGImKky3UB6INuw79T0rrq+TpJz9TTDoBBaRt+249LelnSn9o+YvvvJT0i6Rbbv5V0S/UawBziXq9bntXK7MGtDJKkBQsWFOuPPFL+u33XXXfV2c6nnDx5slhvN8b9E088Uaxv2bKlZW3v3r3FeeeyiCj/gKLCL/yApAg/kBThB5Ii/EBShB9IivADSXHr7jmg3em6++67r2Wt3am6hQsXFutvvfVWsb548eJifcOGDS1ru3btKs47MTFRrKM3bPmBpAg/kBThB5Ii/EBShB9IivADSRF+ICnO888B99xzT7G+cWP3N0/et29fsb527dpi/aOPPirWDx06NOueMBhs+YGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKW7dPQSuvvrqYn18fLxYv+iii7pe98cff1ysP/vss8V6uyG6T5xg5PZB49bdAIoIP5AU4QeSIvxAUoQfSIrwA0kRfiCpttfz294q6euSJiPiumrag5L+QdI71dseiIj/6FeT2U1OThbrF154YdfLfvvtt4v1G2+8sVifP39+sc55/uHVyZZ/m6RbZ5j+LxFxffWP4ANzTNvwR8SLko4PoBcAA9TLMf/dtvfa3mq7PJ4UgKHTbfi/J+nLkq6XNCHpO63eaHvU9h7be7pcF4A+6Cr8EXEsIk5HxCeSvi/phsJ7xyJiJCJGum0SQP26Cr/tJdNefkNS+RawAIZOJ6f6Hpd0s6SFto9I+rakm21fLykkHZa0vo89AugDrudP7vLLLy/W9+/fX6yPjY0V6/fff/+se0JvuJ4fQBHhB5Ii/EBShB9IivADSRF+ICmG6E5u+/btxfrFF19crL/88st1toMBYssPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0lxnr/S7hbVq1evblnbtm1bcd4DBw5009Iftbs199q1a1vWNm3aVJx38eLFxfqqVauK9eeee65Yx/Biyw8kRfiBpAg/kBThB5Ii/EBShB9IivADSXHr7sru3buL9Ztuuqll7Z133mlZk6SJiYli3S7fafmCCy4o1lesWFGsl9x7773F+mOPPVasnzp1qut1oz+4dTeAIsIPJEX4gaQIP5AU4QeSIvxAUoQfSKrt9fy2l0raIWmxpE8kjUXEv9m+RNKPJS2TdFjSHRFxon+tDq9LL720p3q78/ztfotRGkZ78+bNxXl37NhRrOPc1cmW/5Skf4qIP5P0F5I22L5G0kZJz0fECknPV68BzBFtwx8RExHxWvX8fUkHJV0haZWkM8O9bJd0W7+aBFC/WR3z214m6SuSXpG0KCImpKk/EJIuq7s5AP3T8T38bH9B0lOSvhURJ9sdp06bb1TSaHftAeiXjrb8tj+nqeD/ICKeriYfs72kqi+RNDnTvBExFhEjETFSR8MA6tE2/J7axG+RdDAivjuttFPSuur5OknP1N8egH5pe0mv7ZWSdkt6Q1On+iTpAU0d9/9E0pck/U7Smog43mZZQ3tJ78qVK4v122+/vWXtmmuu6Wnd7Q6hxsfHi/VHH320Ze3EiZRnX1Pr9JLetsf8EfErSa0W9tezaQrA8OAXfkBShB9IivADSRF+ICnCDyRF+IGkuHU3cI7h1t0Aigg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCptuG3vdT2f9o+aHu/7X+spj9o+/9s/3f172/73y6AurQdtMP2EklLIuI12/MljUu6TdIdkj6IiM0dr4xBO4C+63TQjvM7WNCEpInq+fu2D0q6orf2ADRtVsf8tpdJ+oqkV6pJd9vea3ur7QUt5hm1vcf2np46BVCrjsfqs/0FSf8laVNEPG17kaR3JYWkhzR1aPB3bZbBbj/QZ53u9ncUftufk/QzSb+IiO/OUF8m6WcRcV2b5RB+oM9qG6jTtiVtkXRwevCrLwLP+IakfbNtEkBzOvm2f6Wk3ZLekPRJNfkBSXdKul5Tu/2HJa2vvhwsLYstP9Bnte7214XwA/1X224/gHMT4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+IKm2N/Cs2buS/nfa64XVtGE0rL0Na18SvXWrzt6u7PSNA72e/zMrt/dExEhjDRQMa2/D2pdEb91qqjd2+4GkCD+QVNPhH2t4/SXD2tuw9iXRW7ca6a3RY34AzWl6yw+gIY2E3/attn9t+03bG5vooRXbh22/UY083OgQY9UwaJO2902bdontXbZ/Wz3OOExaQ70NxcjNhZGlG/3shm3E64Hv9ts+T9JvJN0i6YikVyXdGREHBtpIC7YPSxqJiMbPCdv+S0kfSNpxZjQk2/8s6XhEPFL94VwQEfcPSW8PapYjN/ept1YjS39TDX52dY54XYcmtvw3SHozIg5FxB8k/UjSqgb6GHoR8aKk42dNXiVpe/V8u6b+8wxci96GQkRMRMRr1fP3JZ0ZWbrRz67QVyOaCP8Vkn4/7fURDdeQ3yHpl7bHbY823cwMFp0ZGal6vKzhfs7WduTmQTprZOmh+ey6GfG6bk2Ef6bRRIbplMNXI+LPJf2NpA3V7i068z1JX9bUMG4Tkr7TZDPVyNJPSfpWRJxsspfpZuirkc+tifAfkbR02usvSjraQB8zioij1eOkpJ9q6jBlmBw7M0hq9TjZcD9/FBHHIuJ0RHwi6ftq8LOrRpZ+StIPIuLpanLjn91MfTX1uTUR/lclrbC93PbnJa2VtLOBPj7D9rzqixjZnifpaxq+0Yd3SlpXPV8n6ZkGe/mUYRm5udXI0mr4sxu2Ea8b+ZFPdSrjXyWdJ2lrRGwaeBMzsH2Vprb20tQVjz9ssjfbj0u6WVNXfR2T9G1J/y7pJ5K+JOl3ktZExMC/eGvR282a5cjNfeqt1cjSr6jBz67OEa9r6Ydf+AE58Qs/ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJ/T93Jc3ycrK5KwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x21bc0aebef0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(image_triplets[2510,0,:,:],cmap='gray'),triplet_labels[2509,:]\n",
"# plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((30000, 28, 28, 1), (30000, 1))"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"noisy = np.reshape(image_triplets, [-1,28,28])\n",
"noisy.shape\n",
"noisy_lab = np.zeros((noisy.shape[0],1))\n",
"\n",
"for i in range(0, triplet_labels.shape[0], 3):\n",
" noisy_lab[i],noisy_lab[i+1] ,noisy_lab[i+2] = triplet_labels[(int(i / 3))],triplet_labels[(int(i / 3))],triplet_labels[(int(i / 3))]\n",
"\n",
"noisy = noisy.reshape(noisy.shape[0],28, 28, 1)\n",
"noisy.shape, noisy_lab.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Turn to one hot encoder"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((30000, 28, 28, 1), (30000, 2), (10000, 28, 28, 1), (10000, 2))"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_y = np.array(noisy_lab[:, 0]) # Input\n",
"train_y = np.reshape(train_y, (-1, 1))\n",
"enc = OneHotEncoder()\n",
"enc.fit(train_y)\n",
"out = enc.transform(train_y).toarray()\n",
"noisy_lab = out\n",
"train_y = np.array(test_labels) # Input\n",
"train_y = np.reshape(train_y, (-1, 1))\n",
"enc = OneHotEncoder()\n",
"enc.fit(train_y)\n",
"out = enc.transform(train_y).toarray()\n",
"test_labels=out\n",
"test_images = np.reshape(test_images,(-1,28,28,1))\n",
"noisy.shape,noisy_lab.shape, test_images.shape, test_labels.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I thought about subsampling the triplets at some point, didn't show any signs of improving the score. So I left it commented out."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# rand_triplet = np.random.randint(3, size=10000)\n",
"# plt.hist(rand_triplet, bins='auto') \n",
"# plt.title(\"Histogram with 'auto' bins\")\n",
"# plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is a uniform distribution of the numbers in the triplets. This will guarantee us that we have 66% chance of hitting the right label (since we have 2 images inside the triplet that match the label)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"# noisy_images = np.zeros((10000, train_images.shape[1], train_images.shape[2]))\n",
"# for i in range(10000):\n",
"# rand_triplet = np.random.randint(3, size=1)\n",
"# noisy_images[i, :, :] = image_triplets[i, rand_triplet, :, :]\n",
"# # for i in range(10000):\n",
"# # rand_triplet = np.random.randint(3, size=1)\n",
"# # noisy_images[i+9999, :, :] = image_triplets[i, rand_triplet, :, :]\n",
"# # for i in range(10000):\n",
"# # rand_triplet = np.random.randint(3, size=1)\n",
"# # noisy_images[i+19999, :, :] = image_triplets[i, rand_triplet, :, :]\n",
"# noisy_labels = np.concatenate([triplet_labels])"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# noisy_images.shape, noisy_labels.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We need to have our input fit to the tensorflow module(i.e. batch x height x width x nchannels)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"# noisy_images = noisy_images.reshape(noisy_images.shape[0],28, 28, 1)\n",
"# test_images = test_images.reshape(test_images.shape[0], 28, 28, 1)\n",
"# noisy_images.shape, test_images.shape"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"# noisy_images.shape,noisy_labels.shape, test_images.shape, test_labels.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, we have our engineered dataset , let's move to the neural network, try to train it and test it on the test set."
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From <ipython-input-28-9ba6dd5eb3d7>:41: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"\n",
"Future major versions of TensorFlow will allow gradients to flow\n",
"into the labels input on backprop by default.\n",
"\n",
"See tf.nn.softmax_cross_entropy_with_logits_v2.\n",
"\n"
]
}
],
"source": [
"BATCH_SIZE = 64\n",
"learning_rate = 0.0005\n",
"\n",
"graph = tf.Graph()\n",
"with graph.as_default():\n",
" # Input data.1\n",
" tf_train_dataset = tf.placeholder(tf.float32, shape=(None, train_images.shape[1], train_images.shape[2], 1))\n",
" tf_train_labels = tf.placeholder(tf.int32, shape=(None, 2))\n",
" traindrop = tf.placeholder(tf.bool)\n",
" \n",
" #tf_test_dataset = tf.placeholder(tf.float32, shape=(10000, train_images.shape[1], train_images.shape[2], 1))\n",
"# tf_test_labels = tf.placeholder(tf.float32, shape=(10000, 2))\n",
" #tf_test_dataset = tf.to_float(tf.constant(test_images))\n",
" \n",
" \n",
" #model\n",
" def model(data):\n",
" conv1 = tf.contrib.layers.conv2d(data, num_outputs=32, kernel_size=5,padding ='SAME', activation_fn=tf.nn.relu)\n",
" pool1 = tf.contrib.layers.max_pool2d(conv1, [2,2], 2)\n",
" conv2 = tf.contrib.layers.conv2d(pool1, num_outputs=64, kernel_size=5,padding ='SAME', activation_fn=tf.nn.relu)\n",
" pool2 = tf.contrib.layers.max_pool2d(conv2, [2,2], 2)\n",
" flat = tf.layers.flatten(pool2)\n",
" dense1 = tf.contrib.layers.fully_connected(flat, 1024, activation_fn=tf.nn.relu)\n",
" #dropout = tf.nn.dropout(dense1, 0.5)\n",
" dropout = tf.layers.dropout(inputs=dense1, rate=0.4, training=traindrop)\n",
" dense2 = tf.contrib.layers.fully_connected(dropout, 512, activation_fn=tf.nn.relu)\n",
" logits = tf.layers.dense(inputs=dense2, units=2)\n",
" return logits\n",
"\n",
" # Label smoothing\n",
"# if correct_labels is not None:\n",
"# cr0= 0.6\n",
"# table = tf.convert_to_tensor([cr0, 1.-cr0])\n",
"# tf_train_labels = tf.nn.embedding_lookup(table, tf_train_labels)\n",
" \n",
" \n",
" \n",
" #Loss definition\n",
" logits = model(tf_train_dataset)\n",
" normalized_logits =((0.3/2) + (0.7*logits))\n",
" loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=tf_train_labels, logits=normalized_logits))\n",
"\n",
"\n",
" # Optimizer\n",
" optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss)\n",
"\n",
" # predictions for test\n",
" train_pred = tf.nn.softmax(normalized_logits)\n",
"# test_pred = tf.nn.softmax(model(normalized_logits))\n",
"# pred = tf.nn.softmax(model(tf_test_dataset))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, we finished our neural network model... nothing special, few conv layers, pooling and dropout. eventually evaluating the one hot encoder (2 units).<br>\n",
"The only thing different here is the change in my predictions to p~t=0.3/N+0.7pt (pt are my logits).\n",
"I modeled the noisy labels as uniform noise.\n",
"I thought about implementing <a href= \"https://arxiv.org/pdf/1512.00567.pdf\">label smoothing</a> at some point but it didn't turn out well as well."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"def accuracy(predictions, labels):\n",
" return (100.0 * np.sum(np.argmax(predictions, 1) == np.argmax(labels, 1))/ predictions.shape[0])"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Initialized weights ,biases and other variables\n",
"Training batch loss at epoch 0: 0.695419\n",
"train predict\n",
"[[0.50901407 0.49098596]\n",
" [0.49824655 0.50175345]] (64, 2)\n",
"Minibatch accuracy: 51.6%\n",
"test batch = 0, test minibatch acc = 21.875\n",
"All Test acc : 27.18349358974359\n",
"Training batch loss at epoch 0: 0.629663\n",
"train predict\n",
"[[0.37293997 0.62706006]\n",
" [0.34847245 0.6515275 ]] (64, 2)\n",
"Minibatch accuracy: 65.6%\n",
"test batch = 0, test minibatch acc = 6.25\n",
"All Test acc : 12.880608974358974\n",
"Training batch loss at epoch 0: 0.000000\n",
"train predict\n",
"[[1. 0.]\n",
" [1. 0.]] (64, 2)\n",
"Minibatch accuracy: 100.0%\n",
"test batch = 0, test minibatch acc = 42.1875\n",
"All Test acc : 49.2588141025641\n",
"Training batch loss at epoch 0: 0.000000\n",
"train predict\n",
"[[1.0000000e+00 0.0000000e+00]\n",
" [1.0000000e+00 7.9840876e-36]] (64, 2)\n",
"Minibatch accuracy: 100.0%\n",
"test batch = 0, test minibatch acc = 42.1875\n",
"All Test acc : 49.2588141025641\n",
"Training batch loss at epoch 0: 0.000000\n",
"train predict\n",
"[[1. 0.]\n",
" [1. 0.]] (64, 2)\n",
"Minibatch accuracy: 100.0%\n",
"test batch = 0, test minibatch acc = 42.1875\n",
"All Test acc : 49.2588141025641\n",
"Training batch loss at epoch 1: 62.272629\n",
"train predict\n",
"[[1.0000000e+00 4.8006741e-37]\n",
" [1.0000000e+00 1.0597831e-36]] (64, 2)\n",
"Minibatch accuracy: 48.4%\n",
"test batch = 0, test minibatch acc = 42.1875\n",
"All Test acc : 49.2588141025641\n",
"Training batch loss at epoch 1: 0.640868\n",
"train predict\n",
"[[0.3552155 0.6447845 ]\n",
" [0.37663218 0.6233678 ]] (64, 2)\n",
"Minibatch accuracy: 68.8%\n",
"test batch = 0, test minibatch acc = 14.0625\n",
"All Test acc : 17.037259615384617\n",
"Training batch loss at epoch 1: 0.000000\n",
"train predict\n",
"[[1.0000000e+00 7.3246713e-24]\n",
" [1.0000000e+00 2.1906906e-17]] (64, 2)\n",
"Minibatch accuracy: 100.0%\n",
"test batch = 0, test minibatch acc = 42.1875\n",
"All Test acc : 49.2588141025641\n",
"Training batch loss at epoch 1: 0.000000\n",
"train predict\n",
"[[1.0000000e+00 1.3626768e-17]\n",
" [1.0000000e+00 1.2033073e-15]] (64, 2)\n",
"Minibatch accuracy: 100.0%\n",
"test batch = 0, test minibatch acc = 42.1875\n",
"All Test acc : 49.2588141025641\n",
"Training batch loss at epoch 1: 0.000000\n",
"train predict\n",
"[[1.0000000e+00 1.4841263e-20]\n",
" [1.0000000e+00 7.2137413e-20]] (64, 2)\n",
"Minibatch accuracy: 100.0%\n",
"test batch = 0, test minibatch acc = 42.1875\n",
"All Test acc : 49.2588141025641\n",
"Training batch loss at epoch 2: 23.142492\n",
"train predict\n",
"[[1.0000000e+00 1.2099758e-17]\n",
" [1.0000000e+00 4.4781130e-17]] (64, 2)\n",
"Minibatch accuracy: 48.4%\n",
"test batch = 0, test minibatch acc = 42.1875\n",
"All Test acc : 49.2588141025641\n",
"Training batch loss at epoch 2: 0.695359\n",
"train predict\n",
"[[0.49258307 0.5074169 ]\n",
" [0.4918901 0.50810987]] (64, 2)\n",
"Minibatch accuracy: 40.6%\n",
"test batch = 0, test minibatch acc = 57.8125\n",
"All Test acc : 50.7411858974359\n",
"Training batch loss at epoch 2: 0.000134\n",
"train predict\n",
"[[9.9985874e-01 1.4126042e-04]\n",
" [9.9993229e-01 6.7751578e-05]] (64, 2)\n",
"Minibatch accuracy: 100.0%\n",
"test batch = 0, test minibatch acc = 42.1875\n",
"All Test acc : 49.2588141025641\n",
"Training batch loss at epoch 2: 0.000000\n",
"train predict\n",
"[[9.9999988e-01 1.2150628e-07]\n",
" [9.9999571e-01 4.2739148e-06]] (64, 2)\n",
"Minibatch accuracy: 100.0%\n",
"test batch = 0, test minibatch acc = 42.1875\n",
"All Test acc : 49.2588141025641\n",
"Training batch loss at epoch 2: 0.000000\n",
"train predict\n",
"[[1.0000000e+00 2.5489841e-09]\n",
" [9.9999988e-01 7.4416107e-08]] (64, 2)\n",
"Minibatch accuracy: 100.0%\n",
"test batch = 0, test minibatch acc = 42.1875\n",
"All Test acc : 49.2588141025641\n"
]
}
],
"source": [
"epoch = 3\n",
"TRAIN_DATASIZE,_,_,_ = noisy.shape\n",
"PERIOD = int(TRAIN_DATASIZE/BATCH_SIZE) #Number of iterations for each epoch\n",
"\n",
"\n",
"with tf.Session(graph=graph) as session:\n",
" tf.global_variables_initializer().run()\n",
" print('Initialized weights ,biases and other variables')\n",
" for step in range(epoch):\n",
" idxs = np.random.permutation(TRAIN_DATASIZE) #shuffled ordering\n",
" X_random = noisy[idxs]\n",
" Y_random = noisy_lab[idxs]\n",
" for i in range(PERIOD):\n",
" batch_data = noisy[i * BATCH_SIZE:(i+1) * BATCH_SIZE]\n",
" batch_labels = noisy_lab[i * BATCH_SIZE:(i+1) * BATCH_SIZE]\n",
" feed_dict = {tf_train_dataset : batch_data, tf_train_labels : batch_labels, traindrop:True}\n",
" _,l, train_predictions= session.run( [optimizer, loss, train_pred], feed_dict=feed_dict)\n",
" if (i % 100 == 0):\n",
" print('Training batch loss at epoch %d: %f' % (step, l))\n",
" print('train predict')\n",
" train_predictions = np.array(train_predictions)\n",
" print(train_predictions[:2,:], train_predictions.shape)\n",
" print('Minibatch accuracy: %.1f%%' % accuracy(train_predictions, batch_labels))\n",
" # Evaluation:\n",
" # Eval\n",
" X = test_images\n",
" Y = test_labels\n",
" BATCH_SIZE_TEST = 64\n",
" PERIOD_TEST = int(X.shape[0]/BATCH_SIZE)\n",
" all_acc = []\n",
"\n",
" for j in range(PERIOD_TEST):\n",
" batch_data = X[j * BATCH_SIZE_TEST:(j+1) * BATCH_SIZE_TEST]\n",
" batch_labels = Y[j * BATCH_SIZE_TEST:(j+1) * BATCH_SIZE_TEST]\n",
"\n",
" feed_dict = {tf_train_dataset : batch_data , traindrop:False}\n",
" train_predictions= session.run( train_pred, feed_dict=feed_dict)\n",
" acc = accuracy(train_predictions, batch_labels)\n",
" all_acc.append(acc)\n",
" if (j%200 == 0):\n",
" print('test batch number = {}, test minibatch acc = {}'.format(j,acc))\n",
" print('All Test acc : {}'.format(np.mean(all_acc)))\n",
" \n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So we didn't get a good fit to the test set (50% is as good as a random guess), we got a decent fit to the training set, but every shuffling after epoch ended the NN tried to recoil itself, more work on avoiding overfitting is advised here(smaller net, higher dropout rate...)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (myenv)",
"language": "python",
"name": "myenv"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment