Last active
May 5, 2018 16:25
-
-
Save Z30G0D/c9384ce52c6fe2b561720b7851715079 to your computer and use it in GitHub Desktop.
Combined loss function for imubit
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Imubit\n", | |
"This is my solution for the challenge. I will explain my steps through this notebook. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import tensorflow as tf\n", | |
"import mnist\n", | |
"import matplotlib.pyplot as plt\n", | |
"from sklearn.preprocessing import OneHotEncoder\n", | |
"from IPython.display import display, Math, Latex\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"I used the \"mnist\" package to ease my work. located here:https://github.com/datapythonista/mnist" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Use the retry module or similar alternatives.\n", | |
"WARNING:tensorflow:From <ipython-input-2-6441638eb0cd>:1: load_dataset (from tensorflow.contrib.learn.python.learn.datasets) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Please use tf.data.\n", | |
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\__init__.py:80: load_mnist (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n", | |
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:300: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n", | |
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Please write your own downloading logic.\n", | |
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Please use tf.data to implement this functionality.\n", | |
"Extracting MNIST-data\\train-images-idx3-ubyte.gz\n", | |
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Please use tf.data to implement this functionality.\n", | |
"Extracting MNIST-data\\train-labels-idx1-ubyte.gz\n", | |
"Extracting MNIST-data\\t10k-images-idx3-ubyte.gz\n", | |
"Extracting MNIST-data\\t10k-labels-idx1-ubyte.gz\n", | |
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n" | |
] | |
} | |
], | |
"source": [ | |
"mnist = tf.contrib.learn.datasets.load_dataset(\"mnist\")\n", | |
"train_data = mnist.train.images # Returns np.array\n", | |
"train_labels = np.asarray(mnist.train.labels, dtype=np.int32)\n", | |
"eval_data = mnist.test.images # Returns np.array\n", | |
"eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_images = train_data\n", | |
"test_images = eval_data\n", | |
"test_labels = eval_labels\n", | |
"train_images = np.reshape(train_images, (-1,28,28))\n", | |
"test_images = np.reshape(test_images,(-1,28,28))\n", | |
"train_labels = np.reshape(train_labels, (-1,1))\n", | |
"test_labels = np.reshape(test_labels, (-1,1))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((55000, 28, 28), (55000, 1), (10000, 28, 28), (10000, 1))" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_images.shape, train_labels.shape, test_images.shape, test_labels.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAADa9JREFUeJzt3X+IHfW5x/HPczXFsA2SUEwXs3VjXOqtYm1ZVGi4KtUSNZhEbGgQkktLt5gKVvtHRSUNSKWR29iLSCW1oSmktgHdGkK5SZFSW7hRN6KNbWwTltj8WHcjKdYmSIj79I+dLZu45ztnz5k5M+Z5vyCcH8+ZmYdDPjtzznfmfM3dBSCe/6i6AQDVIPxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4I6v5MbMzNOJwRK5u7WzOva2vOb2RIz+4uZHTCzB9pZF4DOslbP7Tez8yT9VdLNkg5LekXSKnf/c2IZ9vxAyTqx579G0gF3H3b3U5J+IWlZG+sD0EHthP9iSYemPD6cPXcGMxswsyEzG2pjWwAK1s4XftMdWnzosN7dN0naJHHYD9RJO3v+w5J6pjxeIOloe+0A6JR2wv+KpD4zW2hmH5P0FUnbi2kLQNlaPux399Nmdo+knZLOk7TZ3f9UWGcAStXyUF9LG+MzP1C6jpzkA+Cji/ADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoDo6RTdQF729vcn69ddfn6xfd911yfpTTz2VrL/++uvJeiew5weCIvxAUIQfCIrwA0ERfiAowg8ERfiBoNqapdfMDkp6T9IHkk67e3/O65mlFx2zevXqhrWNGzcml503b15b2x4dHU3Wu7u721p/SrOz9BZxks+N7v5OAesB0EEc9gNBtRt+l7TLzPaY2UARDQHojHYP+7/g7kfN7CJJvzGzN939xakvyP4o8IcBqJm29vzufjS7HZM0KOmaaV6zyd37874MBNBZLYffzLrMbM7kfUlfkvRGUY0BKFc7h/3zJQ2a2eR6fu7u/1dIVwBK13L43X1Y0mcL7AU4Q95Y+0MPPZSsDww0/qqpq6urpZ4mvf/++8n67t2721p/JzDUBwRF+IGgCD8QFOEHgiL8QFCEHwiKn+5Gbd11113J+n333Vfatvft25esr1mzJlkfGhoqsp1SsOcHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAY50epVq5c2bC2YsWK5LK33357W9seHx9vWHv00UeTy27YsCFZP3HiREs91Ql7fiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IinH+c8CiRYsa1tauXZtc9uTJk8n6gQMHkvU77rgjWV+6dGnDWjbnQ8vGxsaS9bvvvrthbXBwsK1tnwvY8wNBEX4gKMIPBEX4gaAIPxAU4QeCIvxAULnj/Ga2WdJSSWPufmX23DxJv5TUK+mgpJXu/vfy2jy3zZ49O1nv6+tL1nfs2NGwtmDBgpZ6qoMjR44k68uXL0/W9+zZU2Q755xm9vw/lbTkrOcekPSCu/dJeiF7DOAjJDf87v6ipONnPb1M0pbs/hZJ6T/BAGqn1c/88919RJKy24uKawlAJ5R+br+ZDUgaKHs7AGam1T3/qJl1S1J22/AKC3ff5O797t7f4rYAlKDV8G+XNDlN6RpJzxfTDoBOyQ2/mT0j6f8lfdrMDpvZ1yR9X9LNZrZf0s3ZYwAfIebunduYWec2ViPXXnttsv70008n61dccUXL23733XeT9fPPT3/t09XV1fK2pfTv269bty657JNPPpmsnzp1qqWeznXu3tQPJXCGHxAU4QeCIvxAUIQfCIrwA0ERfiAofrq7A5YsOfuiyDPlDeXl/cR1arj24YcfTi6bd9nsjTfemKzneeKJJxrW8n4WHOVizw8ERfiBoAg/EBThB4Ii/EBQhB8IivADQXFJbwfceeedyfq2bdtK2/bbb7+drD/22GPJ+tatW5P1Y8eOzbgnlItLegEkEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIzzd8DcuXOT9eHh4WT9wgsvLLKdGXnrrbeS9fvvvz9ZHxwcLLIdNIFxfgBJhB8IivADQRF+ICjCDwRF+IGgCD8QVO44v5ltlrRU0pi7X5k9t17S1yVNXsz9oLv/OndjQcf58+RNk33JJZck67feemvD2k033ZRctq+vL1m//PLLk/W8/z/r169vWHvkkUeSy6I1RY7z/1TSdLNOPO7uV2f/coMPoF5yw+/uL0o63oFeAHRQO5/57zGzP5rZZjNLn78KoHZaDf+PJC2SdLWkEUk/aPRCMxswsyEzG2pxWwBK0FL43X3U3T9w93FJP5Z0TeK1m9y93937W20SQPFaCr+ZdU95uELSG8W0A6BTcqfoNrNnJN0g6RNmdljSdyXdYGZXS3JJByV9o8QeAZSA6/mDyzvH4M0330zWL7300mR9165dDWurV69OLjs2NpasY3pczw8gifADQRF+ICjCDwRF+IGgCD8QVO44P85tZulRoZ6enrbWn7oc+fhxrherEnt+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcf7gFi5cmKzPmjWrrfWPj483rJ0+fbqtdaM97PmBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjG+TNz5sxJ1i+44IKGtRMnTiSXPXnyZEs9Nau7u7thbd26dcllV61aVXQ7Z7j33ntLXT9ax54fCIrwA0ERfiAowg8ERfiBoAg/EBThB4LKnaLbzHok/UzSJyWNS9rk7v9rZvMk/VJSr6SDkla6+99z1lXbKbqHh4eT9d7e3oa1vXv3Jpfdv39/Ky017ZZbbmlYmz17dqnbfvnll5P1xYsXN6xxPX85ipyi+7Skb7v7f0q6TtI3zewzkh6Q9IK790l6IXsM4CMiN/zuPuLur2b335O0T9LFkpZJ2pK9bIuk5WU1CaB4M/rMb2a9kj4n6SVJ8919RJr4AyHpoqKbA1Ceps/tN7OPS3pW0rfc/R95c7xNWW5A0kBr7QEoS1N7fjObpYngb3X357KnR82sO6t3Sxqbbll33+Tu/e7eX0TDAIqRG36b2MX/RNI+d984pbRd0prs/hpJzxffHoCyNDPUt1jS7yXt1cRQnyQ9qInP/dskfUrS3yR92d2Tcy7XeajvtddeS9avuuqqDnXSWXnDbTt37kzW165dm6wfOnRoxj2hPc0O9eV+5nf3P0hqtLIvzqQpAPXBGX5AUIQfCIrwA0ERfiAowg8ERfiBoHLH+QvdWI3H+W+77bZk/fHHH29Yu+yyy4puZ0aOHDnSsLZ79+7kshs2bEjWh4aGWuoJ1Snykl4A5yDCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcf4mzZ07t2Gtp6eng5182LFjxxrWRkZGOtgJ6oBxfgBJhB8IivADQRF+ICjCDwRF+IGgCD8QFOP8wDmGcX4ASYQfCIrwA0ERfiAowg8ERfiBoAg/EFRu+M2sx8x+a2b7zOxPZnZv9vx6MztiZq9l/24tv10ARck9ycfMuiV1u/urZjZH0h5JyyWtlPRPd/+fpjfGST5A6Zo9yef8JlY0Imkku/+eme2TdHF77QGo2ow+85tZr6TPSXope+oeM/ujmW02s2l/58rMBsxsyMyY9wmokabP7Tezj0v6naTvuftzZjZf0juSXNIjmvho8NWcdXDYD5Ss2cP+psJvZrMk7ZC00903TlPvlbTD3a/MWQ/hB0pW2IU9ZmaSfiJp39TgZ18ETloh6Y2ZNgmgOs18279Y0u8l7ZU0nj39oKRVkq7WxGH/QUnfyL4cTK2LPT9QskIP+4tC+IHycT0/gCTCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAULk/4FmwdyS9NeXxJ7Ln6qiuvdW1L4neWlVkb5c0+8KOXs//oY2bDbl7f2UNJNS1t7r2JdFbq6rqjcN+ICjCDwRVdfg3Vbz9lLr2Vte+JHprVSW9VfqZH0B1qt7zA6hIJeE3syVm9hczO2BmD1TRQyNmdtDM9mYzD1c6xVg2DdqYmb0x5bl5ZvYbM9uf3U47TVpFvdVi5ubEzNKVvnd1m/G644f9ZnaepL9KulnSYUmvSFrl7n/uaCMNmNlBSf3uXvmYsJn9l6R/SvrZ5GxIZvaYpOPu/v3sD+dcd/9OTXpbrxnO3FxSb41mlv5vVfjeFTnjdRGq2PNfI+mAuw+7+ylJv5C0rII+as/dX5R0/Kynl0nakt3foon/PB3XoLdacPcRd381u/+epMmZpSt97xJ9VaKK8F8s6dCUx4dVrym/XdIuM9tjZgNVNzON+ZMzI2W3F1Xcz9lyZ27upLNmlq7Ne9fKjNdFqyL8080mUqchhy+4++cl3SLpm9nhLZrzI0mLNDGN24ikH1TZTDaz9LOSvuXu/6iyl6mm6auS962K8B+W1DPl8QJJRyvoY1rufjS7HZM0qImPKXUyOjlJanY7VnE//+buo+7+gbuPS/qxKnzvspmln5W01d2fy56u/L2brq+q3rcqwv+KpD4zW2hmH5P0FUnbK+jjQ8ysK/siRmbWJelLqt/sw9slrcnur5H0fIW9nKEuMzc3mllaFb93dZvxupKTfLKhjB9KOk/SZnf/XsebmIaZXaqJvb00ccXjz6vszcyekXSDJq76GpX0XUm/krRN0qck/U3Sl92941+8NejtBs1w5uaSems0s/RLqvC9K3LG60L64Qw/ICbO8AOCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/ENS/AJvlHp92kzVQAAAAAElFTkSuQmCC\n", | |
"text/plain": [ | |
"<matplotlib.figure.Figure at 0x11a21458fd0>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.imshow(train_images[6785],cmap='gray')\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Great, seems like the shape are okay, let's handle the labels as requested. let's start with the the test set" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(array([0, 1]), list)" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"test_labels=[1 if i % 2 != 0 else 0 for i in test_labels]\n", | |
"np.unique(test_labels), type(test_labels)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Ok, let's go preprocess our training set:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_images_odd = np.array([train_images[i,:,:] for i in range(train_images.shape[0]) if train_labels[i] % 2 != 0])\n", | |
"train_images_even = np.array([train_images[i,:,:] for i in range(train_images.shape[0]) if train_labels[i] % 2 == 0])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's also seperate the training images to \"odd\" and \"even\" training images, this is in order to create the triplets required\n", | |
"in the exercise more easily." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((55000, 28, 28), (27027, 28, 28), (27973, 28, 28))" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_images.shape, train_images_even.shape, train_images_odd.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"image_triplets = np.zeros((10000, 3, train_images.shape[1], train_images.shape[2])) #initialize triplet data array (10000 triplets of pictures)\n", | |
"triplet_labels = np.zeros((10000, 1)) # initialize triplet label array\n", | |
"for i in range(10000):\n", | |
" rand_mnist = np.random.randint(55000, size=1) #a random image from mnist training set\n", | |
" if train_labels[rand_mnist] % 2 != 0: # checking if number is odd\n", | |
" rand_even = np.random.randint(train_images_even.shape[0], size=2) #randomizing 2 even numbers\n", | |
"# combine_images = np.concatenate((np.ravel(train_images[rand_mnist,:]), train_images_even[rand_even[0], :],train_images_even[rand_even[1], :]), axis=0) # combine 3 images\n", | |
" image_triplets[i, 0, :, :] = train_images[rand_mnist,:, :] #add 1 odd and 2 even images to triplet array\n", | |
" image_triplets[i, 1, :, :] = train_images_even[rand_even[0], :, :]\n", | |
" image_triplets[i, 2, :, :] = train_images_even[rand_even[1], :, :]\n", | |
" triplet_labels[i] = 0 # adding appropriate binary label\n", | |
" else:\n", | |
" rand_odd = np.random.randint(train_images_odd.shape[0], size=2)#randomizing 2 odd numbers\n", | |
" #combine_images = np.concatenate((np.ravel(train_images[rand_mnist,:]), train_images_odd[rand_odd[0], :],train_images_odd[rand_odd[1], :]), axis=0) # combine 3 images\n", | |
" image_triplets[i, 0, :, :] = train_images[rand_mnist,:, :] #add 1 odd and 2 even images to triplet array\n", | |
" image_triplets[i, 1, :, :] = train_images_odd[rand_odd[0], :, :]\n", | |
" image_triplets[i, 2, :, :] = train_images_odd[rand_odd[1], :, :]\n", | |
" triplet_labels[i] = 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((10000, 3, 28, 28), (10000, 1))" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"image_triplets.shape,triplet_labels.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Great, we have our triplets, We got 10000 samples in each sample we have 3 images (either 1 odd and 2 even <u>or</u> 2 odd and 1 even)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAC8lJREFUeJzt3WGoXPWZx/Hvs27zJi0YKWajcU2tQXYJ1C43suCyKMXqroEYodKAS2RL0xcVtuCLFd9UWAqybLu7+KKQYmiKbdqgcQ1xMS2yrF1Y5Eap1TabNmhsYmKywULsqxJ99sU9WW7jvWfmzpyZM+nz/UCYmfOfO/NjyO/+z8w5d/6RmUiq5w/6DiCpH5ZfKsryS0VZfqkoyy8VZfmloiy/VJTll4qy/FJRfzjNJ4sITyeUJiwzY5j7jTXzR8RdEXE0Io5FxMPjPJak6YpRz+2PiCuAXwB3ACeBeWB7Zv685Wec+aUJm8bMfwtwLDPfyMzfAt8Hto7xeJKmaJzyXwucWHT7ZLPtd0TEzog4HBGHx3guSR0b5wO/pXYtPrRbn5m7gF3gbr80S8aZ+U8C1y26vR44NV4cSdMyTvnngY0R8YmIWAV8HjjQTSxJkzbybn9mXoiIB4FDwBXA7sz8WWfJJE3UyIf6Rnoy3/NLEzeVk3wkXb4sv1SU5ZeKsvxSUZZfKsryS0VZfqkoyy8VZfmloiy/VJTll4qy/FJRll8qyvJLRVl+qSjLLxVl+aWiLL9UlOWXirL8UlGWXypqqkt0a/asXr26dXzv3r2t4xcuXGgdv/fee1ecSdPhzC8VZfmloiy/VJTll4qy/FJRll8qyvJLRY11nD8ijgPvAe8DFzJzrotQmp4HHnigdfzuu+9uHX/88cc7TKNp6uIkn9sz81wHjyNpitztl4oat/wJ/DAiXo6InV0EkjQd4+7235qZpyLiauBHEfE/mfni4js0vxT8xSDNmLFm/sw81VyeBZ4BblniPrsyc84PA6XZMnL5I2J1RHzs4nXgs8DrXQWTNFnj7PavBZ6JiIuP873MfL6TVJImbuTyZ+YbwKc6zKIebNu2bayff/PNNztKomnzUJ9UlOWXirL8UlGWXyrK8ktFWX6pKMsvFWX5paIsv1SU5ZeKsvxSUZZfKsryS0VZfqkol+jWWJ5/3q9wuFw580tFWX6pKMsvFWX5paIsv1SU5ZeKsvxSUR7n11iOHj3adwSNyJlfKsryS0VZfqkoyy8VZfmloiy/VJTll4oaeJw/InYDW4Czmbmp2XYV8ANgA3AcuC8zfz25mOrL/Px83xE0IcPM/N8G7rpk28PAC5m5EXihuS3pMjKw/Jn5IvDuJZu3Anua63uAezrOJWnCRn3PvzYzTwM0l1d3F0nSNEz83P6I2AnsnPTzSFqZUWf+MxGxDqC5PLvcHTNzV2bOZebciM8laQJGLf8BYEdzfQfwbDdxJE3LwPJHxF7gv4GbIuJkRHwBeAy4IyJ+CdzR3JZ0GRn4nj8zty8z9JmOs2gC1q9f3zp+0003tY4/9dRTXcbRDPEMP6koyy8VZfmloiy/VJTll4qy/FJRfnX377krr7yydfyaa66ZUhLNGmd+qSjLLxVl+aWiLL9UlOWXirL8UlGWXyoqMnN6TxYxvScTAJs2bWodf/XVV1vHz58/3zq+Zs2aFWfSZGVmDHM/Z36pKMsvFWX5paIsv1SU5ZeKsvxSUZZfKsq/5/89984777SOHzt2rHV848aNXcbRDHHml4qy/FJRll8qyvJLRVl+qSjLLxVl+aWiBh7nj4jdwBbgbGZuarY9CnwR+N/mbo9k5r9PKqRGd+7cudbxEydOtI7feOONXcbRDBlm5v82cNcS2/85M29u/ll86TIzsPyZ+SLw7hSySJqicd7zPxgRP42I3RHhdzlJl5lRy/9N4JPAzcBp4OvL3TEidkbE4Yg4POJzSZqAkcqfmWcy8/3M/AD4FnBLy313ZeZcZs6NGlJS90Yqf0SsW3RzG/B6N3EkTcswh/r2ArcBH4+Ik8BXgdsi4mYggePAlyaYUdIEDCx/Zm5fYvMTE8iiy9D999/fOv7kk09OKYlWyjP8pKIsv1SU5ZeKsvxSUZZfKsryS0X51d0ay5YtW1rHPdQ3u5z5paIsv1SU5ZeKsvxSUZZfKsryS0VZfqkoj/NrLJs3b+47gkbkzC8VZfmloiy/VJTll4qy/FJRll8qyvJLRXmcv7h9+/a1jt9+++1TSqJpc+aXirL8UlGWXyrK8ktFWX6pKMsvFWX5paIGHuePiOuA7wB/BHwA7MrMf42Iq4AfABuA48B9mfnryUXVJLz99tut45nZOn799de3jt95553Ljh06dKj1ZzVZw8z8F4CHMvNPgD8HvhwRfwo8DLyQmRuBF5rbki4TA8ufmacz85Xm+nvAEeBaYCuwp7nbHuCeSYWU1L0VveePiA3Ap4GXgLWZeRoWfkEAV3cdTtLkDH1uf0R8FHga+Epmno+IYX9uJ7BztHiSJmWomT8iPsJC8b+bmfubzWciYl0zvg44u9TPZuauzJzLzLkuAkvqxsDyx8IU/wRwJDO/sWjoALCjub4DeLb7eJImZZjd/luBvwFei4ifNNseAR4D9kXEF4BfAZ+bTERN0nPPPdc6/tZbb7WO33DDDa3jq1atWnEmTcfA8mfmfwHLvcH/TLdxJE2LZ/hJRVl+qSjLLxVl+aWiLL9UlOWXivKru9Vq//79reMPPfRQ6/ig8wDUH2d+qSjLLxVl+aWiLL9UlOWXirL8UlGWXyrK4/xqdfDgwdbxzZs3t47Pz893GUcdcuaXirL8UlGWXyrK8ktFWX6pKMsvFWX5paJi0BLMnT5ZxPSeTCoqM4daS8+ZXyrK8ktFWX6pKMsvFWX5paIsv1SU5ZeKGlj+iLguIv4jIo5ExM8i4u+a7Y9GxNsR8ZPm319PPq6krgw8ySci1gHrMvOViPgY8DJwD3Af8JvM/Kehn8yTfKSJG/Ykn4Hf5JOZp4HTzfX3IuIIcO148ST1bUXv+SNiA/Bp4KVm04MR8dOI2B0Ra5b5mZ0RcTgiDo+VVFKnhj63PyI+Cvwn8LXM3B8Ra4FzQAL/wMJbg78d8Bju9ksTNuxu/1Dlj4iPAAeBQ5n5jSXGNwAHM3PTgMex/NKEdfaHPRERwBPAkcXFbz4IvGgb8PpKQ0rqzzCf9v8F8GPgNeCDZvMjwHbgZhZ2+48DX2o+HGx7LGd+acI63e3viuWXJs+/55fUyvJLRVl+qSjLLxVl+aWiLL9UlOWXirL8UlGWXyrK8ktFWX6pKMsvFWX5paIsv1TUwC/w7Ng54K1Ftz/ebJtFs5ptVnOB2UbVZbbrh73jVP+e/0NPHnE4M+d6C9BiVrPNai4w26j6yuZuv1SU5ZeK6rv8u3p+/jazmm1Wc4HZRtVLtl7f80vqT98zv6Se9FL+iLgrIo5GxLGIeLiPDMuJiOMR8Vqz8nCvS4w1y6CdjYjXF227KiJ+FBG/bC6XXCatp2wzsXJzy8rSvb52s7bi9dR3+yPiCuAXwB3ASWAe2J6ZP59qkGVExHFgLjN7PyYcEX8J/Ab4zsXVkCLiH4F3M/Ox5hfnmsz8+xnJ9igrXLl5QtmWW1n6AXp87bpc8boLfcz8twDHMvONzPwt8H1gaw85Zl5mvgi8e8nmrcCe5voeFv7zTN0y2WZCZp7OzFea6+8BF1eW7vW1a8nViz7Kfy1wYtHtk8zWkt8J/DAiXo6InX2HWcLaiysjNZdX95znUgNXbp6mS1aWnpnXbpQVr7vWR/mXWk1klg453JqZfwb8FfDlZvdWw/km8EkWlnE7DXy9zzDNytJPA1/JzPN9ZllsiVy9vG59lP8kcN2i2+uBUz3kWFJmnmouzwLPsPA2ZZacubhIanN5tuc8/y8zz2Tm+5n5AfAtenztmpWlnwa+m5n7m829v3ZL5errdeuj/PPAxoj4RESsAj4PHOghx4dExOrmgxgiYjXwWWZv9eEDwI7m+g7g2R6z/I5ZWbl5uZWl6fm1m7UVr3s5yac5lPEvwBXA7sz82tRDLCEibmBhtoeFv3j8Xp/ZImIvcBsLf/V1Bvgq8G/APuCPgV8Bn8vMqX/wtky221jhys0TyrbcytIv0eNr1+WK153k8Qw/qSbP8JOKsvxSUZZfKsryS0VZfqkoyy8VZfmloiy/VNT/Afz9Z1kTvWa/AAAAAElFTkSuQmCC\n", | |
"text/plain": [ | |
"<matplotlib.figure.Figure at 0x11a23ea62e8>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.imshow(image_triplets[2541,2,:,:],cmap='gray')\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We need to have our input fit to the tensorflow module(i.e. batch x height x width x nchannels)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_images = np.array(test_images)\n", | |
"test_labels = np.array(test_labels)\n", | |
"image_tri = np.reshape(image_triplets, [-1, 3, 28, 28 , 1])\n", | |
"test_labels = np.reshape(test_labels, [-1, 1])\n", | |
"test_images = np.reshape(test_images,[-1, 28,28 ,1])\n", | |
"# noisy.shape\n", | |
"# noisy_lab = np.zeros((noisy.shape[0],1))\n", | |
"\n", | |
"# for i in range(0, triplet_labels.shape[0], 3):\n", | |
"# noisy_lab[i],noisy_lab[i+1] ,noisy_lab[i+2] = triplet_labels[(int(i / 3))],triplet_labels[(int(i / 3))],triplet_labels[(int(i / 3))]\n", | |
"\n", | |
"# noisy = noisy.reshape(noisy.shape[0],28, 28, 1)\n", | |
"# noisy.shape, noisy_lab.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((10000, 3, 28, 28, 1), (10000, 1), (10000, 28, 28, 1), (10000, 1))" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"image_tri.shape,triplet_labels.shape, test_images.shape, test_labels.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's change to one hot encoder" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((10000, 3, 28, 28, 1), (10000, 2), (10000, 28, 28, 1), (10000, 2))" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_y = np.array(triplet_labels[:, 0]) # Input\n", | |
"train_y = np.reshape(train_y, (-1, 1))\n", | |
"enc = OneHotEncoder()\n", | |
"enc.fit(train_y)\n", | |
"out = enc.transform(train_y).toarray()\n", | |
"triplet_labels = out\n", | |
"train_y = np.array(test_labels) # Input\n", | |
"train_y = np.reshape(train_y, (-1, 1))\n", | |
"enc = OneHotEncoder()\n", | |
"enc.fit(train_y)\n", | |
"out = enc.transform(train_y).toarray()\n", | |
"test_labels=out\n", | |
"image_tri.shape,triplet_labels.shape, test_images.shape, test_labels.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Ok, we have our engineered dataset , let's move to the neural network, try to train it and test it on the test set.<br>\n", | |
"My solution is based on the next derivation(z = g(y1,y2,y3) and equals to 1 or 0, Y is a set of y and X is a set of x), assume we want to maximize :<br>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"\\begin{eqnarray}\n", | |
"log p(z | X) = log \\sum_{y : z = g(Y)} p(Y | X)\n", | |
"\\end{eqnarray}\n", | |
"where:\n", | |
"\\begin{eqnarray}\n", | |
"P(Y | X) = p(y1 | x1) p(y2 | x2) p(y3 | x3)\n", | |
"\\end{eqnarray}\n", | |
"\n", | |
"Is plugged in to the right side.<br>\n", | |
"\n", | |
"So basically every triplet label will give me a possible set of y labels and according to that I could have the appropriate loss.<br>\n", | |
"This creates a situation as I don't train my system on the actual labels of z because my predictions are related to the y.<br>\n", | |
"In order to implement this I tried conditioning every sample inside the batch (a label for a triplet).\n", | |
"If it's 0 it gets(1,0,0), (0,1,0) and (0,0,1) sum of losses.\n", | |
"If it's 1 gets(1,1,0), (0,1,1) and (1,0,1) sum of losses.\n", | |
"\n", | |
"All losses are cross entropy since we are dealing with classification.<br>\n", | |
"First, all three pictures are forward propagated through the network and then the loss is calculated accordingly.\n", | |
"I had issues with the implementation here since I am not sure how to condition on each sample in a batch inside the graph." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"BATCH_SIZE = 256\n", | |
"learning_rate = 0.0005\n", | |
"\n", | |
"\n", | |
"graph = tf.Graph()\n", | |
"with graph.as_default():\n", | |
" # Input data.\n", | |
" # attached to the graph.\n", | |
" tf_train_dataset1 = tf.placeholder(tf.float32, shape=(BATCH_SIZE, train_images.shape[1], train_images.shape[2], 1))\n", | |
" tf_train_dataset2 = tf.placeholder(tf.float32, shape=(BATCH_SIZE, train_images.shape[1], train_images.shape[2], 1))\n", | |
" tf_train_dataset3 = tf.placeholder(tf.float32, shape=(BATCH_SIZE, train_images.shape[1], train_images.shape[2], 1))\n", | |
"\n", | |
"\n", | |
" tf_train_labels = tf.placeholder(tf.float32, shape=(BATCH_SIZE, 2))\n", | |
" traindrop = tf.placeholder(tf.bool)\n", | |
" tf_test_dataset = tf.placeholder(tf.float32, shape=(None, train_images.shape[1], train_images.shape[2], 1))\n", | |
"# tf_test_labels = tf.placeholder(tf.float32, shape=(10000, 2))\n", | |
"# tf_test_dataset = tf.to_float(tf.constant(test_images))\n", | |
" \n", | |
" \n", | |
" #model\n", | |
" def model(data):\n", | |
" conv1 = tf.contrib.layers.conv2d(data, num_outputs=32, kernel_size=5,padding ='SAME', activation_fn=tf.nn.relu,\n", | |
" scope='conv1')\n", | |
" pool1 = tf.contrib.layers.max_pool2d(conv1, [2,2], 2, scope='pool1')\n", | |
" conv2 = tf.contrib.layers.conv2d(pool1, num_outputs=64, kernel_size=5,padding ='SAME', activation_fn=tf.nn.relu,\n", | |
" scope='conv2')\n", | |
" pool2 = tf.contrib.layers.max_pool2d(conv2, [2,2], 2, scope='pool2')\n", | |
" flat = tf.layers.flatten(pool2)\n", | |
" dense1 = tf.contrib.layers.fully_connected(flat, 1024, activation_fn=tf.nn.relu, scope='dense1')\n", | |
" #dropout = tf.nn.dropout(dense1, 0.5)\n", | |
" dropout = tf.layers.dropout(inputs=dense1, rate=0.6, training=traindrop, name='dropout')\n", | |
" out = tf.layers.dense(inputs=dropout, units=2, name='out')\n", | |
" return out\n", | |
" # multiple inputs\n", | |
" \n", | |
" with tf.variable_scope('net'):\n", | |
" out1 = model(tf_train_dataset1)\n", | |
" with tf.variable_scope('net', reuse=True):\n", | |
" out2 = model(tf_train_dataset2)\n", | |
" out3 = model(tf_train_dataset3)\n", | |
" zero_class = np.zeros(shape=(BATCH_SIZE, 2))\n", | |
" zero_class[:,0] = 1\n", | |
" one_class = np.zeros(shape=(BATCH_SIZE, 2))\n", | |
" one_class[:,1] = 1\n", | |
" zero_class = tf.constant(zero_class)\n", | |
" one_class = tf.constant(one_class)\n", | |
" loss=0\n", | |
" #Loss definition\n", | |
" for i in range(BATCH_SIZE):\n", | |
" lossout11 = tf.nn.softmax_cross_entropy_with_logits_v2(labels = one_class[i,:], logits = out1[i,:])\n", | |
" lossout10 = tf.nn.softmax_cross_entropy_with_logits_v2(labels = zero_class[i,:], logits = out1[i,:])\n", | |
" lossout21 = tf.nn.softmax_cross_entropy_with_logits_v2(labels = one_class[i,:], logits = out2[i,:])\n", | |
" lossout20 = tf.nn.softmax_cross_entropy_with_logits_v2(labels = zero_class[i,:], logits = out2[i,:])\n", | |
" lossout31 = tf.nn.softmax_cross_entropy_with_logits_v2(labels = one_class[i,:], logits = out3[i,:])\n", | |
" lossout30 = tf.nn.softmax_cross_entropy_with_logits_v2(labels = zero_class[i,:], logits = out3[i,:])\n", | |
" if tf_train_labels[i,:]==[1,0]:\n", | |
" loss = loss + (lossout11*lossout20*lossout30) +(lossout10*lossout21*lossout30) + (lossout10*lossout20*lossout31)\n", | |
" else:\n", | |
" loss = loss + (lossout11*lossout21*lossout30) + (lossout10*lossout21*lossout31) + (lossout11*lossout20*lossout31)\n", | |
"\n", | |
" # Optimizer\n", | |
" optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss)\n", | |
"\n", | |
" # predictions for test\n", | |
"# train_pred = tf.nn.softmax(normalized_logits)\n", | |
" predictions_test = tf.nn.softmax(model(tf_test_dataset))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Ok, we finished our neural network model... nothing special, few conv layers, pooling and dropout. eventually evaluating the one hot encoder (2 units)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def accuracy(predictions, labels):\n", | |
" #return (100.0 * np.sum(predictions*labels > 0))/predictions.shape[0]\n", | |
" return (100.0 * np.sum(np.argmax(predictions, 1) == np.argmax(labels, 1))/ predictions.shape[0])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Initialized weights ,biases and other variables\n", | |
"Periods number is: 39\n", | |
"Training batch loss 0: 311.717865\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"Training batch loss 0: 6.089136\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"Training batch loss 0: 0.035357\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"Training batch loss 0: 0.000270\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"Training batch loss 0: 0.000010\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"Training batch loss 0: 0.000000\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"Training batch loss 0: 0.000000\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"Training batch loss 0: 0.000000\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"Training batch loss 0: 0.000000\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"Training batch loss 0: 0.000000\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"Training batch loss 0: 0.000000\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"Training batch loss 0: 0.000000\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n", | |
"Training batch loss 0: 0.000000\n", | |
"[[0.51597804 0.484022 ]\n", | |
" [0.51549846 0.48450157]\n", | |
" [0.50839037 0.4916096 ]] [[0. 1.]\n", | |
" [1. 0.]\n", | |
" [0. 1.]]\n", | |
"j = 0, acc = 48.94\n", | |
"All acc : 49.379999999999995\n" | |
] | |
}, | |
{ | |
"ename": "KeyboardInterrupt", | |
"evalue": "", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", | |
"\u001b[1;32m<ipython-input-18-aff961868499>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 33\u001b[0m \u001b[0mbatch_labels\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mY\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mj\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mBATCH_SIZE_TEST\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mj\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mBATCH_SIZE_TEST\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 34\u001b[0m \u001b[0mfeed_dict\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m{\u001b[0m\u001b[0mtf_test_dataset\u001b[0m \u001b[1;33m:\u001b[0m \u001b[0mbatch_data\u001b[0m \u001b[1;33m,\u001b[0m \u001b[0mtraindrop\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 35\u001b[1;33m \u001b[0mtest_predictions\u001b[0m\u001b[1;33m=\u001b[0m \u001b[0msession\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m \u001b[0mpredictions_test\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mfeed_dict\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 36\u001b[0m \u001b[0macc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0maccuracy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtest_predictions\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_labels\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 37\u001b[0m \u001b[0mall_acc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mall_acc\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0macc\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;32m~\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36mrun\u001b[1;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m 903\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 904\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[1;32m--> 905\u001b[1;33m run_metadata_ptr)\n\u001b[0m\u001b[0;32m 906\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 907\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;32m~\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run\u001b[1;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m 1138\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m \u001b[1;32mor\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mhandle\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mfeed_dict_tensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1139\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[1;32m-> 1140\u001b[1;33m feed_dict_tensor, options, run_metadata)\n\u001b[0m\u001b[0;32m 1141\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1142\u001b[0m \u001b[0mresults\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;32m~\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_run\u001b[1;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m 1319\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1320\u001b[0m return self._do_call(_run_fn, feeds, fetches, targets, options,\n\u001b[1;32m-> 1321\u001b[1;33m run_metadata)\n\u001b[0m\u001b[0;32m 1322\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1323\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_do_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_prun_fn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeeds\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetches\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;32m~\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_call\u001b[1;34m(self, fn, *args)\u001b[0m\n\u001b[0;32m 1325\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1326\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1327\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1328\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1329\u001b[0m \u001b[0mmessage\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;32m~\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[1;34m(feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[0;32m 1310\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_extend_graph\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1311\u001b[0m return self._call_tf_sessionrun(\n\u001b[1;32m-> 1312\u001b[1;33m options, feed_dict, fetch_list, target_list, run_metadata)\n\u001b[0m\u001b[0;32m 1313\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1314\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;32m~\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_call_tf_sessionrun\u001b[1;34m(self, options, feed_dict, fetch_list, target_list, run_metadata)\u001b[0m\n\u001b[0;32m 1418\u001b[0m return tf_session.TF_Run(\n\u001b[0;32m 1419\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moptions\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1420\u001b[1;33m status, run_metadata)\n\u001b[0m\u001b[0;32m 1421\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1422\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_call_tf_sessionprun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;31mKeyboardInterrupt\u001b[0m: " | |
] | |
} | |
], | |
"source": [ | |
"epoch = 1\n", | |
"TRAIN_DATASIZE,_,_,_,_ = image_tri.shape\n", | |
"PERIOD = int(TRAIN_DATASIZE/BATCH_SIZE) #Number of iterations for each epoch\n", | |
"\n", | |
"with tf.Session(graph=graph) as session:\n", | |
" tf.global_variables_initializer().run()\n", | |
" print('Initialized weights ,biases and other variables')\n", | |
" for step in range(epoch):\n", | |
" idxs = np.random.permutation(TRAIN_DATASIZE) #shuffled ordering\n", | |
" X_random = image_tri[idxs]\n", | |
" Y_random = triplet_labels[idxs]\n", | |
" print('Periods number is:',PERIOD)\n", | |
" for i in range(PERIOD):\n", | |
" batch_data = X_random[i * BATCH_SIZE:(i+1) * BATCH_SIZE]\n", | |
" batch_labels = Y_random[i * BATCH_SIZE:(i+1) * BATCH_SIZE]\n", | |
" feed_dict = {\n", | |
" tf_train_dataset1 : batch_data[:,0],\n", | |
" tf_train_dataset2 : batch_data[:,1],\n", | |
" tf_train_dataset3 : batch_data[:,2],\n", | |
" tf_train_labels : batch_labels, traindrop:True\n", | |
" }\n", | |
" _,l= session.run( [optimizer, loss] ,feed_dict=feed_dict)\n", | |
" if (i % 2 == 0):\n", | |
" print('Training batch loss %d: %f' % (step, l))\n", | |
" #evaluate in batches since our test set is 10000 samples\n", | |
" X = test_images\n", | |
" Y = test_labels\n", | |
" BATCH_SIZE_TEST = 5000\n", | |
" PERIOD_TEST = int(X.shape[0]/BATCH_SIZE_TEST)\n", | |
" all_acc = []\n", | |
" for j in range(PERIOD_TEST):\n", | |
" batch_data = X[j * BATCH_SIZE_TEST:(j+1) * BATCH_SIZE_TEST]\n", | |
" batch_labels = Y[j * BATCH_SIZE_TEST:(j+1) * BATCH_SIZE_TEST]\n", | |
" feed_dict = {tf_test_dataset : batch_data , traindrop:False}\n", | |
" test_predictions= session.run( predictions_test, feed_dict=feed_dict)\n", | |
" acc = accuracy(test_predictions, batch_labels) \n", | |
" all_acc = np.append(all_acc,acc)\n", | |
" if (j%500 == 0):\n", | |
" #print(train_predictions[:,:], batch_labels[:,:])\n", | |
" print(test_predictions[:3,:],batch_labels[:3,:])\n", | |
" print('j = {}, acc = {}'.format(j,acc))\n", | |
" print('All acc : {}'.format(np.mean(all_acc)))\n", | |
"\n", | |
"# print('train predict')\n", | |
"# train_predictions = np.array(train_predictions)\n", | |
"# print(train_predictions[:2,:], train_predictions.shape)\n", | |
"# print('Minibatch accuracy: %.1f%%' % accuracy(train_predictions, batch_labels))\n", | |
"# print('test pred')\n", | |
" #feed_dict={tf_test_dataset:test_images,tf_test_labels:test_labels, traindrop:False}\n", | |
" #predictions= session.run([pred], feed_dict=feed_dict)\n", | |
"# predictions = pred.eval()\n", | |
"# predictions = np.array(predictions)\n", | |
"# print(predictions[:2,:], predictions.shape)\n", | |
"# print('Test accuracy: %.1f%%' % accuracy(predictions, test_labels))\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Well, it doesn't seem to get results, the for loop in the graph makes it painfully slow it seems.<br>\n", | |
"The loss seems to decline but since there are no predictions( the graph outputs predictions regarding x1,x2,x3 and not z) it's really hard to see whether we get good results. <br>\n", | |
"Either case the test accuracy is still random (50%~) so more tweaking is needed." | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python (myenv)", | |
"language": "python", | |
"name": "myenv" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment