Skip to content

Instantly share code, notes, and snippets.

@georgehc
Created November 20, 2019 02:36
Show Gist options
  • Save georgehc/7eff5acc0c5f8de144bb0d3b049ea781 to your computer and use it in GitHub Desktop.
Save georgehc/7eff5acc0c5f8de144bb0d3b049ea781 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 94-775/95-865: Prediction and Model Validation\n",
"\n",
"Author: George H. Chen (georgechen [at symbol] cmu.edu)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# the lines below are just for aesthetics\n",
"plt.style.use('ggplot') # if you want your plot to look at ggplot (like how R makes plots)\n",
"%config InlineBackend.figure_format = 'retina' # if you use a Mac with Retina display"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data preparation"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"from tensorflow.python import keras\n",
"from keras.datasets import mnist\n",
"\n",
"(train_images, train_labels), (test_images, test_labels) = mnist.load_data()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"train_images = train_images[:2000]\n",
"train_labels = train_labels[:2000]\n",
"test_images = test_images[:500]\n",
"test_labels = test_labels[:500]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"flattened_train_images = train_images.reshape(len(train_images), -1) # flattens out each training image\n",
"flattened_test_images = test_images.reshape(len(test_images), -1) # flattens out each test image"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"flattened_train_images = flattened_train_images.astype(np.float32) / 255 # rescale to be between 0 and 1\n",
"flattened_test_images = flattened_test_images.astype(np.float32) / 255 # rescale to be between 0 and 1"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(2000, 784)\n"
]
}
],
"source": [
"print(flattened_train_images.shape)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(500, 784)\n"
]
}
],
"source": [
"print(flattened_test_images.shape)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x11a8dd150>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfcAAAHwCAYAAAC7cCafAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAWJQAAFiUBSVIk8AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dfaxcVb3/8fehFWhLWxUfihIEjNAEhUoPUlChlMhF8fLcRQlQcvO7N/gIFDVetWhVzPWayrMiEYEIxHYFchGDT8QWypM/PcWiRhEQuPwai1BQbEHBlvP7Y/bg8XRmes7ec87M+c77lTT7zF57zVpd3T2fWXv2Q9/g4CCSJCmOHTrdAUmS1F6GuyRJwRjukiQFY7hLkhSM4S5JUjCGuyRJwRjukiQFY7hLkhSM4S5JUjCGuyRJwRjukiQFY7hLkhTM5E53YIz4NBxJUhR9o63gzF2SpGCiztwB6Otr/GFnYGAAgP7+/vHszoTmmJXjuJXjuI2eY1ZON49blUeydzTcU0q7A18AjgZ2BTYANwOfzzn/qZN9kyRpourYYfmU0puBtcC/AT8DLgIeAc4B7k0p7dqpvkmSNJF1cub+deB1wNk558vqK1NKFwJLgC8BH+hQ3yRJmrA6MnNPKe0NHAU8BnxtWPHngOeAM1JK08a5a5IkTXidmrkvKJY/zjm/NLQg57wppXQ3tfCfB/yk2ZuklNY2Wp9zBv5xosRws2fPblmubTlm5Thu5Thuo+eYlRN13Dr1nfu+xfLBJuUPFct9xqEvkiSF0qmZ+8xi+WyT8vr6V7Z6k5zz3CZFg9D80oZuvvShWzlm5Thu5Thuo+eYldPN41blUrhuvYlN/QJ17zQnSdIodSrc6zPzmU3KZwzbTpIkjVCnwv13xbLZd+pvKZbNvpOXJElNdCrcVxfLo1JK/9SHlNJ04J3AX4GfjnfHJEma6DoS7jnn3wM/BvYEPjys+PPANODbOefnxrlrkiRNeJ28Q92HgHuAS1NKRwK/BQ4GjqB2OP4zHeybJEkTVsfOli9m7/3AtdRC/WPAm4FLgUNyzk93qm+SJE1kHX0qXM75/1F7cIwkSWqTbr3OXZIklWS4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBTO50x2Q1Jvmzp3btGzq1Knb3eYjH/lI6bYXL15cui7At7/97dJ1L7vsskpt33fffZXqqzc4c5ckKRjDXZKkYAx3SZKCMdwlSQrGcJckKRjDXZKkYAx3SZKCMdwlSQrGcJckKRjDXZKkYAx3SZKCMdwlSQrGcJckKRjDXZKkYAx3SZKC8XnukkqZM2dOpfq33XZb07Lp06dvd5sZM2aUbntwcLB0XYAzzjijdN1jjz22Utu77rprpfrqDR0L95TSY8CbmhT/Mec8axy7I0lSGJ2euT8LXNxg/ebx7ogkSVF0Otz/nHNe1uE+SJIUiifUSZIUTKdn7jullE4H9gCeA34JrMk5b+1styRJmrj6qp41WlaLE+oeBf4t53zHCN5jbaP1OecDAdaubVjM7NmzAXjggQdG1lk5ZiVFHrcpU6ZUqr/PPvs0LZs0aRIAW7c2/5xf32aiafV3Gon777+/4frI+9pY6uZxmzt3bv3HvtHW7eRh+WuAI4FZwDTgbcCVwJ7AD1JKB3Sua5IkTVwdm7k3k1JaDnwMuDnnfELJtxkE6Otr/GFnYGAAgP7+/pJv33scs3Iij1vV69xXrVrVtKx+nfumTZuablPlOvdOevbZZyvVb3ade+R9bSx187gNyecJNXNv5hvF8rCO9kKSpAmqG8P9yWI5raO9kCRpgurGcD+kWD7S0V5IkjRBdSTcU0r7pZRe3WD9m4DLi5fXj2+vJEmKoVPXuS8E/jOltJrapW+bgDcDxwA7A98Hlneob5IkTWidCvfVwL7A26kdhp8G/Bm4C7gOuC7n3F2n8UuSNEF0JNyLG9Rs9yY1ksbWO97xjtJ1b7rppkptz5w5s2lZ/TLWVttUuYy31SV2I/Hiiy+Wrlv1ka3z5s1ruH7atGktywHuu+++Sm1X+XtrfHXjCXWSJKkCw12SpGAMd0mSgjHcJUkKxnCXJCkYw12SpGAMd0mSgjHcJUkKxnCXJCkYw12SpGAMd0mSgjHcJUkKxnCXJCkYw12SpGAMd0mSgunI89wl/cPUqVNL1z3wwAMrtX399deXrrvbbrtVaruTHnrooUr1v/KVr5Suu2LFikpt33333Q3X9/X1tSwHWLp0aaW2/+u//qtSfY0fZ+6SJAVjuEuSFIzhLklSMIa7JEnBGO6SJAVjuEuSFIzhLklSMIa7JEnBGO6SJAVjuEuSFIzhLklSMIa7JEnBGO6SJAVjuEuSFIyPfJU67Morryxd99RTT21jT3pH1Ufl7rLLLqXr3nHHHZXanj9/fum6+++/f6W2NXE4c5ckKRjDXZKkYAx3SZKCMdwlSQrGcJckKRjDXZKkYAx3SZKCMdwlSQrGcJckKRjDXZKkYAx3SZKCMdwlSQrGcJckKRjDXZKkYAx3SZKC8XnuUkVz585tWjZ16tTtbnPMMceUbruvr6903aqqPpf8e9/7XtOyJUuWAHDRRRc13Wb58uWl2/7DH/5Qui7AL37xi9J1//SnP1Vqe8GCBS3LW+0TndxfNL6cuUuSFIzhLklSMIa7JEnBGO6SJAVjuEuSFIzhLklSMIa7JEnBGO6SJAVjuEuSFIzhLklSMIa7JEnBGO6SJAVjuEuSFIzhLklSMD7yVQLmzJlTuu5tt93WtGz69Onb3WbGjBml2x4cHCxdF+AHP/hB6bqnnnpqpbYPP/zwpmUvvPACAA8++GDTbZYuXVq67auuuqp0XYCnnnqqdN3777+/UtsvvfRSw/X1x7m22ieqPF4Y4MADDyxd97777qvUtkbHmbskScG0ZeaeUjoZOByYAxwATAduyDmf3qLOocBSYB6wM/AwcDVwWc55azv6JUlSL2rXYfml1EJ9M7AemN1q45TSccBNwN+AlcAzwL8CFwHvBBa2qV+SJPWcdh2WXwLsA8wAPthqw5TSDOCbwFZgfs75/+ScP0Ft1n8vcHJKaVGb+iVJUs9pS7jnnFfnnB/KOY/k7J6TgdcCK3LOA0Pe42/UjgDAdj4gSJKk5jpxQt2CYvnDBmVrgOeBQ1NKO41flyRJiqMTl8LtWyy3ucYl57wlpfQosB+wN/DbVm+UUlrbaH3OGYCBgYFGxcyePbtlubYVfcymTJlSum79crdGJk2atN1tdtihcxetvOtd7ypdd82aNZXabjUmu+++OwAXX3xx022q/JudcsoppesCbNmypVL9KuqXvJUprzJmANdff33pus8//3yltsdK1N9tnfitMrNYPtukvL7+lePQF0mSwunGm9jUP3Zu9/v7nPPcJkWDAP39/Q0L65/QmpVrW9HHrMpNbFatWtW0rD473bRpU9NtqtzEpqq77rqrdN2xvIlNfcZ+7rnnNt1m//33L912J29iU9XWrY2vFB7JTWz++te/Vmr79NObXt28Xd16E5tu/t1W5SZVnZi512fmM5uUzxi2nSRJGoVOhPvviuU+wwtSSpOBvYAtwCPj2SlJkqLoRLjXj2Ee3aDsMGAqcE/O+YXx65IkSXF0ItxvBDYCi1JKL3/JkVLaGbigeHlFB/olSVII7bq3/PHA8cXLWcXykJTStcXPG3POHwfIOf8lpfQf1EL+9pTSCmq3nz2W2mVyN1K7Ja0kSSqhXTP3OcCZxZ9/KdbtPWTdyUM3zjnfTO1BM2uAk4CPAn8HzgMWjfBOd5IkqYG+qs+D7lKD0PxmDt186UO36vYx22efbc7PHJXPfe5zpesuWtT8UQgjuTxp48aNpdvesGFD6boAF1xwwfY3auLGG2+s1HYr3b6/dVKVS+Gq/r5fubL8QdXTTjutUttjpZv3tSH/Xq3vXNSAz3OXJCkYw12SpGAMd0mSgjHcJUkKxnCXJCkYw12SpGAMd0mSgjHcJUkKxnCXJCkYw12SpGAMd0mSgjHcJUkKxnCXJCkYw12SpGAmd7oDUt1OO+1Uuu7y5csrtf2+972vdN1NmzY1Ldtll10A2Lx5c9NtFi9eXLrt+uMqy5oyZUql+uote+yxR6e7oBFy5i5JUjCGuyRJwRjukiQFY7hLkhSM4S5JUjCGuyRJwRjukiQFY7hLkhSM4S5JUjCGuyRJwRjukiQFY7hLkhSM4S5JUjCGuyRJwRjukiQF4/Pc1TXe/va3l65b5XnsVR133HFNy6688koAzjrrrKbb3HHHHW3vk6Te5sxdkqRgDHdJkoIx3CVJCsZwlyQpGMNdkqRgDHdJkoIx3CVJCsZwlyQpGMNdkqRgDHdJkoIx3CVJCsZwlyQpGMNdkqRgDHdJkoLxka/qGhdeeGHpun19fZXarvLY1VZ1N2/eXPn9paF22KH1nKzV/4WXXnqp3d1Rl3LmLklSMIa7JEnBGO6SJAVjuEuSFIzhLklSMIa7JEnBGO6SJAVjuEuSFIzhLklSMIa7JEnBGO6SJAVjuEuSFIzhLklSMIa7JEnBGO6SJAXj89zVNu9///sr1Z8zZ07puoODg5XavuWWWyrVl8ZLs2ey15/j3ur/QtX/J+vWratUX+PHmbskScG0ZeaeUjoZOByYAxwATAduyDmf3mDbPYFHW7zdypzzonb0S5KkXtSuw/JLqYX6ZmA9MHsEde4Hbm6w/tdt6pMkST2pXeG+hFqoP0xtBr96BHXW5ZyXtal9SZJUaEu455xfDvOUUjveUpIkldTJs+XfkFI6C9gVeBq4N+f8y9G8QUppbaP1OWcABgYGGtabPXt2y3JtayRjNnPmzEpt7LTTTqXr1s8ULuu8884rXff007c5teRl7mvlOG7NbW9fb1Ve9f9JlcnbIYccUqntsRJ1X+tkuL+n+POylNLtwJk558c70iNJkgLoRLg/D3yR2sl0jxTr9geWAUcAP0kpzck5P7e9N8o5z21SNAjQ39/fsLD+Ca1ZubY1kjGrep17/YhLGTvuuGOlti+88MLSdS+++OKmZe5r5ThuzW3durXh+vG4zr3K/9GPfvSjldoeK928r1X59xr3cM85Pwl8dtjqNSmlo4C7gIOBfwcuGe++SZIUQdfcxCbnvAW4qnh5WCf7IknSRNY14V54qlhO62gvJEmawLot3OcVy0dabiVJkpoa93BPKR2cUtrm7KeU0gJqN8MBuH58eyVJUhzturf88cDxxctZxfKQlNK1xc8bc84fL37+b2C/4rK39cW6/YEFxc/n55zvaUe/JEnqRe06W34OcOawdXsXfwD+F6iH+3XACcBBwHuBVwB/BDJwec75zjb1SZKkntSu288uo3ad+ki2/RbwrXa0q+4yZcqUSvWrXKv+5JNPVmp75cqVleqrt1S5m+KyZcva15FRWrVqVaX6n/rUp9rUE421bjuhTpIkVWS4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUTLue5y511AsvvFCp/oYNG9rUE00EVR7ZCrB06dLSdT/xiU9Uanv9+vUN18+aNQuAJ554omndr371q5Xa3rx5c6X6Gj/O3CVJCsZwlyQpGMNdkqRgDHdJkoIx3CVJCsZwlyQpGMNdkqRgDHdJkoIx3CVJCsZwlyQpGMNdkqRgDHdJkoIx3CVJCsZwlyQpGMNdkqRgfJ67Qrjllls63QWNszlz5pSuW/WZ6qecckrput/97ncrtX3SSSc1XD8wMABAf39/pfdXDM7cJUkKxnCXJCkYw12SpGAMd0mSgjHcJUkKxnCXJCkYw12SpGAMd0mSgjHcJUkKxnCXJCkYw12SpGAMd0mSgjHcJUkKxnCXJCkYH/mqtunr6+tY/eOPP75S2+ecc06l+hq9JUuWNC173etet91tzj///NJtz5w5s3RdgBtuuKF03cWLF1dqWxoJZ+6SJAVjuEuSFIzhLklSMIa7JEnBGO6SJAVjuEuSFIzhLklSMIa7JEnBGO6SJAVjuEuSFIzhLklSMIa7JEnBGO6SJAVjuEuSFIzhLklSMD7PXW0zODjYsfqzZs2q1Pall15auu7VV1/dtGzKlCkAzJkzp+k2Tz/9dOm2582bV7ouwBlnnFG67gEHHFCp7d13371pWV9fHwDLly9vus3jjz9euu0f/ehHpesCfP3rX69UXxprlcM9pbQrcAJwDPA24I3Ai8CvgGuAa3LOLzWodyiwFJgH7Aw8DFwNXJZz3lq1X5Ik9ap2HJZfCHwTOBj4v8DFwE3AW4GrgJxS6htaIaV0HLAGOAz4H+BrwI7ARcCKNvRJkqSe1Y7D8g8CxwK3Dp2hp5Q+DfwMOAk4kVrgk1KaQe3DwFZgfs55oFh/PrAKODmltCjnbMhLklRC5Zl7znlVzvl7ww+955yfAL5RvJw/pOhk4LXAinqwF9v/jdpheoAPVu2XJEm9aqzPlv97sdwyZN2CYvnDBtuvAZ4HDk0p7TSWHZMkKaoxO1s+pTQZWFy8HBrk+xbLB4fXyTlvSSk9CuwH7A38djttrG20PucMwMDAQKNiZs+e3bJc2xrJmL3qVa+q1MYOO5T/rFk/u7qsU045pXTdI488smnZXnvtBcANN9zQdJutW8ufPzpt2rTSdQFe/epXl647derUSm2P5N+s1Ta77bZb6baPOuqo0nWh9dUP2/Pcc89VarsZf6+VE3XcxnLm/mVqJ9V9P+c89LqTmcXy2Sb16utfOVYdkyQpsjGZuaeUzgY+BjwAjPZC2vpH9e1e9JxzntukaBCgv7+/YWH9E1qzcm1rJGO2cOHCSm185zvfKV23yuwXYOXKlaXrtrrOvT5jP+2005pu43Xu26rP2Fvd+2DDhg2l2/7pT39aui7AJZdc0rG2m/H3WjndPG5V7v3R9pl7SunDwCXAb4Ajcs7PDNukPjOfSWMzhm0nSZJGoa3hnlI6F7gc+DW1YH+iwWa/K5b7NKg/GdiL2gl4j7Szb5Ik9Yq2hXtK6ZPUbkKzjlqwP9lk01XF8ugGZYcBU4F7cs4vtKtvkiT1kraEe3EDmi8Da4Ejc84bW2x+I7ARWJRSevlLjpTSzsAFxcsr2tEvSZJ6UTvuLX8m8AVqd5y7Ezg7pTR8s8dyztcC5Jz/klL6D2ohf3tKaQXwDLW73O1brC9/dpMkST2uHWfL71UsJwHnNtnmDuDa+ouc880ppcOBz1C7PW39wTHnAZfmnKs9XkySpB5WOdxzzsuAZSXq3Q28r2r7EsCkSZMq1f/Qhz5Uuu5JJ53UtOw1r3kNALfeemvTbf7yl7+Ubvstb3lL6bqdds899zQtq98kZt26dU23Wb16dem2P/vZz5auK00EY337WUmSNM4Md0mSgjHcJUkKxnCXJCkYw12SpGAMd0mSgjHcJUkKxnCXJCkYw12SpGAMd0mSgjHcJUkKxnCXJCkYw12SpGAMd0mSgjHcJUkKpvLz3KW6e++9t1L9n//856XrHnTQQZXarmLWrFlNy/r6+ra7zetf//q292mknn766dJ1V6xYUantc845p2nZwMAAAO9+97srtSH1KmfukiQFY7hLkhSM4S5JUjCGuyRJwRjukiQFY7hLkhSM4S5JUjCGuyRJwRjukiQFY7hLkhSM4S5JUjCGuyRJwRjukiQFY7hLkhSMj3xV26xfv75S/RNPPLF03bPOOqtS20uXLq1Uv1MuueSSSvWvuOKK0nUffvjhSm1LGjvO3CVJCsZwlyQpGMNdkqRgDHdJkoIx3CVJCsZwlyQpGMNdkqRgDHdJkoIx3CVJCsZwlyQpGMNdkqRgDHdJkoIx3CVJCsZwlyQpGMNdkqRgfJ67usaGDRtK1122bFmltqvWb2ZgYACA/v7+MXl/SWrEmbskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwk6u+QUppV+AE4BjgbcAbgReBXwHXANfknF8asv2ewKMt3nJlznlR1X5JktSrKoc7sBC4AtgArAYeB14PnAhcBbw3pbQw5zw4rN79wM0N3u/XbeiTJEk9qx3h/iBwLHDrsBn6p4GfASdRC/qbhtVbl3Ne1ob2JUnSEJXDPee8qsn6J1JK3wC+BMxn23CXJEljoB0z91b+Xiy3NCh7Q0rpLGBX4Gng3pzzL8e4P5IkhTdm4Z5SmgwsLl7+sMEm7yn+DK1zO3BmzvnxEbaxttH6nDMAAwMDDevNnj27Zbm25ZiV47iV47iNnmNWTtRxG8tL4b4MvBX4fs75R0PWPw98EZgLvKr4czi1k/HmAz9JKU0bw35JkhRa3+Dg8JPYq0spnQ1cAjwAvDPn/MwI6kwG7gIOBs7NOV9SoQuDAH19fQ0L65/Q+vv7KzTRWxyzchy3chy30XPMyunmcRuSz43DrIW2z9xTSh+mFuy/AY4YSbAD5Jy3ULt0DuCwdvdLkqRe0dZwTymdC1xO7Vr1I3LOT4zyLZ4qlh6WlySppLaFe0rpk8BFwDpqwf5kibeZVywfaVe/JEnqNW0J95TS+dROoFsLHJlz3thi24NTSjs2WL8AWFK8vL4d/ZIkqRe1497yZwJfALYCdwJnp5SGb/ZYzvna4uf/BvYrLntbX6zbH1hQ/Hx+zvmeqv2SJKlXteM6972K5STg3Cbb3AFcW/x8HbUHzRwEvBd4BfBHIAOX55zvbEOfJEnqWe24/ewyYNkotv8W8K2q7UqSpMZ8nrskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwhrskScEY7pIkBWO4S5IUjOEuSVIwhrskScH0DQ4OdroPYyHkX0qS1JP6RlvBmbskScFM7nQHxkjLTzkppbUAOee549Odic8xK8dxK8dxGz3HrJyo4+bMXZKkYAx3SZKCMdwlSQrGcJckKRjDXZKkYKJe5y5JUs9y5i5JUjCGuyRJwRjukiQFY7hLkhSM4S5JUjCGuyRJwRjukiQFE/WpcA2llHYHvgAcDewKbABuBj6fc/5TJ/vWrVJKjwFvalL8x5zzrHHsTtdIKZ0MHA7MAQ4ApgM35JxPb1HnUGApMA/YGXgYuBq4LOe8dcw73QVGM24ppT2BR1u83cqc86Kx6Gc3SSntCpwAHAO8DXgj8CLwK+Aa4Jqc80sN6vX0/jbacYu2v/VMuKeU3gzcA7wO+C7wAPAO4Bzg6JTSO3POT3ewi93sWeDiBus3j3dHushSauG0GVgPzG61cUrpOOAm4G/ASuAZ4F+Bi4B3AgvHsrNdZFTjVrif2ofw4X7dxn51s4XAFdQmI6uBx4HXAycCVwHvTSktzDm/fEcy9zegxLgVQuxvPRPuwNepBfvZOefL6itTShcCS4AvAR/oUN+63Z9zzss63Ykus4RaOD1MbSa6utmGKaUZwDeBrcD8nPNAsf58YBVwckppUc55xZj3uvNGPG5DrOvx/e9B4Fjg1mEzzU8DPwNOohZYNxXr3d9qRjVuQ4TY33riO/eU0t7AUcBjwNeGFX8OeA44I6U0bZy7pgkq57w65/xQg0/9jZwMvBZYUf9FW7zH36jNZAE+OAbd7DqjHDcBOedVOefvDT/0nnN+AvhG8XL+kCL3N0qNWyi9MnNfUCx/3OAfelNK6W5q4T8P+Ml4d24C2CmldDqwB7UPQr8E1vTC93ZtUt//ftigbA3wPHBoSmmnnPML49etCeMNKaWzqJ0n8zRwb875lx3uU7f4e7HcMmSd+9v2NRq3uhD7W0/M3IF9i+WDTcofKpb7jENfJqJZwHXUvrq4mNqhvYdSSod3tFcTR9P9L+e8hdpJPJOBvcezUxPIe6jNtL5ULO9PKa1OKe3R2W51VkppMrC4eDk0yN3fWmgxbnUh9rdeCfeZxfLZJuX19a8ch75MNNcAR1IL+GnUzjq9EtgT+EFK6YDOdW3CcP8r53ngi8Bc4FXFn/r39POBn/T4V2lfBt4KfD/n/KMh693fWms2bqH2t145LL89fcXS7wGHyTl/ftiqXwMfSCltBj4GLKN2uYnKc/9rIOf8JPDZYavXpJSOAu4CDgb+HbhkvPvWaSmls6n9/3sAOGOU1Xt2f2s1btH2t16Zudc/qc5sUj5j2HbavvoJKYd1tBcTg/tfGxWHlq8qXvbc/pdS+jC1gPkNcETO+Zlhm7i/NTCCcWtoou5vvRLuvyuWzb5Tf0uxbPadvLb1ZLGcMIepOqjp/ld8/7cXtRN7HhnPTk1wTxXLntr/UkrnApdTO4J2RHHm93Dub8OMcNxamXD7W6+Ee/1a2qNSSv/0d04pTad2U4e/Aj8d745NYIcUy575BVHBqmJ5dIOyw4CpwD09fOZyGfOKZc/sfymlT1K7Cc06agH1ZJNN3d+GGMW4tTLh9reeCPec8++BH1M7CezDw4o/T+3T2Ldzzs+Nc9e6Wkppv5TSqxusfxO1T8EA149vryakG4GNwKKUUn99ZUppZ+CC4uUVnehYN0spHZxS2rHB+gXUboYDPbL/FTeg+TKwFjgy57yxxebub4XRjFu0/a1vcLA3zqlocPvZ31I7QeIIaofjD/X2s/8spbQM+E9qRz4eBTYBb6Z2r+adge8DJ+ScX+xUHzslpXQ8cHzxchbwL9Q+1d9ZrNuYc/74sOyP82gAAAE7SURBVO1vpHY70BXUbgd6LLXLlm4EUi/c2GU045ZSuh3YD7id2l3tAPbnH9dxn59zrodVWCmlM4Frqd1x7jIaf1f+WM752iF1en5/G+24RdvfeuZs+Zzz74tPsfUHx7yP2j2HL6X24JgRnVzRY1ZT+2XwdmqH4acBf6Z25uh1wHXRf0G0MAc4c9i6vfnHtcP/C7wc7jnnm4v7AnyG2m0v6w/yOA+4tIfGcTTjdh21KzEOAt4LvAL4I5CBy3POd9Ib9iqWk4Bzm2xzB7UgA9zfCqMdt1D7W8/M3CVJ6hU98Z27JEm9xHCXJCkYw12SpGAMd0mSgjHcJUkKxnCXJCkYw12SpGAMd0mSgjHcJUkKxnCXJCkYw12SpGAMd0mSgjHcJUkKxnCXJCkYw12SpGAMd0mSgjHcJUkK5v8DAlse3siJRgQAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"image/png": {
"height": 248,
"width": 251
},
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# we could plot out what training images look like\n",
"plt.imshow(train_images[1].reshape(28, 28), cmap='gray')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Classification using $k$-nearest neighbors"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
" metric_params=None, n_jobs=None, n_neighbors=1, p=2,\n",
" weights='uniform')"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.neighbors import KNeighborsClassifier\n",
"classifier = KNeighborsClassifier(n_neighbors=1)\n",
"classifier.fit(flattened_train_images, train_labels)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"predicted_train_labels = classifier.predict(flattened_train_images)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([5, 0, 4, ..., 5, 2, 0], dtype=uint8)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predicted_train_labels"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.0\n"
]
}
],
"source": [
"error_rate = np.mean(predicted_train_labels != train_labels)\n",
"print(error_rate)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Choosing hyperparameter $k$ using simple data splitting"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(0)\n",
"num_train_images = len(flattened_train_images)\n",
"shuffled_indices = np.random.permutation(num_train_images)\n",
"\n",
"train_frac = 0.7\n",
"num_actual_training_examples = int(train_frac*num_train_images)\n",
"smaller_train_indices = shuffled_indices[:num_actual_training_examples]\n",
"validation_indices = shuffled_indices[num_actual_training_examples:]"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"k: 1 error rate: 0.08666666666666667\n",
"k: 2 error rate: 0.105\n",
"k: 3 error rate: 0.085\n",
"k: 4 error rate: 0.095\n",
"k: 5 error rate: 0.08666666666666667\n",
"k: 6 error rate: 0.09666666666666666\n",
"k: 7 error rate: 0.095\n",
"k: 8 error rate: 0.09833333333333333\n",
"k: 9 error rate: 0.1\n",
"k: 10 error rate: 0.105\n",
"k: 11 error rate: 0.10666666666666667\n",
"k: 12 error rate: 0.10333333333333333\n",
"k: 13 error rate: 0.115\n",
"k: 14 error rate: 0.11166666666666666\n",
"k: 15 error rate: 0.10833333333333334\n",
"k: 16 error rate: 0.11\n",
"k: 17 error rate: 0.10833333333333334\n",
"k: 18 error rate: 0.115\n",
"k: 19 error rate: 0.11333333333333333\n",
"k: 20 error rate: 0.11666666666666667\n",
"k: 21 error rate: 0.11333333333333333\n",
"k: 22 error rate: 0.11833333333333333\n",
"k: 23 error rate: 0.12\n",
"k: 24 error rate: 0.12\n",
"k: 25 error rate: 0.12166666666666667\n",
"k: 26 error rate: 0.12\n",
"k: 27 error rate: 0.12\n",
"k: 28 error rate: 0.12166666666666667\n",
"k: 29 error rate: 0.12666666666666668\n",
"k: 30 error rate: 0.12666666666666668\n",
"k: 31 error rate: 0.12666666666666668\n",
"k: 32 error rate: 0.13\n",
"k: 33 error rate: 0.13\n",
"k: 34 error rate: 0.12666666666666668\n",
"k: 35 error rate: 0.13\n",
"k: 36 error rate: 0.13166666666666665\n",
"k: 37 error rate: 0.13\n",
"k: 38 error rate: 0.13166666666666665\n",
"k: 39 error rate: 0.12666666666666668\n",
"k: 40 error rate: 0.13333333333333333\n",
"k: 41 error rate: 0.13666666666666666\n",
"k: 42 error rate: 0.14\n",
"k: 43 error rate: 0.14166666666666666\n",
"k: 44 error rate: 0.14333333333333334\n",
"k: 45 error rate: 0.145\n",
"k: 46 error rate: 0.14666666666666667\n",
"k: 47 error rate: 0.14166666666666666\n",
"k: 48 error rate: 0.14166666666666666\n",
"k: 49 error rate: 0.14666666666666667\n",
"k: 50 error rate: 0.14333333333333334\n",
"Best k: 3 error rate: 0.085\n"
]
}
],
"source": [
"lowest_error = np.inf\n",
"best_k = None\n",
"for k in range(1, 51):\n",
" classifier = KNeighborsClassifier(n_neighbors=k)\n",
" classifier.fit(flattened_train_images[smaller_train_indices],\n",
" train_labels[smaller_train_indices])\n",
" predicted_val_labels = classifier.predict(flattened_train_images[validation_indices])\n",
" error = np.mean(predicted_val_labels != train_labels[validation_indices])\n",
" print('k:', k, 'error rate:', error)\n",
" \n",
" if error < lowest_error:\n",
" lowest_error = error\n",
" best_k = k\n",
"\n",
"print('Best k:', best_k, 'error rate:', lowest_error)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Choosing hyperparameter $k$ using 5-fold cross validation"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"k: 1 cross validation error: 0.092\n",
"k: 2 cross validation error: 0.10899999999999999\n",
"k: 3 cross validation error: 0.10099999999999998\n",
"k: 4 cross validation error: 0.10400000000000001\n",
"k: 5 cross validation error: 0.1025\n",
"k: 6 cross validation error: 0.10850000000000001\n",
"k: 7 cross validation error: 0.10500000000000001\n",
"k: 8 cross validation error: 0.1115\n",
"k: 9 cross validation error: 0.11199999999999999\n",
"k: 10 cross validation error: 0.1155\n",
"k: 11 cross validation error: 0.11750000000000001\n",
"k: 12 cross validation error: 0.11950000000000001\n",
"k: 13 cross validation error: 0.12300000000000003\n",
"k: 14 cross validation error: 0.12400000000000003\n",
"k: 15 cross validation error: 0.131\n",
"k: 16 cross validation error: 0.1335\n",
"k: 17 cross validation error: 0.133\n",
"k: 18 cross validation error: 0.13649999999999998\n",
"k: 19 cross validation error: 0.13899999999999998\n",
"k: 20 cross validation error: 0.1405\n",
"k: 21 cross validation error: 0.14350000000000002\n",
"k: 22 cross validation error: 0.14200000000000002\n",
"k: 23 cross validation error: 0.142\n",
"k: 24 cross validation error: 0.14100000000000001\n",
"k: 25 cross validation error: 0.14300000000000002\n",
"k: 26 cross validation error: 0.146\n",
"k: 27 cross validation error: 0.14850000000000002\n",
"k: 28 cross validation error: 0.14950000000000002\n",
"k: 29 cross validation error: 0.152\n",
"k: 30 cross validation error: 0.1545\n",
"k: 31 cross validation error: 0.15700000000000003\n",
"k: 32 cross validation error: 0.1595\n",
"k: 33 cross validation error: 0.159\n",
"k: 34 cross validation error: 0.15750000000000003\n",
"k: 35 cross validation error: 0.1605\n",
"k: 36 cross validation error: 0.1625\n",
"k: 37 cross validation error: 0.16399999999999998\n",
"k: 38 cross validation error: 0.1645\n",
"k: 39 cross validation error: 0.16450000000000004\n",
"k: 40 cross validation error: 0.1665\n",
"k: 41 cross validation error: 0.16999999999999998\n",
"k: 42 cross validation error: 0.1745\n",
"k: 43 cross validation error: 0.17650000000000002\n",
"k: 44 cross validation error: 0.175\n",
"k: 45 cross validation error: 0.17850000000000002\n",
"k: 46 cross validation error: 0.18\n",
"k: 47 cross validation error: 0.1805\n",
"k: 48 cross validation error: 0.1815\n",
"k: 49 cross validation error: 0.1825\n",
"k: 50 cross validation error: 0.1825\n",
"Best k: 1 cross validation error: 0.092\n"
]
}
],
"source": [
"from sklearn.model_selection import KFold\n",
"\n",
"lowest_cross_val_error = np.inf\n",
"best_k = None\n",
"\n",
"indices = range(num_train_images)\n",
"kf = KFold(n_splits=5, shuffle=True, random_state=0)\n",
"for k in range(1, 51):\n",
" errors = []\n",
" for train_indices, val_indices in kf.split(indices):\n",
" classifier = KNeighborsClassifier(n_neighbors=k)\n",
" classifier.fit(flattened_train_images[train_indices],\n",
" train_labels[train_indices])\n",
" predicted_val_labels = classifier.predict(flattened_train_images[val_indices])\n",
" error = np.mean(predicted_val_labels != train_labels[val_indices])\n",
" errors.append(error)\n",
" \n",
" cross_val_error = np.mean(errors)\n",
" print('k:', k, 'cross validation error:', cross_val_error)\n",
"\n",
" if cross_val_error < lowest_cross_val_error:\n",
" lowest_cross_val_error = cross_val_error\n",
" best_k = k\n",
"\n",
"print('Best k:', best_k, 'cross validation error:', lowest_cross_val_error)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using different classifiers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's simple to work with other classifiers in scikit-learn. For example, here is how one can use random forest classifiers using the number of trees and the max depth as hyperparameters (there are other hyperparameters as well, but we're just using the scikit-learn default values in this demo--if you care about actually tuning the performance of a random forest classifier carefully, then you should look into what the other hyperparameters do by reading the documentation)."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.0\n"
]
}
],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"rf_classifier = RandomForestClassifier(n_estimators=50, max_depth=None, random_state=0)\n",
"rf_classifier.fit(flattened_train_images, train_labels)\n",
"rf_predicted_train_labels = rf_classifier.predict(flattened_train_images)\n",
"rf_error = np.mean(rf_predicted_train_labels != train_labels)\n",
"print(rf_error) # training set error"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we see cross-validation for random forests. Importantly, now we sweep over two hyperparameters."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hyperparameter: (50, 3) cross validation error: 0.2765\n",
"Hyperparameter: (50, 4) cross validation error: 0.20350000000000001\n",
"Hyperparameter: (50, 5) cross validation error: 0.155\n",
"Hyperparameter: (50, None) cross validation error: 0.0905\n",
"Hyperparameter: (100, 3) cross validation error: 0.2635\n",
"Hyperparameter: (100, 4) cross validation error: 0.19899999999999998\n",
"Hyperparameter: (100, 5) cross validation error: 0.1545\n",
"Hyperparameter: (100, None) cross validation error: 0.08549999999999999\n",
"Hyperparameter: (150, 3) cross validation error: 0.263\n",
"Hyperparameter: (150, 4) cross validation error: 0.1955\n",
"Hyperparameter: (150, 5) cross validation error: 0.15\n",
"Hyperparameter: (150, None) cross validation error: 0.08149999999999999\n",
"Hyperparameter: (200, 3) cross validation error: 0.2625\n",
"Hyperparameter: (200, 4) cross validation error: 0.1975\n",
"Hyperparameter: (200, 5) cross validation error: 0.149\n",
"Hyperparameter: (200, None) cross validation error: 0.0825\n",
"Best hyperparameter: (150, None) cross validation error: 0.08149999999999999\n"
]
}
],
"source": [
"lowest_cross_val_error = np.inf\n",
"best_hyperparameter_setting = None\n",
"\n",
"hyperparameter_settings = [(num_trees, max_depth)\n",
" for num_trees in [50, 100, 150, 200]\n",
" for max_depth in [3, 4, 5, None]]\n",
"\n",
"indices = range(num_train_images)\n",
"kf = KFold(n_splits=5, shuffle=True, random_state=0)\n",
"for hyperparameter_setting in hyperparameter_settings:\n",
" num_trees, max_depth = hyperparameter_setting\n",
" errors = []\n",
" for train_indices, val_indices in kf.split(indices):\n",
" classifier = RandomForestClassifier(n_estimators=num_trees,\n",
" max_depth=max_depth,\n",
" random_state=0)\n",
" classifier.fit(flattened_train_images[train_indices],\n",
" train_labels[train_indices])\n",
" predicted_val_labels = classifier.predict(flattened_train_images[val_indices])\n",
" error = np.mean(predicted_val_labels != train_labels[val_indices])\n",
" errors.append(error)\n",
" \n",
" cross_val_error = np.mean(errors)\n",
" print('Hyperparameter:', hyperparameter_setting, 'cross validation error:', cross_val_error)\n",
"\n",
" if cross_val_error < lowest_cross_val_error:\n",
" lowest_cross_val_error = cross_val_error\n",
" best_hyperparameter_setting = hyperparameter_setting\n",
"\n",
"print('Best hyperparameter:', best_hyperparameter_setting, 'cross validation error:', lowest_cross_val_error)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Finally actually looking at the test data"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.138\n"
]
}
],
"source": [
"final_knn_classifier = KNeighborsClassifier(n_neighbors=best_k)\n",
"final_knn_classifier.fit(flattened_train_images, train_labels)\n",
"predicted_test_labels = final_knn_classifier.predict(flattened_test_images)\n",
"test_set_error = np.mean(predicted_test_labels != test_labels)\n",
"print(test_set_error)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.082\n"
]
}
],
"source": [
"final_rf_classifier = RandomForestClassifier(n_estimators=best_hyperparameter_setting[0],\n",
" max_depth=best_hyperparameter_setting[1],\n",
" random_state=0)\n",
"final_rf_classifier.fit(flattened_train_images, train_labels)\n",
"predicted_test_labels = final_rf_classifier.predict(flattened_test_images)\n",
"test_set_error = np.mean(predicted_test_labels != test_labels)\n",
"print(test_set_error)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that in general the cross validation error is not going to perfectly match up with the test set error."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment