Skip to content

Instantly share code, notes, and snippets.

@Z30G0D
Last active May 5, 2018 16:25
Show Gist options
  • Save Z30G0D/c9384ce52c6fe2b561720b7851715079 to your computer and use it in GitHub Desktop.
Save Z30G0D/c9384ce52c6fe2b561720b7851715079 to your computer and use it in GitHub Desktop.
Combined loss function for imubit
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imubit\n",
"This is my solution for the challenge. I will explain my steps through this notebook. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"import mnist\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"from IPython.display import display, Math, Latex\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I used the \"mnist\" package to ease my work. located here:https://github.com/datapythonista/mnist"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use the retry module or similar alternatives.\n",
"WARNING:tensorflow:From <ipython-input-2-6441638eb0cd>:1: load_dataset (from tensorflow.contrib.learn.python.learn.datasets) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use tf.data.\n",
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\__init__.py:80: load_mnist (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:300: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please write your own downloading logic.\n",
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use tf.data to implement this functionality.\n",
"Extracting MNIST-data\\train-images-idx3-ubyte.gz\n",
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use tf.data to implement this functionality.\n",
"Extracting MNIST-data\\train-labels-idx1-ubyte.gz\n",
"Extracting MNIST-data\\t10k-images-idx3-ubyte.gz\n",
"Extracting MNIST-data\\t10k-labels-idx1-ubyte.gz\n",
"WARNING:tensorflow:From C:\\Users\\zeogo\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n"
]
}
],
"source": [
"mnist = tf.contrib.learn.datasets.load_dataset(\"mnist\")\n",
"train_data = mnist.train.images # Returns np.array\n",
"train_labels = np.asarray(mnist.train.labels, dtype=np.int32)\n",
"eval_data = mnist.test.images # Returns np.array\n",
"eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"train_images = train_data\n",
"test_images = eval_data\n",
"test_labels = eval_labels\n",
"train_images = np.reshape(train_images, (-1,28,28))\n",
"test_images = np.reshape(test_images,(-1,28,28))\n",
"train_labels = np.reshape(train_labels, (-1,1))\n",
"test_labels = np.reshape(test_labels, (-1,1))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((55000, 28, 28), (55000, 1), (10000, 28, 28), (10000, 1))"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_images.shape, train_labels.shape, test_images.shape, test_labels.shape"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAADa9JREFUeJzt3X+IHfW5x/HPczXFsA2SUEwXs3VjXOqtYm1ZVGi4KtUSNZhEbGgQkktLt5gKVvtHRSUNSKWR29iLSCW1oSmktgHdGkK5SZFSW7hRN6KNbWwTltj8WHcjKdYmSIj79I+dLZu45ztnz5k5M+Z5vyCcH8+ZmYdDPjtzznfmfM3dBSCe/6i6AQDVIPxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4I6v5MbMzNOJwRK5u7WzOva2vOb2RIz+4uZHTCzB9pZF4DOslbP7Tez8yT9VdLNkg5LekXSKnf/c2IZ9vxAyTqx579G0gF3H3b3U5J+IWlZG+sD0EHthP9iSYemPD6cPXcGMxswsyEzG2pjWwAK1s4XftMdWnzosN7dN0naJHHYD9RJO3v+w5J6pjxeIOloe+0A6JR2wv+KpD4zW2hmH5P0FUnbi2kLQNlaPux399Nmdo+knZLOk7TZ3f9UWGcAStXyUF9LG+MzP1C6jpzkA+Cji/ADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoDo6RTdQF729vcn69ddfn6xfd911yfpTTz2VrL/++uvJeiew5weCIvxAUIQfCIrwA0ERfiAowg8ERfiBoNqapdfMDkp6T9IHkk67e3/O65mlFx2zevXqhrWNGzcml503b15b2x4dHU3Wu7u721p/SrOz9BZxks+N7v5OAesB0EEc9gNBtRt+l7TLzPaY2UARDQHojHYP+7/g7kfN7CJJvzGzN939xakvyP4o8IcBqJm29vzufjS7HZM0KOmaaV6zyd37874MBNBZLYffzLrMbM7kfUlfkvRGUY0BKFc7h/3zJQ2a2eR6fu7u/1dIVwBK13L43X1Y0mcL7AU4Q95Y+0MPPZSsDww0/qqpq6urpZ4mvf/++8n67t2721p/JzDUBwRF+IGgCD8QFOEHgiL8QFCEHwiKn+5Gbd11113J+n333Vfatvft25esr1mzJlkfGhoqsp1SsOcHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAY50epVq5c2bC2YsWK5LK33357W9seHx9vWHv00UeTy27YsCFZP3HiREs91Ql7fiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IinH+c8CiRYsa1tauXZtc9uTJk8n6gQMHkvU77rgjWV+6dGnDWjbnQ8vGxsaS9bvvvrthbXBwsK1tnwvY8wNBEX4gKMIPBEX4gaAIPxAU4QeCIvxAULnj/Ga2WdJSSWPufmX23DxJv5TUK+mgpJXu/vfy2jy3zZ49O1nv6+tL1nfs2NGwtmDBgpZ6qoMjR44k68uXL0/W9+zZU2Q755xm9vw/lbTkrOcekPSCu/dJeiF7DOAjJDf87v6ipONnPb1M0pbs/hZJ6T/BAGqn1c/88919RJKy24uKawlAJ5R+br+ZDUgaKHs7AGam1T3/qJl1S1J22/AKC3ff5O797t7f4rYAlKDV8G+XNDlN6RpJzxfTDoBOyQ2/mT0j6f8lfdrMDpvZ1yR9X9LNZrZf0s3ZYwAfIebunduYWec2ViPXXnttsv70008n61dccUXL23733XeT9fPPT3/t09XV1fK2pfTv269bty657JNPPpmsnzp1qqWeznXu3tQPJXCGHxAU4QeCIvxAUIQfCIrwA0ERfiAofrq7A5YsOfuiyDPlDeXl/cR1arj24YcfTi6bd9nsjTfemKzneeKJJxrW8n4WHOVizw8ERfiBoAg/EBThB4Ii/EBQhB8IivADQXFJbwfceeedyfq2bdtK2/bbb7+drD/22GPJ+tatW5P1Y8eOzbgnlItLegEkEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIzzd8DcuXOT9eHh4WT9wgsvLLKdGXnrrbeS9fvvvz9ZHxwcLLIdNIFxfgBJhB8IivADQRF+ICjCDwRF+IGgCD8QVO44v5ltlrRU0pi7X5k9t17S1yVNXsz9oLv/OndjQcf58+RNk33JJZck67feemvD2k033ZRctq+vL1m//PLLk/W8/z/r169vWHvkkUeSy6I1RY7z/1TSdLNOPO7uV2f/coMPoF5yw+/uL0o63oFeAHRQO5/57zGzP5rZZjNLn78KoHZaDf+PJC2SdLWkEUk/aPRCMxswsyEzG2pxWwBK0FL43X3U3T9w93FJP5Z0TeK1m9y93937W20SQPFaCr+ZdU95uELSG8W0A6BTcqfoNrNnJN0g6RNmdljSdyXdYGZXS3JJByV9o8QeAZSA6/mDyzvH4M0330zWL7300mR9165dDWurV69OLjs2NpasY3pczw8gifADQRF+ICjCDwRF+IGgCD8QVO44P85tZulRoZ6enrbWn7oc+fhxrherEnt+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcf7gFi5cmKzPmjWrrfWPj483rJ0+fbqtdaM97PmBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjG+TNz5sxJ1i+44IKGtRMnTiSXPXnyZEs9Nau7u7thbd26dcllV61aVXQ7Z7j33ntLXT9ax54fCIrwA0ERfiAowg8ERfiBoAg/EBThB4LKnaLbzHok/UzSJyWNS9rk7v9rZvMk/VJSr6SDkla6+99z1lXbKbqHh4eT9d7e3oa1vXv3Jpfdv39/Ky017ZZbbmlYmz17dqnbfvnll5P1xYsXN6xxPX85ipyi+7Skb7v7f0q6TtI3zewzkh6Q9IK790l6IXsM4CMiN/zuPuLur2b335O0T9LFkpZJ2pK9bIuk5WU1CaB4M/rMb2a9kj4n6SVJ8919RJr4AyHpoqKbA1Ceps/tN7OPS3pW0rfc/R95c7xNWW5A0kBr7QEoS1N7fjObpYngb3X357KnR82sO6t3Sxqbbll33+Tu/e7eX0TDAIqRG36b2MX/RNI+d984pbRd0prs/hpJzxffHoCyNDPUt1jS7yXt1cRQnyQ9qInP/dskfUrS3yR92d2Tcy7XeajvtddeS9avuuqqDnXSWXnDbTt37kzW165dm6wfOnRoxj2hPc0O9eV+5nf3P0hqtLIvzqQpAPXBGX5AUIQfCIrwA0ERfiAowg8ERfiBoHLH+QvdWI3H+W+77bZk/fHHH29Yu+yyy4puZ0aOHDnSsLZ79+7kshs2bEjWh4aGWuoJ1Snykl4A5yDCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcf4mzZ07t2Gtp6eng5182LFjxxrWRkZGOtgJ6oBxfgBJhB8IivADQRF+ICjCDwRF+IGgCD8QFOP8wDmGcX4ASYQfCIrwA0ERfiAowg8ERfiBoAg/EFRu+M2sx8x+a2b7zOxPZnZv9vx6MztiZq9l/24tv10ARck9ycfMuiV1u/urZjZH0h5JyyWtlPRPd/+fpjfGST5A6Zo9yef8JlY0Imkku/+eme2TdHF77QGo2ow+85tZr6TPSXope+oeM/ujmW02s2l/58rMBsxsyMyY9wmokabP7Tezj0v6naTvuftzZjZf0juSXNIjmvho8NWcdXDYD5Ss2cP+psJvZrMk7ZC00903TlPvlbTD3a/MWQ/hB0pW2IU9ZmaSfiJp39TgZ18ETloh6Y2ZNgmgOs18279Y0u8l7ZU0nj39oKRVkq7WxGH/QUnfyL4cTK2LPT9QskIP+4tC+IHycT0/gCTCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAULk/4FmwdyS9NeXxJ7Ln6qiuvdW1L4neWlVkb5c0+8KOXs//oY2bDbl7f2UNJNS1t7r2JdFbq6rqjcN+ICjCDwRVdfg3Vbz9lLr2Vte+JHprVSW9VfqZH0B1qt7zA6hIJeE3syVm9hczO2BmD1TRQyNmdtDM9mYzD1c6xVg2DdqYmb0x5bl5ZvYbM9uf3U47TVpFvdVi5ubEzNKVvnd1m/G644f9ZnaepL9KulnSYUmvSFrl7n/uaCMNmNlBSf3uXvmYsJn9l6R/SvrZ5GxIZvaYpOPu/v3sD+dcd/9OTXpbrxnO3FxSb41mlv5vVfjeFTnjdRGq2PNfI+mAuw+7+ylJv5C0rII+as/dX5R0/Kynl0nakt3foon/PB3XoLdacPcRd381u/+epMmZpSt97xJ9VaKK8F8s6dCUx4dVrym/XdIuM9tjZgNVNzON+ZMzI2W3F1Xcz9lyZ27upLNmlq7Ne9fKjNdFqyL8080mUqchhy+4++cl3SLpm9nhLZrzI0mLNDGN24ikH1TZTDaz9LOSvuXu/6iyl6mm6auS962K8B+W1DPl8QJJRyvoY1rufjS7HZM0qImPKXUyOjlJanY7VnE//+buo+7+gbuPS/qxKnzvspmln5W01d2fy56u/L2brq+q3rcqwv+KpD4zW2hmH5P0FUnbK+jjQ8ysK/siRmbWJelLqt/sw9slrcnur5H0fIW9nKEuMzc3mllaFb93dZvxupKTfLKhjB9KOk/SZnf/XsebmIaZXaqJvb00ccXjz6vszcyekXSDJq76GpX0XUm/krRN0qck/U3Sl92941+8NejtBs1w5uaSems0s/RLqvC9K3LG60L64Qw/ICbO8AOCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/ENS/AJvlHp92kzVQAAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x11a21458fd0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(train_images[6785],cmap='gray')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Great, seems like the shape are okay, let's handle the labels as requested. let's start with the the test set"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([0, 1]), list)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_labels=[1 if i % 2 != 0 else 0 for i in test_labels]\n",
"np.unique(test_labels), type(test_labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, let's go preprocess our training set:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"train_images_odd = np.array([train_images[i,:,:] for i in range(train_images.shape[0]) if train_labels[i] % 2 != 0])\n",
"train_images_even = np.array([train_images[i,:,:] for i in range(train_images.shape[0]) if train_labels[i] % 2 == 0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's also seperate the training images to \"odd\" and \"even\" training images, this is in order to create the triplets required\n",
"in the exercise more easily."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((55000, 28, 28), (27027, 28, 28), (27973, 28, 28))"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_images.shape, train_images_even.shape, train_images_odd.shape"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"image_triplets = np.zeros((10000, 3, train_images.shape[1], train_images.shape[2])) #initialize triplet data array (10000 triplets of pictures)\n",
"triplet_labels = np.zeros((10000, 1)) # initialize triplet label array\n",
"for i in range(10000):\n",
" rand_mnist = np.random.randint(55000, size=1) #a random image from mnist training set\n",
" if train_labels[rand_mnist] % 2 != 0: # checking if number is odd\n",
" rand_even = np.random.randint(train_images_even.shape[0], size=2) #randomizing 2 even numbers\n",
"# combine_images = np.concatenate((np.ravel(train_images[rand_mnist,:]), train_images_even[rand_even[0], :],train_images_even[rand_even[1], :]), axis=0) # combine 3 images\n",
" image_triplets[i, 0, :, :] = train_images[rand_mnist,:, :] #add 1 odd and 2 even images to triplet array\n",
" image_triplets[i, 1, :, :] = train_images_even[rand_even[0], :, :]\n",
" image_triplets[i, 2, :, :] = train_images_even[rand_even[1], :, :]\n",
" triplet_labels[i] = 0 # adding appropriate binary label\n",
" else:\n",
" rand_odd = np.random.randint(train_images_odd.shape[0], size=2)#randomizing 2 odd numbers\n",
" #combine_images = np.concatenate((np.ravel(train_images[rand_mnist,:]), train_images_odd[rand_odd[0], :],train_images_odd[rand_odd[1], :]), axis=0) # combine 3 images\n",
" image_triplets[i, 0, :, :] = train_images[rand_mnist,:, :] #add 1 odd and 2 even images to triplet array\n",
" image_triplets[i, 1, :, :] = train_images_odd[rand_odd[0], :, :]\n",
" image_triplets[i, 2, :, :] = train_images_odd[rand_odd[1], :, :]\n",
" triplet_labels[i] = 1"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((10000, 3, 28, 28), (10000, 1))"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"image_triplets.shape,triplet_labels.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Great, we have our triplets, We got 10000 samples in each sample we have 3 images (either 1 odd and 2 even <u>or</u> 2 odd and 1 even)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAC8lJREFUeJzt3WGoXPWZx/Hvs27zJi0YKWajcU2tQXYJ1C43suCyKMXqroEYodKAS2RL0xcVtuCLFd9UWAqybLu7+KKQYmiKbdqgcQ1xMS2yrF1Y5Eap1TabNmhsYmKywULsqxJ99sU9WW7jvWfmzpyZM+nz/UCYmfOfO/NjyO/+z8w5d/6RmUiq5w/6DiCpH5ZfKsryS0VZfqkoyy8VZfmloiy/VJTll4qy/FJRfzjNJ4sITyeUJiwzY5j7jTXzR8RdEXE0Io5FxMPjPJak6YpRz+2PiCuAXwB3ACeBeWB7Zv685Wec+aUJm8bMfwtwLDPfyMzfAt8Hto7xeJKmaJzyXwucWHT7ZLPtd0TEzog4HBGHx3guSR0b5wO/pXYtPrRbn5m7gF3gbr80S8aZ+U8C1y26vR44NV4cSdMyTvnngY0R8YmIWAV8HjjQTSxJkzbybn9mXoiIB4FDwBXA7sz8WWfJJE3UyIf6Rnoy3/NLEzeVk3wkXb4sv1SU5ZeKsvxSUZZfKsryS0VZfqkoyy8VZfmloiy/VJTll4qy/FJRll8qyvJLRVl+qSjLLxVl+aWiLL9UlOWXirL8UlGWXypqqkt0a/asXr26dXzv3r2t4xcuXGgdv/fee1ecSdPhzC8VZfmloiy/VJTll4qy/FJRll8qyvJLRY11nD8ijgPvAe8DFzJzrotQmp4HHnigdfzuu+9uHX/88cc7TKNp6uIkn9sz81wHjyNpitztl4oat/wJ/DAiXo6InV0EkjQd4+7235qZpyLiauBHEfE/mfni4js0vxT8xSDNmLFm/sw81VyeBZ4BblniPrsyc84PA6XZMnL5I2J1RHzs4nXgs8DrXQWTNFnj7PavBZ6JiIuP873MfL6TVJImbuTyZ+YbwKc6zKIebNu2bayff/PNNztKomnzUJ9UlOWXirL8UlGWXyrK8ktFWX6pKMsvFWX5paIsv1SU5ZeKsvxSUZZfKsryS0VZfqkol+jWWJ5/3q9wuFw580tFWX6pKMsvFWX5paIsv1SU5ZeKsvxSUR7n11iOHj3adwSNyJlfKsryS0VZfqkoyy8VZfmloiy/VJTll4oaeJw/InYDW4Czmbmp2XYV8ANgA3AcuC8zfz25mOrL/Px83xE0IcPM/N8G7rpk28PAC5m5EXihuS3pMjKw/Jn5IvDuJZu3Anua63uAezrOJWnCRn3PvzYzTwM0l1d3F0nSNEz83P6I2AnsnPTzSFqZUWf+MxGxDqC5PLvcHTNzV2bOZebciM8laQJGLf8BYEdzfQfwbDdxJE3LwPJHxF7gv4GbIuJkRHwBeAy4IyJ+CdzR3JZ0GRn4nj8zty8z9JmOs2gC1q9f3zp+0003tY4/9dRTXcbRDPEMP6koyy8VZfmloiy/VJTll4qy/FJRfnX377krr7yydfyaa66ZUhLNGmd+qSjLLxVl+aWiLL9UlOWXirL8UlGWXyoqMnN6TxYxvScTAJs2bWodf/XVV1vHz58/3zq+Zs2aFWfSZGVmDHM/Z36pKMsvFWX5paIsv1SU5ZeKsvxSUZZfKsq/5/89984777SOHzt2rHV848aNXcbRDHHml4qy/FJRll8qyvJLRVl+qSjLLxVl+aWiBh7nj4jdwBbgbGZuarY9CnwR+N/mbo9k5r9PKqRGd+7cudbxEydOtI7feOONXcbRDBlm5v82cNcS2/85M29u/ll86TIzsPyZ+SLw7hSySJqicd7zPxgRP42I3RHhdzlJl5lRy/9N4JPAzcBp4OvL3TEidkbE4Yg4POJzSZqAkcqfmWcy8/3M/AD4FnBLy313ZeZcZs6NGlJS90Yqf0SsW3RzG/B6N3EkTcswh/r2ArcBH4+Ik8BXgdsi4mYggePAlyaYUdIEDCx/Zm5fYvMTE8iiy9D999/fOv7kk09OKYlWyjP8pKIsv1SU5ZeKsvxSUZZfKsryS0X51d0ay5YtW1rHPdQ3u5z5paIsv1SU5ZeKsvxSUZZfKsryS0VZfqkoj/NrLJs3b+47gkbkzC8VZfmloiy/VJTll4qy/FJRll8qyvJLRXmcv7h9+/a1jt9+++1TSqJpc+aXirL8UlGWXyrK8ktFWX6pKMsvFWX5paIGHuePiOuA7wB/BHwA7MrMf42Iq4AfABuA48B9mfnryUXVJLz99tut45nZOn799de3jt95553Ljh06dKj1ZzVZw8z8F4CHMvNPgD8HvhwRfwo8DLyQmRuBF5rbki4TA8ufmacz85Xm+nvAEeBaYCuwp7nbHuCeSYWU1L0VveePiA3Ap4GXgLWZeRoWfkEAV3cdTtLkDH1uf0R8FHga+Epmno+IYX9uJ7BztHiSJmWomT8iPsJC8b+bmfubzWciYl0zvg44u9TPZuauzJzLzLkuAkvqxsDyx8IU/wRwJDO/sWjoALCjub4DeLb7eJImZZjd/luBvwFei4ifNNseAR4D9kXEF4BfAZ+bTERN0nPPPdc6/tZbb7WO33DDDa3jq1atWnEmTcfA8mfmfwHLvcH/TLdxJE2LZ/hJRVl+qSjLLxVl+aWiLL9UlOWXivKru9Vq//79reMPPfRQ6/ig8wDUH2d+qSjLLxVl+aWiLL9UlOWXirL8UlGWXyrK4/xqdfDgwdbxzZs3t47Pz893GUcdcuaXirL8UlGWXyrK8ktFWX6pKMsvFWX5paJi0BLMnT5ZxPSeTCoqM4daS8+ZXyrK8ktFWX6pKMsvFWX5paIsv1SU5ZeKGlj+iLguIv4jIo5ExM8i4u+a7Y9GxNsR8ZPm319PPq6krgw8ySci1gHrMvOViPgY8DJwD3Af8JvM/Kehn8yTfKSJG/Ykn4Hf5JOZp4HTzfX3IuIIcO148ST1bUXv+SNiA/Bp4KVm04MR8dOI2B0Ra5b5mZ0RcTgiDo+VVFKnhj63PyI+Cvwn8LXM3B8Ra4FzQAL/wMJbg78d8Bju9ksTNuxu/1Dlj4iPAAeBQ5n5jSXGNwAHM3PTgMex/NKEdfaHPRERwBPAkcXFbz4IvGgb8PpKQ0rqzzCf9v8F8GPgNeCDZvMjwHbgZhZ2+48DX2o+HGx7LGd+acI63e3viuWXJs+/55fUyvJLRVl+qSjLLxVl+aWiLL9UlOWXirL8UlGWXyrK8ktFWX6pKMsvFWX5paIsv1TUwC/w7Ng54K1Ftz/ebJtFs5ptVnOB2UbVZbbrh73jVP+e/0NPHnE4M+d6C9BiVrPNai4w26j6yuZuv1SU5ZeK6rv8u3p+/jazmm1Wc4HZRtVLtl7f80vqT98zv6Se9FL+iLgrIo5GxLGIeLiPDMuJiOMR8Vqz8nCvS4w1y6CdjYjXF227KiJ+FBG/bC6XXCatp2wzsXJzy8rSvb52s7bi9dR3+yPiCuAXwB3ASWAe2J6ZP59qkGVExHFgLjN7PyYcEX8J/Ab4zsXVkCLiH4F3M/Ox5hfnmsz8+xnJ9igrXLl5QtmWW1n6AXp87bpc8boLfcz8twDHMvONzPwt8H1gaw85Zl5mvgi8e8nmrcCe5voeFv7zTN0y2WZCZp7OzFea6+8BF1eW7vW1a8nViz7Kfy1wYtHtk8zWkt8J/DAiXo6InX2HWcLaiysjNZdX95znUgNXbp6mS1aWnpnXbpQVr7vWR/mXWk1klg453JqZfwb8FfDlZvdWw/km8EkWlnE7DXy9zzDNytJPA1/JzPN9ZllsiVy9vG59lP8kcN2i2+uBUz3kWFJmnmouzwLPsPA2ZZacubhIanN5tuc8/y8zz2Tm+5n5AfAtenztmpWlnwa+m5n7m829v3ZL5errdeuj/PPAxoj4RESsAj4PHOghx4dExOrmgxgiYjXwWWZv9eEDwI7m+g7g2R6z/I5ZWbl5uZWl6fm1m7UVr3s5yac5lPEvwBXA7sz82tRDLCEibmBhtoeFv3j8Xp/ZImIvcBsLf/V1Bvgq8G/APuCPgV8Bn8vMqX/wtky221jhys0TyrbcytIv0eNr1+WK153k8Qw/qSbP8JOKsvxSUZZfKsryS0VZfqkoyy8VZfmloiy/VNT/Afz9Z1kTvWa/AAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x11a23ea62e8>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(image_triplets[2541,2,:,:],cmap='gray')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We need to have our input fit to the tensorflow module(i.e. batch x height x width x nchannels)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"test_images = np.array(test_images)\n",
"test_labels = np.array(test_labels)\n",
"image_tri = np.reshape(image_triplets, [-1, 3, 28, 28 , 1])\n",
"test_labels = np.reshape(test_labels, [-1, 1])\n",
"test_images = np.reshape(test_images,[-1, 28,28 ,1])\n",
"# noisy.shape\n",
"# noisy_lab = np.zeros((noisy.shape[0],1))\n",
"\n",
"# for i in range(0, triplet_labels.shape[0], 3):\n",
"# noisy_lab[i],noisy_lab[i+1] ,noisy_lab[i+2] = triplet_labels[(int(i / 3))],triplet_labels[(int(i / 3))],triplet_labels[(int(i / 3))]\n",
"\n",
"# noisy = noisy.reshape(noisy.shape[0],28, 28, 1)\n",
"# noisy.shape, noisy_lab.shape"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((10000, 3, 28, 28, 1), (10000, 1), (10000, 28, 28, 1), (10000, 1))"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"image_tri.shape,triplet_labels.shape, test_images.shape, test_labels.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's change to one hot encoder"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((10000, 3, 28, 28, 1), (10000, 2), (10000, 28, 28, 1), (10000, 2))"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_y = np.array(triplet_labels[:, 0]) # Input\n",
"train_y = np.reshape(train_y, (-1, 1))\n",
"enc = OneHotEncoder()\n",
"enc.fit(train_y)\n",
"out = enc.transform(train_y).toarray()\n",
"triplet_labels = out\n",
"train_y = np.array(test_labels) # Input\n",
"train_y = np.reshape(train_y, (-1, 1))\n",
"enc = OneHotEncoder()\n",
"enc.fit(train_y)\n",
"out = enc.transform(train_y).toarray()\n",
"test_labels=out\n",
"image_tri.shape,triplet_labels.shape, test_images.shape, test_labels.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, we have our engineered dataset , let's move to the neural network, try to train it and test it on the test set.<br>\n",
"My solution is based on the next derivation(z = g(y1,y2,y3) and equals to 1 or 0, Y is a set of y and X is a set of x), assume we want to maximize :<br>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\\begin{eqnarray}\n",
"log p(z | X) = log \\sum_{y : z = g(Y)} p(Y | X)\n",
"\\end{eqnarray}\n",
"where:\n",
"\\begin{eqnarray}\n",
"P(Y | X) = p(y1 | x1) p(y2 | x2) p(y3 | x3)\n",
"\\end{eqnarray}\n",
"\n",
"Is plugged in to the right side.<br>\n",
"\n",
"So basically every triplet label will give me a possible set of y labels and according to that I could have the appropriate loss.<br>\n",
"This creates a situation as I don't train my system on the actual labels of z because my predictions are related to the y.<br>\n",
"In order to implement this I tried conditioning every sample inside the batch (a label for a triplet).\n",
"If it's 0 it gets(1,0,0), (0,1,0) and (0,0,1) sum of losses.\n",
"If it's 1 gets(1,1,0), (0,1,1) and (1,0,1) sum of losses.\n",
"\n",
"All losses are cross entropy since we are dealing with classification.<br>\n",
"First, all three pictures are forward propagated through the network and then the loss is calculated accordingly.\n",
"I had issues with the implementation here since I am not sure how to condition on each sample in a batch inside the graph."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"BATCH_SIZE = 256\n",
"learning_rate = 0.0005\n",
"\n",
"\n",
"graph = tf.Graph()\n",
"with graph.as_default():\n",
" # Input data.\n",
" # attached to the graph.\n",
" tf_train_dataset1 = tf.placeholder(tf.float32, shape=(BATCH_SIZE, train_images.shape[1], train_images.shape[2], 1))\n",
" tf_train_dataset2 = tf.placeholder(tf.float32, shape=(BATCH_SIZE, train_images.shape[1], train_images.shape[2], 1))\n",
" tf_train_dataset3 = tf.placeholder(tf.float32, shape=(BATCH_SIZE, train_images.shape[1], train_images.shape[2], 1))\n",
"\n",
"\n",
" tf_train_labels = tf.placeholder(tf.float32, shape=(BATCH_SIZE, 2))\n",
" traindrop = tf.placeholder(tf.bool)\n",
" tf_test_dataset = tf.placeholder(tf.float32, shape=(None, train_images.shape[1], train_images.shape[2], 1))\n",
"# tf_test_labels = tf.placeholder(tf.float32, shape=(10000, 2))\n",
"# tf_test_dataset = tf.to_float(tf.constant(test_images))\n",
" \n",
" \n",
" #model\n",
" def model(data):\n",
" conv1 = tf.contrib.layers.conv2d(data, num_outputs=32, kernel_size=5,padding ='SAME', activation_fn=tf.nn.relu,\n",
" scope='conv1')\n",
" pool1 = tf.contrib.layers.max_pool2d(conv1, [2,2], 2, scope='pool1')\n",
" conv2 = tf.contrib.layers.conv2d(pool1, num_outputs=64, kernel_size=5,padding ='SAME', activation_fn=tf.nn.relu,\n",
" scope='conv2')\n",
" pool2 = tf.contrib.layers.max_pool2d(conv2, [2,2], 2, scope='pool2')\n",
" flat = tf.layers.flatten(pool2)\n",
" dense1 = tf.contrib.layers.fully_connected(flat, 1024, activation_fn=tf.nn.relu, scope='dense1')\n",
" #dropout = tf.nn.dropout(dense1, 0.5)\n",
" dropout = tf.layers.dropout(inputs=dense1, rate=0.6, training=traindrop, name='dropout')\n",
" out = tf.layers.dense(inputs=dropout, units=2, name='out')\n",
" return out\n",
" # multiple inputs\n",
" \n",
" with tf.variable_scope('net'):\n",
" out1 = model(tf_train_dataset1)\n",
" with tf.variable_scope('net', reuse=True):\n",
" out2 = model(tf_train_dataset2)\n",
" out3 = model(tf_train_dataset3)\n",
" zero_class = np.zeros(shape=(BATCH_SIZE, 2))\n",
" zero_class[:,0] = 1\n",
" one_class = np.zeros(shape=(BATCH_SIZE, 2))\n",
" one_class[:,1] = 1\n",
" zero_class = tf.constant(zero_class)\n",
" one_class = tf.constant(one_class)\n",
" loss=0\n",
" #Loss definition\n",
" for i in range(BATCH_SIZE):\n",
" lossout11 = tf.nn.softmax_cross_entropy_with_logits_v2(labels = one_class[i,:], logits = out1[i,:])\n",
" lossout10 = tf.nn.softmax_cross_entropy_with_logits_v2(labels = zero_class[i,:], logits = out1[i,:])\n",
" lossout21 = tf.nn.softmax_cross_entropy_with_logits_v2(labels = one_class[i,:], logits = out2[i,:])\n",
" lossout20 = tf.nn.softmax_cross_entropy_with_logits_v2(labels = zero_class[i,:], logits = out2[i,:])\n",
" lossout31 = tf.nn.softmax_cross_entropy_with_logits_v2(labels = one_class[i,:], logits = out3[i,:])\n",
" lossout30 = tf.nn.softmax_cross_entropy_with_logits_v2(labels = zero_class[i,:], logits = out3[i,:])\n",
" if tf_train_labels[i,:]==[1,0]:\n",
" loss = loss + (lossout11*lossout20*lossout30) +(lossout10*lossout21*lossout30) + (lossout10*lossout20*lossout31)\n",
" else:\n",
" loss = loss + (lossout11*lossout21*lossout30) + (lossout10*lossout21*lossout31) + (lossout11*lossout20*lossout31)\n",
"\n",
" # Optimizer\n",
" optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss)\n",
"\n",
" # predictions for test\n",
"# train_pred = tf.nn.softmax(normalized_logits)\n",
" predictions_test = tf.nn.softmax(model(tf_test_dataset))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, we finished our neural network model... nothing special, few conv layers, pooling and dropout. eventually evaluating the one hot encoder (2 units)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def accuracy(predictions, labels):\n",
" #return (100.0 * np.sum(predictions*labels > 0))/predictions.shape[0]\n",
" return (100.0 * np.sum(np.argmax(predictions, 1) == np.argmax(labels, 1))/ predictions.shape[0])"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Initialized weights ,biases and other variables\n",
"Periods number is: 39\n",
"Training batch loss 0: 311.717865\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"Training batch loss 0: 6.089136\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"Training batch loss 0: 0.035357\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"Training batch loss 0: 0.000270\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"Training batch loss 0: 0.000010\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"Training batch loss 0: 0.000000\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"Training batch loss 0: 0.000000\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"Training batch loss 0: 0.000000\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"Training batch loss 0: 0.000000\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"Training batch loss 0: 0.000000\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"Training batch loss 0: 0.000000\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"Training batch loss 0: 0.000000\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n",
"Training batch loss 0: 0.000000\n",
"[[0.51597804 0.484022 ]\n",
" [0.51549846 0.48450157]\n",
" [0.50839037 0.4916096 ]] [[0. 1.]\n",
" [1. 0.]\n",
" [0. 1.]]\n",
"j = 0, acc = 48.94\n",
"All acc : 49.379999999999995\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-18-aff961868499>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 33\u001b[0m \u001b[0mbatch_labels\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mY\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mj\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mBATCH_SIZE_TEST\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mj\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mBATCH_SIZE_TEST\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 34\u001b[0m \u001b[0mfeed_dict\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m{\u001b[0m\u001b[0mtf_test_dataset\u001b[0m \u001b[1;33m:\u001b[0m \u001b[0mbatch_data\u001b[0m \u001b[1;33m,\u001b[0m \u001b[0mtraindrop\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 35\u001b[1;33m \u001b[0mtest_predictions\u001b[0m\u001b[1;33m=\u001b[0m \u001b[0msession\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m \u001b[0mpredictions_test\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mfeed_dict\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 36\u001b[0m \u001b[0macc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0maccuracy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtest_predictions\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_labels\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 37\u001b[0m \u001b[0mall_acc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mall_acc\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0macc\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36mrun\u001b[1;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m 903\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 904\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[1;32m--> 905\u001b[1;33m run_metadata_ptr)\n\u001b[0m\u001b[0;32m 906\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 907\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run\u001b[1;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m 1138\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m \u001b[1;32mor\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mhandle\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mfeed_dict_tensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1139\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[1;32m-> 1140\u001b[1;33m feed_dict_tensor, options, run_metadata)\n\u001b[0m\u001b[0;32m 1141\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1142\u001b[0m \u001b[0mresults\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_run\u001b[1;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m 1319\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1320\u001b[0m return self._do_call(_run_fn, feeds, fetches, targets, options,\n\u001b[1;32m-> 1321\u001b[1;33m run_metadata)\n\u001b[0m\u001b[0;32m 1322\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1323\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_do_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_prun_fn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeeds\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetches\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_call\u001b[1;34m(self, fn, *args)\u001b[0m\n\u001b[0;32m 1325\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1326\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1327\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1328\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1329\u001b[0m \u001b[0mmessage\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[1;34m(feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[0;32m 1310\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_extend_graph\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1311\u001b[0m return self._call_tf_sessionrun(\n\u001b[1;32m-> 1312\u001b[1;33m options, feed_dict, fetch_list, target_list, run_metadata)\n\u001b[0m\u001b[0;32m 1313\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1314\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Miniconda2\\envs\\jupy\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_call_tf_sessionrun\u001b[1;34m(self, options, feed_dict, fetch_list, target_list, run_metadata)\u001b[0m\n\u001b[0;32m 1418\u001b[0m return tf_session.TF_Run(\n\u001b[0;32m 1419\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moptions\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1420\u001b[1;33m status, run_metadata)\n\u001b[0m\u001b[0;32m 1421\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1422\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_call_tf_sessionprun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"epoch = 1\n",
"TRAIN_DATASIZE,_,_,_,_ = image_tri.shape\n",
"PERIOD = int(TRAIN_DATASIZE/BATCH_SIZE) #Number of iterations for each epoch\n",
"\n",
"with tf.Session(graph=graph) as session:\n",
" tf.global_variables_initializer().run()\n",
" print('Initialized weights ,biases and other variables')\n",
" for step in range(epoch):\n",
" idxs = np.random.permutation(TRAIN_DATASIZE) #shuffled ordering\n",
" X_random = image_tri[idxs]\n",
" Y_random = triplet_labels[idxs]\n",
" print('Periods number is:',PERIOD)\n",
" for i in range(PERIOD):\n",
" batch_data = X_random[i * BATCH_SIZE:(i+1) * BATCH_SIZE]\n",
" batch_labels = Y_random[i * BATCH_SIZE:(i+1) * BATCH_SIZE]\n",
" feed_dict = {\n",
" tf_train_dataset1 : batch_data[:,0],\n",
" tf_train_dataset2 : batch_data[:,1],\n",
" tf_train_dataset3 : batch_data[:,2],\n",
" tf_train_labels : batch_labels, traindrop:True\n",
" }\n",
" _,l= session.run( [optimizer, loss] ,feed_dict=feed_dict)\n",
" if (i % 2 == 0):\n",
" print('Training batch loss %d: %f' % (step, l))\n",
" #evaluate in batches since our test set is 10000 samples\n",
" X = test_images\n",
" Y = test_labels\n",
" BATCH_SIZE_TEST = 5000\n",
" PERIOD_TEST = int(X.shape[0]/BATCH_SIZE_TEST)\n",
" all_acc = []\n",
" for j in range(PERIOD_TEST):\n",
" batch_data = X[j * BATCH_SIZE_TEST:(j+1) * BATCH_SIZE_TEST]\n",
" batch_labels = Y[j * BATCH_SIZE_TEST:(j+1) * BATCH_SIZE_TEST]\n",
" feed_dict = {tf_test_dataset : batch_data , traindrop:False}\n",
" test_predictions= session.run( predictions_test, feed_dict=feed_dict)\n",
" acc = accuracy(test_predictions, batch_labels) \n",
" all_acc = np.append(all_acc,acc)\n",
" if (j%500 == 0):\n",
" #print(train_predictions[:,:], batch_labels[:,:])\n",
" print(test_predictions[:3,:],batch_labels[:3,:])\n",
" print('j = {}, acc = {}'.format(j,acc))\n",
" print('All acc : {}'.format(np.mean(all_acc)))\n",
"\n",
"# print('train predict')\n",
"# train_predictions = np.array(train_predictions)\n",
"# print(train_predictions[:2,:], train_predictions.shape)\n",
"# print('Minibatch accuracy: %.1f%%' % accuracy(train_predictions, batch_labels))\n",
"# print('test pred')\n",
" #feed_dict={tf_test_dataset:test_images,tf_test_labels:test_labels, traindrop:False}\n",
" #predictions= session.run([pred], feed_dict=feed_dict)\n",
"# predictions = pred.eval()\n",
"# predictions = np.array(predictions)\n",
"# print(predictions[:2,:], predictions.shape)\n",
"# print('Test accuracy: %.1f%%' % accuracy(predictions, test_labels))\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Well, it doesn't seem to get results, the for loop in the graph makes it painfully slow it seems.<br>\n",
"The loss seems to decline but since there are no predictions( the graph outputs predictions regarding x1,x2,x3 and not z) it's really hard to see whether we get good results. <br>\n",
"Either case the test accuracy is still random (50%~) so more tweaking is needed."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (myenv)",
"language": "python",
"name": "myenv"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment