Skip to content

Instantly share code, notes, and snippets.

@georgehc
Created February 14, 2019 00:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save georgehc/747a060f69e94138553d54b684c12052 to your computer and use it in GitHub Desktop.
Save georgehc/747a060f69e94138553d54b684c12052 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 94-775/95-865: Prediction and Model Validation\n",
"\n",
"Author: George H. Chen (georgechen [at symbol] cmu.edu)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# the lines below are just for aesthetics\n",
"plt.style.use('ggplot') # if you want your plot to look at ggplot (like how R makes plots)\n",
"%config InlineBackend.figure_format = 'retina' # if you use a Mac with Retina display"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data preparation"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"from keras.datasets import mnist\n",
"\n",
"(train_images, train_labels), (test_images, test_labels) = mnist.load_data()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"train_images = train_images[:2000]\n",
"train_labels = train_labels[:2000]\n",
"test_images = test_images[:500]\n",
"test_labels = test_labels[:500]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"flattened_train_images = train_images.reshape(len(train_images), -1) # flattens out each training image\n",
"flattened_test_images = test_images.reshape(len(test_images), -1) # flattens out each test image"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"flattened_train_images = flattened_train_images.astype(np.float32) / 255 # rescale to be between 0 and 1\n",
"flattened_test_images = flattened_test_images.astype(np.float32) / 255 # rescale to be between 0 and 1"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(2000, 784)\n"
]
}
],
"source": [
"print(flattened_train_images.shape)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(500, 784)\n"
]
}
],
"source": [
"print(flattened_test_images.shape)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x12e27b7b8>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfoAAAH0CAYAAADVH+85AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAWJQAAFiUBSVIk8AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAIABJREFUeJzt3XuMXWW9//H3UATaQqvipShBwAhNUKh0kILKpUQOgoc7DxCg5OTnCV6BgsSjFk7xEj0GuSsSEVAglifw+yEGb8QWys2fTrEoUQQEDodYhKJiCwpS5vyx18ZxuveembX2zJr57vcradbs9axnP0+frs5nP2uvS9/g4CCSJCmmTerugCRJGj8GvSRJgRn0kiQFZtBLkhSYQS9JUmAGvSRJgRn0kiQFZtBLkhSYQS9JUmAGvSRJgRn0kiQFZtBLkhSYQS9JUmAGvSRJgRn0kiQFZtBLkhTYpnV3YJwM1t0BSZK6pK9KZWf0kiQFVuuMPqW0LfBZ4CBga2ANcBNwbs75T1Xfv6+v9YeggYEBAPr7+6s20TMcs3Ict3Ict7FzzMqZzOM2ONidg9O1zehTSm8FVgH/BvwMuAB4BDgNuCeltHVdfZMkKYo6Z/RfA94AnJpzvqS5MqV0PrAY+ALwoZr6JklSCLXM6FNKOwIHAo8BXx1W/J/Ac8BJKaWZE9w1SZJCqevQ/cJi+eOc88tDC3LO64C7gBnAgonumCRJkdR16H7nYvlgm/KHaMz4dwJ+0u5NUkqrWq3POQP/OMliuLlz53Ys18Ycs3Ict3Ict7FzzMrphXGra0Y/u1g+26a8uf7VE9AXSZLCmqw3zGleF9fx2oKc8/w2RYPQ/nKJyXw5xWTlmJXjuJXjuI2dY1bOZB63qX55XXPGPrtN+axh20mSpBLqCvrfFsud2pS/rVi2+w5fkiSNQl1Bv6JYHphS+qc+pJS2At4N/BX46UR3TJKkSGoJ+pzz74AfA9sDHx1WfC4wE/h2zvm5Ce6aJEmh1Hky3keAu4GLU0oHAL8B9gT2p3HI/jM19k2SpBBqu9d9MavvB66mEfBnAm8FLgb2yjk/U1ffJEmKotbL63LO/0PjoTaSJGkc+Dx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQps07o7IKk3zZ8/v23ZjBkzRtzmYx/7WOm2Fy1aVLouwLe//e3SdS+55JJKbd97772V6qv3OKOXJCkwg16SpMAMekmSAjPoJUkKzKCXJCkwg16SpMAMekmSAjPoJUkKzKCXJCkwg16SpMAMekmSAjPoJUkKzKCXJCkwg16SpMAMekmSAvN59JJKmTdvXqX6t956a9uyrbbaasRtZs2aVbrtwcHB0nUBTjrppNJ1Dz300Eptb7311pXqq/fUFvQppceAt7Qp/kPOec4EdkeSpJDqntE/C1zYYv36ie6IJEkR1R30f845L625D5IkheXJeJIkBVb3jH7zlNKJwHbAc8AvgZU55w31dkuSpBj6qp59WlaHk/EeBf4t53z7KN5jVav1OefdAVatalnM3LlzAXjggQdG11k5ZiVFHrfp06dXqr/TTju1LZs2bRoAGza0/8zf3Gaq6fR3Go377ruv5frI+9p4mszjNn/+/OaPfVXep85D91cBBwBzgJnAO4DLge2BH6SUdquva5IkxVDbjL6dlNJ5wJnATTnnI0q+zSBAX1/rD0EDAwMA9Pf3l3z73uOYlRN53KpeR798+fK2Zc3r6NetW9d2myrX0dfp2WefrVS/3XX0kfe18TSZx21IPk/ZGX07Xy+W+9TaC0mSApiMQf9UsZxZay8kSQpgMgb9XsXykVp7IUlSALUEfUppl5TSa1usfwtwafHy2ontlSRJ8dR1Hf0xwH+klFbQuJxuHfBW4BBgC+D7wHk19U2SpDDqCvoVwM7AO2kcqp8J/Bm4E7gGuCbnPLkuB5AkaQqqJeiLm+GMeEMcSePrXe96V+m6N954Y6W2Z8+e3baseWlsp22qXBrc6bK90XjxxRdL1636mNkFCxa0XD9z5syO5QD33ntvpbar/L1Vn8l4Mp4kSeoSg16SpMAMekmSAjPoJUkKzKCXJCkwg16SpMAMekmSAjPoJUkKzKCXJCkwg16SpMAMekmSAjPoJUkKzKCXJCkwg16SpMAMekmSAqvlefSS/mHGjBml6+6+++6V2r722mtL191mm20qtV2nhx56qFL9L3/5y6XrLlu2rFLbd911V8v1fX19HcsBlixZUqntL37xi5Xqqx7O6CVJCsyglyQpMINekqTADHpJkgIz6CVJCsyglyQpMINekqTADHpJkgIz6CVJCsyglyQpMINekqTADHpJkgIz6CVJCsyglyQpMB9TK9Xs8ssvL133+OOP72JPekfVx/tuueWWpevefvvtldreb7/9StfdddddK7WtqckZvSRJgRn0kiQFZtBLkhSYQS9JUmAGvSRJgRn0kiQFZtBLkhSYQS9JUmAGvSRJgRn0kiQFZtBLkhSYQS9JUmAGvSRJgRn0kiQFZtBLkhSYz6OXKpo/f37bshkzZoy4zSGHHFK67b6+vtJ1q6r6XPXvfe97bcsWL14MwAUXXNB2m/POO69027///e9L1wX4xS9+Ubrun/70p0ptL1y4sGN5p32izv1F9XFGL0lSYAa9JEmBGfSSJAVm0EuSFJhBL0lSYAa9JEmBGfSSJAVm0EuSFJhBL0lSYAa9JEmBGfSSJAVm0EuSFJhBL0lSYAa9JEmB+ZhaCZg3b17purfeemvbsq222mrEbWbNmlW67cHBwdJ1AX7wgx+Urnv88cdXanvfffdtW/bCCy8A8OCDD7bdZsmSJaXbvuKKK0rXBXj66adL173vvvsqtf3yyy+3XN98BG2nfaLKI5EBdt9999J177333kptq7yuBH1K6WhgX2AesBuwFXBdzvnEDnX2BpYAC4AtgIeBK4FLcs4butEvSZJ6Xbdm9EtoBPx64AlgbqeNU0qHATcCfwOuB/4I/CtwAfBu4Jgu9UuSpJ7Wre/oFwM7AbOAD3faMKU0C/gGsAHYL+f8f3LOZ9E4GnAPcHRK6bgu9UuSpJ7WlaDPOa/IOT+Ucx7NF4ZHA68HluWcB4a8x99oHBmAET4sSJKk0anjrPuFxfKHLcpWAs8De6eUNp+4LkmSFFMdQb9zsdzodNqc80vAozTOHdhxIjslSVJEdVxeN7tYPtumvLn+1SO9UUppVav1OWcABgYGWhUzd+7cjuXaWPQxmz59eum6zUvoWpk2bdqI22yySX23s3jPe95Tuu7KlSsrtd1pTLbddlsALrzwwrbbVPk3O/bYY0vXBXjppZcq1a+ieRldmfIqYwZw7bXXlq77/PPPV2p7vET/3QaT84Y5zb202gXCkiSplhl9c8Y+u035rGHbtZVznt+maBCgv7+/ZWHzk1u7cm0s+phVuWHO8uXL25Y1Z63r1q1ru02VG+ZUdeedd5auO543zGnO5E8//fS22+y6666l267zhjlVbdjQ+jYjo7lhzl//+tdKbZ94Yttbo4xost4wZzL/bqt6Q6ymOmb0vy2WOw0vSCltCuwAvAQ8MpGdkiQpojqCvjn9OahF2T7ADODunPMLE9clSZJiqiPobwDWAsellF45VpJS2gL4fPHyshr6JUlSON261/3hwOHFyznFcq+U0tXFz2tzzp8AyDn/JaX07zQC/7aU0jIat8A9lMaldzfQuC2uJEmqqFsz+nnAycWffynW7Thk3dFDN84530TjITgrgaOAjwN/B84AjhvlHfYkSdIIujKjzzkvBZaOsc5dwMHdaF+SJLXm8+gVwk47bXQRx5icddZZpevOnt3uStF/XPLUaZu1a9eWbnvNmjWl6wJ861vfKl13/fr1ldq+5ZZb2pade+65I27TqUytVb1hzplnnlm67gknnFCpbZU3GW+YI0mSusSglyQpMINekqTADHpJkgIz6CVJCsyglyQpMINekqTADHpJkgIz6CVJCsyglyQpMINekqTADHpJkgIz6CVJCsyglyQpMB9Tq0lj8803L133vPPOq9T2wQcfXLruunXr2pZtueWWQOdHui5atKh02wMDA6XrQvXHlqq3bLfddnV3QSU4o5ckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZIC83n0mjTe+c53lq5b5XnyVR122GFtyy6//HIATjnllLbb3H777V3vkyQ1OaOXJCkwg16SpMAMekmSAjPoJUkKzKCXJCkwg16SpMAMekmSAjPoJUkKzKCXJCkwg16SpMAMekmSAjPoJUkKzKCXJCkwg16SpMB8TK0mjfPPP7903b6+vkptV3lUbKe669evr/z+0lCbbNJ5ftbp/8LLL7/c7e5oCnBGL0lSYAa9JEmBGfSSJAVm0EuSFJhBL0lSYAa9JEmBGfSSJAVm0EuSFJhBL0lSYAa9JEmBGfSSJAVm0EuSFJhBL0lSYAa9JEmBGfSSJAXm8+jVNR/4wAcq1Z83b17puoODg5XavvnmmyvVlyZKu2fKN59D3+n/QtX/J6tXr65UX/XoStCnlI4G9gXmAbsBWwHX5ZxPbLHt9sCjHd7u+pzzcd3olyRJva5bM/olNAJ+PfAEMHcUde4Dbmqx/v4u9UmSpJ7XraBfTCPgH6Yxs18xijqrc85Lu9S+JElqoStBn3N+JdhTSt14S0mS1AV1noz3ppTSKcDWwDPAPTnnX9bYH0mSwqkz6N9X/HlFSuk24OSc8+OjeYOU0qpW63POAAwMDLSsN3fu3I7l2thoxmz27NmV2th8881L122ecVzWGWecUbruiSdudM7pK9zXynHc2htpX+9UXvX/SZUjtnvttVeltsdLL+xrdVxH/zzwOWA+8JriT/N7/f2An6SUZtbQL0mSwpnwGX3O+SngnGGrV6aUDgTuBPYEPghcNIr3mt+maBCgv7+/ZWHzk1u7cm1sNGNW9Tr65pGYMjbbbLNKbZ9//vml61544YVty9zXynHc2tuwYUPL9RNxHX2V/6Mf//jHK7U9Xibzvlb136tp0twZL+f8EnBF8XKfOvsiSVIUkyboC08XSw/dS5LUBZMt6BcUy0dq7YUkSUFMeNCnlPZMKW30hWpKaSGNG+8AXDuxvZIkKaZu3ev+cODw4uWcYrlXSunq4ue1OedPFD//F7BLcSndE8W6XYGFxc9n55zv7ka/JEnqdd06634ecPKwdTsWfwD+G2gG/TXAEcAewPuBVwF/ADJwac75ji71SZKkntetW+AuBZaOcttvAt/sRruSJKkzn0evrpk+fXql+lWuhX/qqacqtX399ddXqq/eUuUujkuXLu1eR8Zo+fLllep/6lOf6lJPNJEm21n3kiSpiwx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZIC8zG1CuGFF16oVH/NmjVd6ommgiqPmQVYsmRJ6bpnnXVWpbafeOKJluvnzJkDwJNPPtm27le+8pVKba9fv75SfdXDGb0kSYEZ9JIkBWbQS5IUmEEvSVJgBr0kSYEZ9JIkBWbQS5IUmEEvSVJgBr0kSYEZ9JIkBWbQS5IUmEEvSVJgBr0kSYEZ9JIkBWbQS5IUmM+jVwg333xz3V3QBJs3b17pulWfCX/ssceWrvvd7363UttHHXVUy/UDAwMA9Pf3V3p/xeOMXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZIC8zG16pq+vr7a6h9++OGV2j7ttNMq1dfYLV68uG3ZG97whhG3Ofvss0u3PXv27NJ1Aa677rrSdRctWlSpbWmsnNFLkhSYQS9JUmAGvSRJgRn0kiQFZtBLkhSYQS9JUmAGvSRJgRn0kiQFZtBLkhSYQS9JUmAGvSRJgRn0kiQFZtBLkhSYQS9JUmAGvSRJgfk8enXN4OBgbfXnzJlTqe2LL764dN0rr7yybdn06dMBmDdvXtttnnnmmdJtL1iwoHRdgJNOOql03d12261S29tuu23bsr6+PgDOO++8tts8/vjjpdv+0Y9+VLouwNe+9rVK9aWJVDnoU0pbA0cAhwDvAN4MvAj8CrgKuCrn/HKLensDS4AFwBbAw8CVwCU55w1V+yVJkrpz6P4Y4BvAnsD/By4EbgTeDlwB5JRS39AKKaXDgJXAPsD/A74KbAZcACzrQp8kSRLdOXT/IHAocMvQmXtK6dPAz4CjgCNphD8ppVk0PhhsAPbLOQ8U688GlgNHp5SOyzkb+JIkVVR5Rp9zXp5z/t7ww/M55yeBrxcv9xtSdDTwemBZM+SL7f9G41A+wIer9kuSJI3/Wfd/L5YvDVm3sFj+sMX2K4Hngb1TSpuPZ8ckSeoF43bWfUppU2BR8XJoqO9cLB8cXifn/FJK6VFgF2BH4DcjtLGq1fqcMwADAwOtipk7d27Hcm1sNGP2mte8plIbm2xS/nNn8yztso499tjSdQ844IC2ZTvssAMA1113XdttNmwof+7pzJkzS9cFeO1rX1u67owZMyq1PZp/s07bbLPNNqXbPvDAA0vXhc5XUYzkueeeq9R2O/5eK6cXxm08Z/RfonFC3vdzzkOvZZldLJ9tU6+5/tXj1TFJknrFuMzoU0qnAmcCDwBjvVC3+RF+xIuqc87z2xQNAvT397csbH5ya1eujY1mzI455phKbXznO98pXbfKrBjg+uuvL12303X0zZn8CSec0HYbr6PfWHMm3+neCmvWrCnd9k9/+tPSdQEuuuii2tpux99r5Uzmcat6b5Kmrs/oU0ofBS4Cfg3sn3P+47BNmjP22bQ2a9h2kiSppK4GfUrpdOBS4H4aIf9ki81+Wyx3alF/U2AHGifvPdLNvkmS1Iu6FvQppU/SuOHNahoh/1SbTZcXy4NalO0DzADuzjm/0K2+SZLUq7oS9MXNbr4ErAIOyDmv7bD5DcBa4LiU0itfiqSUtgA+X7y8rBv9kiSp13XjXvcnA5+lcae7O4BTU0rDN3ss53w1QM75Lymlf6cR+LellJYBf6Rxd72di/Xlz4ySJEmv6MZZ9zsUy2nA6W22uR24uvki53xTSmlf4DM0bpHbfKjNGcDFOefunGooSVKPqxz0OeelwNIS9e4CDq7avgQwbdq0SvU/8pGPlK571FFHtS173eteB8Att9zSdpu//OUvpdt+29veVrpu3e6+++62Zc0b0qxevbrtNitWrCjd9jnnnFO6rjTVjPctcCVJUo0MekmSAjPoJUkKzKCXJCkwg16SpMAMekmSAjPoJUkKzKCXJCkwg16SpMAMekmSAjPoJUkKzKCXJCkwg16SpMAMekmSAjPoJUkKrPLz6KWme+65p1L9n//856Xr7rHHHpXarmLOnDlty/r6+kbc5o1vfGPX+zRazzzzTOm6y5Ytq9T2aaed1rZsYGAAgPe+972V2pDkjF6SpNAMekmSAjPoJUkKzKCXJCkwg16SpMAMekmSAjPoJUkKzKCXJCkwg16SpMAMekmSAjPoJUkKzKCXJCkwg16SpMAMekmSAvMxteqaJ554olL9I488snTdU045pVLbS5YsqVS/LhdddFGl+pdddlnpug8//HCltiVNDGf0kiQFZtBLkhSYQS9JUmAGvSRJgRn0kiQFZtBLkhSYQS9JUmAGvSRJgRn0kiQFZtBLkhSYQS9JUmAGvSRJgRn0kiQFZtBLkhSYQS9JUmA+j16Txpo1a0rXXbp0aaW2q9ZvZ2BgAID+/v5xeX9JGokzekmSAjPoJUkKzKCXJCkwg16SpMAMekmSAjPoJUkKzKCXJCkwg16SpMAMekmSAjPoJUkKzKCXJCkwg16SpMAMekmSAjPoJUkKzKCXJCmwys+jTyltDRwBHAK8A3gz8CLwK+Aq4Kqc88tDtt8eeLTDW16fcz6uar8kSVIXgh44BrgMWAOsAB4H3ggcCVwBvD+ldEzOeXBYvfuAm1q83/1d6JMkSaI7Qf8gcChwy7CZ+6eBnwFH0Qj9G4fVW51zXtqF9iVJUhuVgz7nvLzN+idTSl8HvgDsx8ZBL0mSxlk3ZvSd/L1YvtSi7E0ppVOArYFngHtyzr8c5/5IktRTxi3oU0qbAouKlz9sscn7ij9D69wGnJxzfny8+iVJUi8Zzxn9l4C3A9/POf9oyPrngc/ROBHvkWLdrsBSYH/gJymleTnn50ZqIKW0qtX6nDMAAwMDLevNnTu3Y7k25piV47iV47iNnWNWTi+M27gEfUrpVOBM4AHgpKFlOeengHOGVVmZUjoQuBPYE/ggcNF49E2SpF7SNzg4/Kq3alJKHwUuBX4NHJBzfnIMdT8IfAP4vznnoyp0YxCgr6+vZWHzk1t/f3+FJnqLY1aO41aO4zZ2jlk5k3nchuRz6zAbpa7eGS+ldDqNkL8f2H8sIV94uljO7Ga/JEnqVV0L+pTSJ4ELgNU0Qv6pEm+zoFg+0nErSZI0Kl0J+pTS2TROvltF43D92g7b7plS2qzF+oXA4uLltd3olyRJva4b97o/GfgssAG4Azg1pTR8s8dyzlcXP/8XsEtxKd0TxbpdgYXFz2fnnO+u2i9JktSds+53KJbTgNPbbHM7cHXx8zU0HoKzB/B+4FXAH4AMXJpzvqMLfZIkSXTnFrhLaVwDP9rtvwl8s2q7kiRpZD6PXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZIC6xscHKy7D+Mh5F9KktST+qpUdkYvSVJgm9bdgXHS8dNPSmkVQM55/sR0Z+pzzMpx3Mpx3MbOMSunF8bNGb0kSYEZ9JIkBWbQS5IUmEEvSVJgBr0kSYFFvY5ekiThjF6SpNAMekmSAjPoJUkKzKCXJCkwg16SpMAMekmSAjPoJUkKLOrT61pKKW0LfBY4CNgaWAPcBJybc/5TnX2brFJKjwFvaVP8h5zznAnszqSRUjoa2BeYB+wGbAVcl3M+sUOdvYElwAJgC+Bh4ErgkpzzhnHv9CQwlnFLKW0PPNrh7a7POR83Hv2cTFJKWwNHAIcA7wDeDLwI/Aq4Crgq5/xyi3o9vb+Nddwi7289E/QppbcCdwNvAL4LPAC8CzgNOCil9O6c8zM1dnEyexa4sMX69RPdkUlkCY2gWg88AczttHFK6TDgRuBvwPXAH4F/BS4A3g0cM56dnUTGNG6F+2h8IB/u/i72azI7BriMxsRkBfA48EbgSOAK4P0ppWNyzq/c/cz9DSgxboVw+1vPBD3wNRohf2rO+ZLmypTS+cBi4AvAh2rq22T355zz0ro7McksphFUD9OYoa5ot2FKaRbwDWADsF/OeaBYfzawHDg6pXRcznnZuPe6fqMetyFW9/j+9yBwKHDLsBnop4GfAUfRCK8bi/Xubw1jGrchwu1vPfEdfUppR+BA4DHgq8OK/xN4DjgppTRzgrumKSrnvCLn/FCL2UArRwOvB5Y1f+kW7/E3GjNcgA+PQzcnnTGOm4Cc8/Kc8/eGH57POT8JfL14ud+QIvc3So1bWL0yo19YLH/c4h99XUrpLhofBBYAP5nozk0Bm6eUTgS2o/Gh6JfAyl74nq9LmvvfD1uUrQSeB/ZOKW2ec35h4ro1ZbwppXQKjfNqngHuyTn/suY+TRZ/L5YvDVnn/jayVuPWFG5/64kZPbBzsXywTflDxXKnCejLVDQHuIbG1xsX0jj891BKad9aezV1tN3/cs4v0TgBaFNgx4ns1BTyPhozsC8Uy/tSSitSStvV2616pZQ2BRYVL4eGuvtbBx3GrSnc/tYrQT+7WD7bpry5/tUT0Jep5irgABphP5PG2auXA9sDP0gp7VZf16YM979yngc+B8wHXlP8aX6vvx/wkx7/uu1LwNuB7+ecfzRkvftbZ+3GLez+1iuH7kfSVyz93nCYnPO5w1bdD3wopbQeOBNYSuMSFpXn/tdCzvkp4Jxhq1emlA4E7gT2BD4IXDTRfatbSulUGv//HgBOGmP1nt3fOo1b5P2tV2b0zU+ws9uUzxq2nUbWPJlln1p7MTW4/3VRcfj5iuJlz+1/KaWP0gibXwP755z/OGwT97cWRjFuLUXY33ol6H9bLNt9B/+2YtnuO3xt7KliOSUPZU2wtvtf8X3hDjROCnpkIjs1xT1dLHtq/0spnQ5cSuPI2v7FGeTDub8NM8px62RK72+9EvTNa3UPTCn90985pbQVjRtI/BX46UR3bArbq1j2zC+LCpYXy4NalO0DzADu7uEzoMtYUCx7Zv9LKX2Sxg1vVtMIq6fabOr+NsQYxq2TKb2/9UTQ55x/B/yYxglkHx1WfC6NT2nfzjk/N8Fdm9RSSruklF7bYv1baHw6Brh2Yns1Jd0ArAWOSyn1N1emlLYAPl+8vKyOjk1mKaU9U0qbtVi/kMaNd6BH9r/iZjdfAlYBB+Sc13bY3P2tMJZxi7y/9Q0O9sb5GC1ugfsbGidX7E/jkP3e3gL3n6WUlgL/QeOIyKPAOuCtNO4dvQXwfeCInPOLdfWxLimlw4HDi5dzgH+h8Wn/jmLd2pzzJ4apr4kZAAABO0lEQVRtfwONW5Iuo3FL0kNpXAp1A5B64SYyYxm3lNJtwC7AbTTupgewK/+4TvzsnHMzuMJKKZ0MXE3jTneX0Pq79cdyzlcPqdPz+9tYxy3y/tYzZ93nnH9XfLptPtTmYBr3QL6YxkNtRnViRo9ZQeMXwztpHKqfCfyZxhmo1wDXRP9l0cE84ORh63bkH9cm/zfwStDnnG8q7jvwGRq33mw+ZOQM4OIeGsexjNs1NK7o2AN4P/Aq4A9ABi7NOd9Bb9ihWE4DTm+zze00Qg1wfyuMddzC7m89M6OXJKkX9cR39JIk9SqDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKTCDXpKkwAx6SZICM+glSQrMoJckKbD/BRS3C3GiA4iOAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"image/png": {
"height": 250,
"width": 253
}
},
"output_type": "display_data"
}
],
"source": [
"# we could plot out what training images look like\n",
"plt.imshow(train_images[1].reshape(28, 28), cmap='gray')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Classification using $k$-nearest neighbors"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
" metric_params=None, n_jobs=1, n_neighbors=1, p=2,\n",
" weights='uniform')"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.neighbors import KNeighborsClassifier\n",
"classifier = KNeighborsClassifier(n_neighbors=1)\n",
"classifier.fit(flattened_train_images, train_labels)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"predicted_train_labels = classifier.predict(flattened_train_images)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([5, 0, 4, ..., 5, 2, 0], dtype=uint8)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predicted_train_labels"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.0\n"
]
}
],
"source": [
"error_rate = np.mean(predicted_train_labels != train_labels)\n",
"print(error_rate)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Choosing hyperparameter $k$ using simple data splitting"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(0)\n",
"num_train_images = len(flattened_train_images)\n",
"shuffled_indices = np.random.permutation(num_train_images)\n",
"\n",
"train_frac = 0.7\n",
"smaller_train_indices = shuffled_indices[:int(train_frac*num_train_images)]\n",
"validation_indices = shuffled_indices[int(train_frac*num_train_images):]"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"k: 1 error rate: 0.08666666666666667\n",
"k: 2 error rate: 0.105\n",
"k: 3 error rate: 0.085\n",
"k: 4 error rate: 0.095\n",
"k: 5 error rate: 0.08666666666666667\n",
"k: 6 error rate: 0.09666666666666666\n",
"k: 7 error rate: 0.095\n",
"k: 8 error rate: 0.09833333333333333\n",
"k: 9 error rate: 0.1\n",
"k: 10 error rate: 0.105\n",
"k: 11 error rate: 0.10666666666666667\n",
"k: 12 error rate: 0.10333333333333333\n",
"k: 13 error rate: 0.115\n",
"k: 14 error rate: 0.11166666666666666\n",
"k: 15 error rate: 0.10833333333333334\n",
"k: 16 error rate: 0.11\n",
"k: 17 error rate: 0.10833333333333334\n",
"k: 18 error rate: 0.115\n",
"k: 19 error rate: 0.11333333333333333\n",
"k: 20 error rate: 0.11666666666666667\n",
"k: 21 error rate: 0.11333333333333333\n",
"k: 22 error rate: 0.11833333333333333\n",
"k: 23 error rate: 0.12\n",
"k: 24 error rate: 0.12\n",
"k: 25 error rate: 0.12166666666666667\n",
"k: 26 error rate: 0.12\n",
"k: 27 error rate: 0.12\n",
"k: 28 error rate: 0.12166666666666667\n",
"k: 29 error rate: 0.12666666666666668\n",
"k: 30 error rate: 0.12666666666666668\n",
"k: 31 error rate: 0.12666666666666668\n",
"k: 32 error rate: 0.13\n",
"k: 33 error rate: 0.13\n",
"k: 34 error rate: 0.12666666666666668\n",
"k: 35 error rate: 0.13\n",
"k: 36 error rate: 0.13166666666666665\n",
"k: 37 error rate: 0.13\n",
"k: 38 error rate: 0.13166666666666665\n",
"k: 39 error rate: 0.12666666666666668\n",
"k: 40 error rate: 0.13333333333333333\n",
"k: 41 error rate: 0.13666666666666666\n",
"k: 42 error rate: 0.14\n",
"k: 43 error rate: 0.14166666666666666\n",
"k: 44 error rate: 0.14333333333333334\n",
"k: 45 error rate: 0.145\n",
"k: 46 error rate: 0.14666666666666667\n",
"k: 47 error rate: 0.14166666666666666\n",
"k: 48 error rate: 0.14166666666666666\n",
"k: 49 error rate: 0.14666666666666667\n",
"k: 50 error rate: 0.14333333333333334\n",
"Best k: 3 error rate: 0.085\n"
]
}
],
"source": [
"lowest_error = np.inf\n",
"best_k = None\n",
"for k in range(1, 51):\n",
" classifier = KNeighborsClassifier(n_neighbors=k)\n",
" classifier.fit(flattened_train_images[smaller_train_indices],\n",
" train_labels[smaller_train_indices])\n",
" predicted_val_labels = classifier.predict(flattened_train_images[validation_indices])\n",
" error = np.mean(predicted_val_labels != train_labels[validation_indices])\n",
" print('k:', k, 'error rate:', error)\n",
" \n",
" if error < lowest_error:\n",
" lowest_error = error\n",
" best_k = k\n",
"\n",
"print('Best k:', best_k, 'error rate:', lowest_error)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Choosing hyperparameter $k$ using 5-fold cross validation"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"k: 1 cross validation error: 0.092\n",
"k: 2 cross validation error: 0.10899999999999999\n",
"k: 3 cross validation error: 0.10099999999999998\n",
"k: 4 cross validation error: 0.10400000000000001\n",
"k: 5 cross validation error: 0.1025\n",
"k: 6 cross validation error: 0.10850000000000001\n",
"k: 7 cross validation error: 0.10500000000000001\n",
"k: 8 cross validation error: 0.1115\n",
"k: 9 cross validation error: 0.11199999999999999\n",
"k: 10 cross validation error: 0.1155\n",
"k: 11 cross validation error: 0.11750000000000001\n",
"k: 12 cross validation error: 0.11950000000000001\n",
"k: 13 cross validation error: 0.12300000000000003\n",
"k: 14 cross validation error: 0.12400000000000003\n",
"k: 15 cross validation error: 0.131\n",
"k: 16 cross validation error: 0.1335\n",
"k: 17 cross validation error: 0.133\n",
"k: 18 cross validation error: 0.13649999999999998\n",
"k: 19 cross validation error: 0.13899999999999998\n",
"k: 20 cross validation error: 0.1405\n",
"k: 21 cross validation error: 0.14350000000000002\n",
"k: 22 cross validation error: 0.14200000000000002\n",
"k: 23 cross validation error: 0.142\n",
"k: 24 cross validation error: 0.14100000000000001\n",
"k: 25 cross validation error: 0.14300000000000002\n",
"k: 26 cross validation error: 0.146\n",
"k: 27 cross validation error: 0.14850000000000002\n",
"k: 28 cross validation error: 0.14950000000000002\n",
"k: 29 cross validation error: 0.152\n",
"k: 30 cross validation error: 0.1545\n",
"k: 31 cross validation error: 0.15700000000000003\n",
"k: 32 cross validation error: 0.1595\n",
"k: 33 cross validation error: 0.159\n",
"k: 34 cross validation error: 0.15750000000000003\n",
"k: 35 cross validation error: 0.1605\n",
"k: 36 cross validation error: 0.1625\n",
"k: 37 cross validation error: 0.16399999999999998\n",
"k: 38 cross validation error: 0.1645\n",
"k: 39 cross validation error: 0.16450000000000004\n",
"k: 40 cross validation error: 0.1665\n",
"k: 41 cross validation error: 0.16999999999999998\n",
"k: 42 cross validation error: 0.1745\n",
"k: 43 cross validation error: 0.17650000000000002\n",
"k: 44 cross validation error: 0.175\n",
"k: 45 cross validation error: 0.17850000000000002\n",
"k: 46 cross validation error: 0.18\n",
"k: 47 cross validation error: 0.1805\n",
"k: 48 cross validation error: 0.1815\n",
"k: 49 cross validation error: 0.1825\n",
"k: 50 cross validation error: 0.1825\n",
"Best k: 1 cross validation error: 0.092\n"
]
}
],
"source": [
"from sklearn.model_selection import KFold\n",
"\n",
"lowest_cross_val_error = np.inf\n",
"best_k = None\n",
"\n",
"indices = range(num_train_images)\n",
"kf = KFold(n_splits=5, shuffle=True, random_state=0)\n",
"for k in range(1, 51):\n",
" errors = []\n",
" for train_indices, val_indices in kf.split(indices):\n",
" classifier = KNeighborsClassifier(n_neighbors=k)\n",
" classifier.fit(flattened_train_images[train_indices],\n",
" train_labels[train_indices])\n",
" predicted_val_labels = classifier.predict(flattened_train_images[val_indices])\n",
" error = np.mean(predicted_val_labels != train_labels[val_indices])\n",
" errors.append(error)\n",
" \n",
" cross_val_error = np.mean(errors)\n",
" print('k:', k, 'cross validation error:', cross_val_error)\n",
"\n",
" if cross_val_error < lowest_cross_val_error:\n",
" lowest_cross_val_error = cross_val_error\n",
" best_k = k\n",
"\n",
"print('Best k:', best_k, 'cross validation error:', lowest_cross_val_error)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using different classifiers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's simple to work with other classifiers in scikit-learn. For example, here is how one can use random forest classifiers using the number of trees and the max depth as hyperparameters (there are other hyperparameters as well, but we're just using the scikit-learn default values in this demo--if you care about actually tuning the performance of a random forest classifier carefully, then you should look into what the other hyperparameters do by reading the documentation)."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/georgehc/anaconda3/envs/py35/lib/python3.5/site-packages/sklearn/ensemble/weight_boosting.py:29: DeprecationWarning: numpy.core.umath_tests is an internal NumPy module and should not be imported. It will be removed in a future NumPy release.\n",
" from numpy.core.umath_tests import inner1d\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.0\n"
]
}
],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"rf_classifier = RandomForestClassifier(n_estimators=50, max_depth=None, random_state=0)\n",
"rf_classifier.fit(flattened_train_images, train_labels)\n",
"rf_predicted_train_labels = rf_classifier.predict(flattened_train_images)\n",
"rf_error = np.mean(rf_predicted_train_labels != train_labels)\n",
"print(rf_error) # training set error"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we see cross-validation for random forests. Importantly, now we sweep over two hyperparameters."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hyperparameter: (50, 3) cross validation error: 0.2765\n",
"Hyperparameter: (50, 4) cross validation error: 0.20400000000000001\n",
"Hyperparameter: (50, 5) cross validation error: 0.1545\n",
"Hyperparameter: (50, None) cross validation error: 0.091\n",
"Hyperparameter: (100, 3) cross validation error: 0.2635\n",
"Hyperparameter: (100, 4) cross validation error: 0.2\n",
"Hyperparameter: (100, 5) cross validation error: 0.1545\n",
"Hyperparameter: (100, None) cross validation error: 0.08549999999999999\n",
"Hyperparameter: (150, 3) cross validation error: 0.263\n",
"Hyperparameter: (150, 4) cross validation error: 0.1955\n",
"Hyperparameter: (150, 5) cross validation error: 0.15\n",
"Hyperparameter: (150, None) cross validation error: 0.08099999999999999\n",
"Hyperparameter: (200, 3) cross validation error: 0.262\n",
"Hyperparameter: (200, 4) cross validation error: 0.1975\n",
"Hyperparameter: (200, 5) cross validation error: 0.1495\n",
"Hyperparameter: (200, None) cross validation error: 0.08299999999999999\n",
"Best hyperparameter: (150, None) cross validation error: 0.08099999999999999\n"
]
}
],
"source": [
"lowest_cross_val_error = np.inf\n",
"best_hyperparameter_setting = None\n",
"\n",
"hyperparameter_settings = [(num_trees, max_depth)\n",
" for num_trees in [50, 100, 150, 200]\n",
" for max_depth in [3, 4, 5, None]]\n",
"\n",
"indices = range(num_train_images)\n",
"kf = KFold(n_splits=5, shuffle=True, random_state=0)\n",
"for hyperparameter_setting in hyperparameter_settings:\n",
" num_trees, max_depth = hyperparameter_setting\n",
" errors = []\n",
" for train_indices, val_indices in kf.split(indices):\n",
" classifier = RandomForestClassifier(n_estimators=num_trees,\n",
" max_depth=max_depth,\n",
" random_state=0)\n",
" classifier.fit(flattened_train_images[train_indices],\n",
" train_labels[train_indices])\n",
" predicted_val_labels = classifier.predict(flattened_train_images[val_indices])\n",
" error = np.mean(predicted_val_labels != train_labels[val_indices])\n",
" errors.append(error)\n",
" \n",
" cross_val_error = np.mean(errors)\n",
" print('Hyperparameter:', hyperparameter_setting, 'cross validation error:', cross_val_error)\n",
"\n",
" if cross_val_error < lowest_cross_val_error:\n",
" lowest_cross_val_error = cross_val_error\n",
" best_hyperparameter_setting = hyperparameter_setting\n",
"\n",
"print('Best hyperparameter:', best_hyperparameter_setting, 'cross validation error:', lowest_cross_val_error)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Finally actually looking at the test data"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.138\n"
]
}
],
"source": [
"final_knn_classifier = KNeighborsClassifier(n_neighbors=best_k)\n",
"final_knn_classifier.fit(flattened_train_images, train_labels)\n",
"predicted_test_labels = final_knn_classifier.predict(flattened_test_images)\n",
"test_set_error = np.mean(predicted_test_labels != test_labels)\n",
"print(test_set_error)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.082\n"
]
}
],
"source": [
"final_rf_classifier = RandomForestClassifier(n_estimators=best_hyperparameter_setting[0],\n",
" max_depth=best_hyperparameter_setting[1],\n",
" random_state=0)\n",
"final_rf_classifier.fit(flattened_train_images, train_labels)\n",
"predicted_test_labels = final_rf_classifier.predict(flattened_test_images)\n",
"test_set_error = np.mean(predicted_test_labels != test_labels)\n",
"print(test_set_error)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that in general the cross validation error is not going to perfectly match up with the test set error."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment