Skip to content

Instantly share code, notes, and snippets.

@myurasov
Last active November 21, 2017 04:18
Show Gist options
  • Save myurasov/69efc0172b58fbccfcdbf3c097006fa8 to your computer and use it in GitHub Desktop.
Save myurasov/69efc0172b58fbccfcdbf3c097006fa8 to your computer and use it in GitHub Desktop.
Manual Backprop
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import scipy as sp\nimport numpy as np",
"execution_count": 17,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import matplotlib.pyplot as plt\n%matplotlib inline\n\n# apply jt styling to matplotlib\ntry:\n from jupyterthemes import jtplot\n jtplot.style()\nexcept ImportError:\n pass",
"execution_count": 18,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "TRAIN_SIZE = 55000\nTEST_SIZE = 1000\nRESIZE_MNIST_TO = (20, 20)",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "BATCH_SIZE = 100\nLEARNING_RATE = 0.03\n\nINPUT_SIZE = RESIZE_MNIST_TO[0] * RESIZE_MNIST_TO[1]\nL1_SIZE = 200\nL2_SIZE = 100\nOUTPUT_SIZE = 10 # 10 classes",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "np.random.seed(0)",
"execution_count": 4,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# create downsampled version of MNIST\ndef gen_resampled_mnist(n_train=55000, n_test=1000, size=(10, 10)):\n def _create_subset(src, n):\n i = 0\n X = np.zeros((n, size[0] * size[1]))\n Y = np.zeros((n, 10))\n\n while i < n:\n img = src.images[i].reshape(28, 28)\n img = sp.misc.imresize(img, size)\n X[i], Y[i] = img.flatten(), src.labels[i]\n i += 1\n\n return X, Y\n\n import tensorflow.examples.tutorials.mnist as mnist\n sets = mnist.input_data.read_data_sets('MNIST_data', one_hot=True)\n\n train_X, train_Y = _create_subset(sets.train, n_train)\n test_X, test_Y = _create_subset(sets.test, n_test)\n\n return (train_X, train_Y, test_X, test_Y)",
"execution_count": 12,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# transfer (activation) function/it's derivative\n\n\n# logistic sigmoid\n# https://en.wikipedia.org/wiki/Sigmoid_function\ndef sigmoid(x):\n # clip -x to prevent overflow in np.exp\n return 1 / (1 + np.exp(np.clip(-x, -708., 709.)))\n\n\n# derivative of logistic sigmoid\ndef sigmoid_d(x):\n return x * (1 - x)",
"execution_count": 6,
"outputs": []
},
{
"metadata": {
"scrolled": true,
"trusted": true
},
"cell_type": "code",
"source": "# create train/test sets with downsampled mnist images\ntrain_X, train_Y, test_X, test_Y = gen_resampled_mnist(TRAIN_SIZE, TEST_SIZE, RESIZE_MNIST_TO)\nplt.imshow(train_X[111].reshape(RESIZE_MNIST_TO), cmap=plt.cm.gray)\nprint('label=', np.argmax(train_Y[777]))",
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": "Extracting MNIST_data/train-images-idx3-ubyte.gz\nExtracting MNIST_data/train-labels-idx1-ubyte.gz\nExtracting MNIST_data/t10k-images-idx3-ubyte.gz\nExtracting MNIST_data/t10k-labels-idx1-ubyte.gz\nlabel= 4\n",
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/plain": "<matplotlib.figure.Figure at 0x7f206f6ff1d0>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFRCAYAAABKR3dEAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3X9MG/f9P/CnzzZgAkkEmrIMXFGIknZym1aem58VSiNN\n7EOQWqAVctYyVWSqKq1q+sfUjS7/EHXqpEKWsU3ijyr7IwTRjkhzNBllStjGYkYpUVOXtFPUdhCt\nPwIUjI0NduzvH9H5C8F3tu8M7wOej38I936f3y/edzxznM93pqqqqgSIiGjNSaILICLarBjARESC\nMICJiARhABMRCcIAJiISxCK6gEwUFhYiGo2KLoOIKGtWqxXz8/Mp2wwfwIWFhTh+/LjoMoiINDt/\n/nzKEDZ8AMtHvm+++SYWFhaWtUmSBIfDAb/fj3g8LqK8lIxaF2Dc2lhXdlhXdkTVlZ+fj1/+8peK\nf8EbPoBlCwsLKQM4Go1iYWHBcBvbiHUBxq2NdWWHdWXHsHWJLoCIaLPKyRGwJElobGzEvn37YDKZ\ncP36dVy4cAGxWExXXyKijSwnR8A1NTXYvXs32tracOrUKezcuRP19fW6+xIRbWQ5OQI+fPgw+vr6\nMDMzAwC4dOkSTpw4gXfffReJREJz36UkSYIkSSuWLf1qFEatCzBubawrO6wrO6LqSjee7gC22Wwo\nKSnBxMREctn4+DhsNhtKS0sxOTmpqe/9HA6H4juJDodD74+xKoxaF2Dc2lhXdlhXdta6LqvVqtqu\nO4ALCgoAAOFwOLlMvt5NbtPS935+v5+XoeWAUWtjXdlhXdkReRlaQ0ODYrvuAI5EIgDuHd0GAgEA\n9z48sbRNS9/7xeNxxYlTaxPJqHUBxq2NdWWHdWVnretKN5buEyLhcBjT09MoLy9PLrPb7QiHw5ia\nmtLcl4hoo8vJGenBwUHU1NRg27ZtKCoqwrFjx+Dz+VK+qZZNXyKijSwnV0F4vV4UFRXh1KlTkCQJ\no6OjuHjxIgDA7XYDALq7u9P2JSLaTHISwPF4HL29vejt7V3RJgdvJn2JiDYTY12sR0S0iTCAiYgE\nYQATEQnCACYiEoQBTEQkCAOYiEgQBjARkSAMYCIiQRjARESCMICJiARhABMRCcIAJiIShAFMRCQI\nA5iISBAGMBGRIAxgIiJBGMBERIIwgImIBGEAExEJwgAmIhKEAUxEJAgDmIhIEAYwEZEgDGAiIkEY\nwEREgjCAiYgEYQATEQnCACYiEsSi+wUsFjQ1NWHPnj0oLi7G7Owsrl69ioGBgZT9m5ub4XK5EIvF\nksu6urowNjamtxQionVFdwBLkoTZ2VmcPXsWk5OTKCsrwyuvvIK5uTl88MEHKdcZHBxET0+P3qGJ\niNY13QG8uLgIj8eT/P727du4ceMGqqqqFANYC0mSIEnSimVLvxqFUesCjFsb68oO68qOqLrSjac7\ngFMNuGvXLly+fFmxj8vlgsvlQiAQwPDwMPr7+xGPx1Vf1+FwIBqNKrYZkVHrAoxbG+vKDuvKzlrX\nZbVaVdtzHsBNTU2IRCIYGhpK2X7lyhX09fUhGAzCbrejpaUFFotl2VF0Kn6/HwsLC8uWSZIEh8MB\nv9+fNsDXklHrAtLXVlpaquv1tf68JpMJ5eXlhpszo25Lua5PPvlEU101NTW6xg8GgymXx+NxzMzM\nYPv27YpHf0rvD60mUdsxPz8fDQ0Niu05DeDGxkZUVlaio6MDd+/eTdlnYmIi+e/x8XF4PB7U1dWl\nDeB4PK44cWptIhm1LkC5tkQioet19a5v1DnbaHXp/VM83fqpThnKRM7jWm/HdGPl7ITIs88+i4cf\nfhhnzpxBKBTKeD29v7BEROtVTgL4ueeew0MPPYSOjg7FP01kTqcTBQUFAICysjLU1tZidHQ0F2UQ\nEa0ruk9BlJSU4KmnnkI0GsXp06eTy2/duoXOzk643W4AQHd3NwCguroabrcbZrMZgUAAQ0ND8Hq9\nessgIlp3dAfw9PQ0XnrpJcV2OXhl7e3teockItoQjHWxHhHRJsIAJiIShAFMRCQIA5iISBAGMBGR\nIAxgIiJBGMBERIIwgImIBMn53dBIPLUbpaS7L2pdXZ2usffu3atpvcXFRXR3d8Nms2m6WUo4HNY0\n7npXXl6uab3XX39d17j//e9/Uy6PRCLo6OhAc3Nz8pYD9/vb3/6ma+yNhEfARESCMICJiARhABMR\nCcIAJiIShAFMRCQIA5iISBAGMBGRIAxgIiJBGMBERIIwgImIBGEAExEJwgAmIhKEAUxEJAgDmIhI\nEN6OcgMqKipSbDOZTACALVu2IJFIrGivrq7WNfbjjz+uab35+Xl0d3ejuLg4ZV3pbNbbUX7nO99J\nbtNsFBcX6xr322+/Tbl8YWEBADAzM4P8/HxdY2wGPAImIhKEAUxEJAgDmIhIEAYwEZEgut+Ea25u\nhsvlQiwWSy7r6urC2NhYyv6SJKGxsRH79u2DyWTC9evXceHChWXrExFtBjm5CmJwcBA9PT0Z9a2p\nqcHu3bvR1taGWCyGl19+GfX19ejt7c1FKURE68aaX4Z2+PBh9PX1YWZmBgBw6dIlnDhxAu+++67q\n5UeSJK14km+6J/yKIroutcuS5DalPvJlRFrNz89rWk++jEzLJVXA6s216G2pRK5HyyV7gPbtJFPa\nTxYXF5d9TUXEXIrajunGy0kAu1wuuFwuBAIBDA8Po7+/P+WjxW02G0pKSjAxMZFcNj4+DpvNhtLS\nUkxOTiqO4XA4EI1GFduMyKh1AUBFRUXK5X/84x/XtpD77NixQ9N63/3ud3NcyXJG3ZZqQafmxIkT\nOa5kufPnzyu2Pfroo6s6tpq13o5Wq1W1XXcAX7lyBX19fQgGg7Db7WhpaYHFYoHH41nRt6CgAMDy\ni+bl/4nlNiV+v3/F/7qSJMHhcMDv96cMfFFE16V2kb3JZEJFRQW++OKLlEdPv/nNb3SNrfWXKxwO\n47XXXsPXX3+t6ajum2++0TRuOqK3pRK5rry8PE1/NXR2duoa/9q1aymXLy4u4vz58zh+/Djy8vJS\n9jl58qSusbUQtR3z8/PR0NCg2K47gO8/mvV4PKirq0sZwJFIBMC9I+FAIAAAKCwsXNamJB6PK06c\nWptIourKJMASiUTKfno/vSRvT62U6kpntefZqPuYyWTSFMB6t1O6/SQvL0+xj8h5XOvtmG6snJ8Q\nUfvlCYfDmJ6eRnl5eXKZ3W5HOBzG1NRUrkshIjI03QHsdDqTpw/KyspQW1uL0dFRxf6Dg4OoqanB\ntm3bUFRUhGPHjsHn82l+M4GIaL3SfQqiuroabrcbZrMZgUAAQ0ND8Hq9yXa32w0A6O7uBgB4vV4U\nFRXh1KlTkCQJo6OjuHjxot4yiIjWHd0B3N7ertouB68sHo+jt7eX1/0S0abH21FuQGqXvshv2Fit\n1pSnffS+OTM+Pq5pPflN2EAgYMg3u1aL2q1D1cjXl/7oRz9SvNpATWlpqaZxZUNDQymX3717FwAw\nMjICs9msa4zNwFhXlxMRbSIMYCIiQRjARESCMICJiARhABMRCcIAJiIShAFMRCQIA5iISBAGMBGR\nIAxgIiJBGMBERIIwgImIBGEAExEJwgAmIhKEAUxEJAjvB7wB2e32tH2WPpdvqYcffljX2G+99Zam\n9eT7yEYikTW/H7DaQy3lNrWHX2q9py8A7N27V9N6iUQCoVAITz31lKZ7OH/11VeaxpX985//VGwr\nLi5WfGoyLccjYCIiQRjARESCMICJiARhABMRCcIAJiIShAFMRCQIA5iISBAGMBGRIAxgIiJBGMBE\nRIIwgImIBNF9L4gzZ84s+95qteLLL7/E6dOnU/Zvbm6Gy+VCLBZLLuvq6sLY2JjeUoiI1hXdAfzq\nq68u+/6NN97AyMiI6jqDg4Po6enROzQR0bqW01MQFRUV2LlzJ3w+Xy5flohoQ8rp7SgPHjyIjz/+\nGLOzs6r9XC4XXC4XAoEAhoeH0d/fn/YWhJIkQZKkFcuWfjWKXNSVl5ened2DBw8qtsViMQwPD+OJ\nJ56AxbJy8z/wwAOaxwWAGzduaF431TbO1LZt2zSP+/jjjyu2xeNxzMzM4Mknn1Ssrba2VvPYTz/9\ntKb1QqEQXnjhBezatUvT7TDfe+89TePKJicnUy43mUwoLi7G1NQUEolEyj4ifl9FZUW68UxVVVWp\nZylLeXl5eOutt3Du3Dl8+OGHiv3sdjtmZmYQDAZht9vR0tKC999/Hx6PJ2V/q9WKn/zkJ/jzn/+M\naDSai1KJiNaE1WpFQ0MDzp07lzK/cnYE7HQ6sbi4iI8++ki138TERPLf4+Pj8Hg8qKurUwxgmd/v\nx8LCwrJlkiTB4XDA7/ev+U281eSiLj1HwC+++KJiW7oj4F//+teaxwWAJ598UvO6kiRpnrPVPgLe\nvn27IY+A//rXvwo5Aj558mTK5SaTCRUVFfjiiy8Uj4Dn5uZ0ja2FqKzIz89HQ0ODYnvOAvjQoUPw\n+XxZ/3BKG+l+8Xhc8bXV2kTSU5eenydVsKbqk6pfcXGx5nFzQeucZbofpZLJn6Vqp0f0/Ge5ZcsW\nzesC957GoWWb5efn6xo33XwnEgnFPiJ/V9c6K9KeWs3FIDt27EBlZWVGjyFxOp0oKCgAAJSVlaG2\nthajo6O5KIOIaF3JyRHwoUOHcOvWLXzzzTcr2txuNwCgu7sbAFBdXQ232w2z2YxAIIChoSF4vd5c\nlEFEtK7kJID7+voU2+TglbW3t+diSCKidc9Y128REW0iDGAiIkEYwEREgjCAiYgEYQATEQnCACYi\nEoQBTEQkCAOYiEiQnN6OknKntLRU87r/93//p9gWiURw7do1/PCHP0x+JHypO3fuaB4XAPbs2aNp\nvbt37+Kzzz7Dz372s4zuZXG/w4cPaxoXAPbu3avYFgqF8Pzzz+O3v/2t4n0b7Ha75rHNZrOm9eQb\n2oTDYU2vke7mV+mEQqGUy+X7ZczPzxvy/ixGwyNgIiJBGMBERIIwgImIBGEAExEJwgAmIhKEAUxE\nJAgDmIhIEAYwEZEgDGAiIkEYwEREgjCAiYgEYQATEQnCACYiEoQBTEQkCG9HaVD19fWa1/3BD36g\n2BYMBgEAjz32GIqKila05+fnax4XANrb2zWtFwqF0NTUhNdffz1lXenIt0HUYnp6WrHNarUmv8r/\nvt8nn3yieeyqqipN6y0sLAAAPv/8cxQWFma9/tjYmKZxZXfv3k25PJFIJNt5O8r0eARMRCQIA5iI\nSBAGMBGRIAxgIiJBMnoTzul04siRIygvL0coFEJra2uyTZIkNDY2Yt++fTCZTLh+/TouXLiAWCyW\n8rWy7U9EtFFldAQ8Pz+PgYEB/OUvf1nRVlNTg927d6OtrQ2nTp3Czp07Vd/Bz7Y/EdFGlVEA37x5\nEyMjI5iamlrRdvjwYXi9XszMzCAYDOLSpUs4cOAATCZTytfKtj8R0Ual6zpgm82GkpISTExMJJeN\nj4/DZrOhtLQUk5OTuvovJUnSims95e/1XAO6GnJRl55TMvK1vqnIjxNXeqx4NBrVPK7a66YzPz+v\na309c602plyX/FWtjxZq20qNXHM4HNY8th5K872Rfyf1jKtEVwAXFBQAWL4TyDuj3Kan/1IOh0Mx\nHBwORxZVrx09df373//WvO7TTz+dts/x48c1v/5qamxsFF1CSi+++KLoElI6efKkpvX0BtGjjz6q\n2r4Rfye1UPrwjkxXAEciEQD3jmwDgQAAJD+VI7fp6b+U3+9PfvpHJkkSHA4H/H6/oT51k4u6Tpw4\noXn8X/ziF4ptoVAIx48fx/nz57Fly5YV7Xo/CTc7O6tpvfn5ebz44ot47733UtaVjp5A+fbbb9PW\n9c477yh+4kzrzwwADz74oKb1QqEQnnnmGXR0dMBms2W9/k9/+lNN48r8fn/K5Rv5d1KL/Px8NDQ0\nKLbrCuBwOIzp6WmUl5fj66+/BgDY7XaEw+GU54uz7b9UPB5XnDi1NpH01GWxaN80mXyUd8uWLavy\nUWS9V7Mo1ZWOngBeXFxM26ewsFDxPwY9p220/KxL2Ww2TR9F1ivdfr0Rfye1jqcmo73WZDLBYrHA\nbDYDuBcOckAMDg6ipqYG27ZtQ1FREY4dOwafz5f8TPj9su1PRLRRZXSYtX//fjQ3Nye/7+zsxNTU\nFFpbW+H1elFUVIRTp05BkiSMjo7i4sWLyb5utxsA0N3dDQBp+xMRbRYZBbDP54PP50vZFo/H0dvb\ni97e3pTtcvBm2p+IaLMw1rUiRESbCO8HbFBPPvmk5nW3b9+u2Cafu9++fXvKN4D0fiBG65t48vWw\nn376qaZ39fv7+zWNC6hf8ie/iXLy5EnFN/oeeeQRzWP//Oc/17SefNXQp59+mvYSzlS++uorTeNS\nbvEImIhIEAYwEZEgDGAiIkEYwEREgjCAiYgEYQATEQnCACYiEoQBTEQkCAOYiEgQBjARkSAMYCIi\nQRjARESCMICJiARhABMRCcLbURqU0kMPM1FaWqrYJj+R+h//+EfK2z6OjIxoHld+XS3k2z6+8MIL\nmh5PdefOHU3jAurP35MkCZWVlfjXv/6l+Hyv+vp6zWPLj5fPlvw08XfeeUfTLUQnJyc1jUu5xSNg\nIiJBGMBERIIwgImIBGEAExEJwgAmIhKEAUxEJAgDmIhIEAYwEZEgDGAiIkEYwEREgjCAiYgEyehe\nEE6nE0eOHEF5eTlCoRBaW1vvrWyxoKmpCXv27EFxcTFmZ2dx9epVDAwMKL5Wc3MzXC4XYrFYcllX\nVxfGxsb0/SREROtMRgE8Pz+PgYEBbN26FUePHk0ulyQJs7OzOHv2LCYnJ1FWVoZXXnkFc3Nz+OCD\nDxRfb3BwED09PfqrJyJaxzI6BXHz5k2MjIxgampq2fLFxUV4PB7cuXMHiUQCt2/fxo0bN1BVVbUq\nxRIRbSQ5vR2lJEnYtWsXLl++rNrP5XLB5XIhEAhgeHgY/f39irf6W/rakiStWLb0q1Hkoq6zZ8+u\nyromkwkPPvggnn/++ZS3fQwGg5rH1UOSJDgcDvzvf/9Luy+kouWWjLInnnhCsS2RSGBubg579+5V\nHGPfvn2ax75+/bqm9SKRCADg/fff1zRfeint2xv5d1LPuEpyGsBNTU2IRCIYGhpS7HPlyhX09fUh\nGAzCbrejpaUFFosFHo9H9bUdDgei0ahimxEZtS4AePDBB0WXkJKIOQsEAmn7zM3NKbb9+Mc/zmU5\nWdm9e7ewsdUYdd9f67qsVqtqe84CuLGxEZWVlejo6MDdu3cV+01MTCT/PT4+Do/Hg7q6urQB7Pf7\nsbCwsGyZfNTk9/uFHAUoyUVdRUVFOa7qHvkI+PPPPzfkEbDWOdNzBHzo0CHFNvkIuLi4WHGMP/zh\nD5rH/uyzzzStF4lE8Oabb+I///mPpvlaXFzUNG46G/l3Uov8/Hw0NDQotuckgJ999lk89NBD6Ojo\nyPoO/5k+/SAejytOnFqbSHrqWq2fR/6TKJFIpBxD9DxqnTM9AZzJuiaTSbHfli1bNI+d6qkk2dA6\nX6u9nTfi76TW8dRkdELEZDLBYrHAbDYDuHf5mfwYl+eeey4ZvpkcPTmdThQUFAAAysrKUFtbi9HR\n0UzKICLaUDI6At6/fz+am5uT33d2dmJqagpvv/02nnrqKUSjUZw+fTrZfuvWLXR2dgIA3G43AKC7\nuxsAUF1dDbfbDbPZjEAggKGhIXi93pz9QERE60VGAezz+eDz+VK2vfTSS6rrysEra29vz7A0IqKN\nzVjXihARbSIMYCIiQRjARESCMICJiARhABMRCcIAJiIShAFMRCQIA5iISJCc3g2Ncme1booj3wsi\nGAwa8rP6WqW765SaRx55RLEtFothaGgI3//+95Mfv79fWVmZ5rGXfoI0G/INrxYXFzfUdtxseARM\nRCQIA5iISBAGMBGRIAxgIiJBGMBERIIwgImIBGEAExEJwgAmIhKEAUxEJAgDmIhIEAYwEZEgDGAi\nIkEYwEREgjCAiYgEYQATEQnC+wHThvC9731P87r19fWKbeFwGENDQ6irq4PNZkvZ5/bt25rHfv/9\n9zWvK9/bmdYvbkEiIkEYwEREgjCAiYgEyegcsNPpxJEjR1BeXo5QKITW1tZkW3NzM1wuF2KxWHJZ\nV1cXxsbGUr6WJElobGzEvn37YDKZcP36dVy4cGHZ+kREm0FGATw/P4+BgQFs3boVR48eXdE+ODiI\nnp6ejAasqanB7t270dbWhlgshpdffhn19fXo7e3NrnIionUuo1MQN2/exMjICKampnQPePjwYXi9\nXszMzCAYDOLSpUs4cOAATCaT7tcmIlpPcnIZmsvlgsvlQiAQwPDwMPr7+1M+Kttms6GkpAQTExPJ\nZePj47DZbCgtLcXk5KTiGJIkrbjsRv7eaJfjGLUuwLi1iawrHA4rtkUikWVfUwmFQjmvKVPcjpkR\nVVe68XQH8JUrV9DX14dgMAi73Y6WlhZYLBZ4PJ4VfQsKCgAs3+Hn5+eXtSlxOByIRqOKbUZk1LoA\n49Ymoq6l72koaWtrW5Wx9QYCt2N21rouq9Wq2q47gO8/mvV4PKirq0sZwPJRhM1mQyAQAAAUFhYu\na1Pi9/uxsLCwbJkkSXA4HPD7/SmPuEUxal2AcWvTW9cDDzygeezf//73im2RSARtbW341a9+pXiQ\noGfs48ePa15XkqQNtx1Xi6i68vPz0dDQoNie80/CJRIJxbZwOIzp6WmUl5fj66+/BgDY7XaEw+G0\n55fj8bjixKm1iWTUugDj1iaiLqVPuC1VUFCg2G/Lli25Lilj3I7ZWeu60o2V0d8/JpMJFosFZrMZ\nAGCxWGCx3Mtup9OZPDIoKytDbW0tRkdHFV9rcHAQNTU12LZtG4qKinDs2DH4fD7V4CYi2ogyOgLe\nv38/mpubk993dnZiamoKra2tqK6uhtvthtlsRiAQwNDQELxeb7Kv2+0GAHR3dwMAvF4vioqKcOrU\nKUiShNHRUVy8eDGXPxMR0bqQUQD7fD74fL6Ube3t7arrysEri8fj6O3t5XW/RLTpGetaESKiTYS3\no6QN4fHHH9e87mOPPabYFgwGAQCPPPIIioqKUvb505/+pHnsW7duaVpPkiTs2rVL87hkDDwCJiIS\nhAFMRCQIA5iISBAGMBGRIAxgIiJBGMBERIIwgImIBGEAExEJwgAmIhKEAUxEJAgDmIhIEAYwEZEg\nDGAiIkEYwEREgvB2lLQhyE/X1kLtEVryE7w//PBDxWfC/f3vf9c8tta6jfbYd9KGW5GISBAGMBGR\nIAxgIiJBGMBERIIwgImIBGEAExEJwgAmIhKEAUxEJAgDmIhIEAYwEZEgDGAiIkEyuheE0+nEkSNH\nUF5ejlAohNbW1mTbmTNnlvW1Wq348ssvcfr06ZSv1dzcDJfLhVgsllzW1dWFsbExLfUTEa1bGQXw\n/Pw8BgYGsHXrVhw9enRZ26uvvrrs+zfeeAMjIyOqrzc4OIienp4sSyUi2lgyCuCbN28CAPbu3ava\nr6KiAjt37oTP59NfGRHRBpfT21EePHgQH3/8MWZnZ1X7uVwuuFwuBAIBDA8Po7+/H/F4XHUdSZJW\n3IJP/t5ot+Yzal2AcWvTW9fly5c1j622riRJcDgceOaZZ9Luo1po/Xk36nZcLaLqSjdezgI4Ly8P\nLpcL586dU+135coV9PX1IRgMwm63o6WlBRaLBR6PR3U9h8OBaDSq2GZERq0LMG5trCs7rCs7a12X\n1WpVbc9ZADudTiwuLuKjjz5S7TcxMZH89/j4ODweD+rq6tIGsN/vx8LCwrJl8tGJ3+9flaMTrYxa\nF2Dc2lhXdlhXdkTVlZ+fj4aGBsX2nAXwoUOH4PP5sv7hEolERv3i8bjia6u1iWTUugDj1sa6ssO6\nsrPWdaU9tZrJi5hMJlgsFpjNZgCAxWKBxfL/s3vHjh2orKzEtWvX0r6W0+lEQUEBAKCsrAy1tbWq\nj4QhItqoMjoC3r9/P5qbm5Pfd3Z2YmpqKnk98KFDh3Dr1i188803K9Z1u90AgO7ubgBAdXU13G43\nzGYzAoEAhoaG4PV6df8gRETrTUYB7PP5VC8t6+vrU2yTg1fW3t6eYWlERBubsa4VISLaRBjARESC\nMICJiARhABMRCcIAJiIShAFMRCQIA5iISBAGMBGRIAxgIiJBGMBERIIwgImIBGEAExEJwgAmIhKE\nAUxEJAgDmIhIEAYwEZEgDGAiIkEYwEREgjCAiYgEYQATEQnCACYiEoQBTEQkCAOYiEgQi+gCMpWf\nn79imSRJsFqtyM/PRzweF1BVakatCzBubawrO6wrO6LqSpVbS5mqqqoSa1SLJoWFhTh+/LjoMoiI\nNDt//jzm5+dXLDd8AAP3QjgajYoug4goa1arNWX4AuvkFIRS8URERqd28Mg34YiIBGEAExEJwgAm\nIhKEAUxEJIjh34STJAmNjY3Yt28fTCYTrl+/jgsXLiAWi+nqq4fFYkFTUxP27NmD4uJizM7O4urV\nqxgYGEjZv7m5GS6Xa1kdXV1dGBsby2ldWsZaqzk7c+bMsu+tViu+/PJLnD59OmX/1Zozp9OJI0eO\noLy8HKFQCK2trcm2bOcil3OnVFe2+xqQ27lTmy+R+5paXUbZ1zJh+ACuqanB7t270dbWhlgshpdf\nfhn19fXo7e3V1VcPSZIwOzuLs2fPYnJyEmVlZXjllVcwNzeHDz74IOU6g4OD6OnpyWkdSrIZa63m\n7NVXX132/RtvvIGRkRHVdVZjzubn5zEwMICtW7fi6NGjy9qynYtczp1SXVr2NSB3c6c2X9mOsxbz\nBRhnX8uE4U9BHD58GF6vFzMzMwgGg7h06RIOHDgAk8mkq68ei4uL8Hg8uHPnDhKJBG7fvo0bN26g\nqqoqp+OshbWas6UqKiqwc+dO+Hy+VRtDyc2bNzEyMoKpqakVbdnORS7nTqku0fua2nxlay3m634i\n97VMGPoI2GazoaSkBBMTE8ll4+PjsNlsKC0txeTkpKa+uSZJEnbt2oXLly8r9nG5XHC5XAgEAhge\nHkZ/f/+qzh+sAAACi0lEQVSqfSQy07FEzdnBgwfx8ccfY3Z2VrXfWs5ZtnMhau4y2deAtZs77mv6\nGDqACwoKAADhcDi5TP5QhtympW+uNTU1IRKJYGhoKGX7lStX0NfXh2AwCLvdjpaWFlgsFng8npzX\nks1YIuYsLy8PLpcL586dU+23lnMGZD8Xova3dPsasHZzx31NP0OfgohEIgDu/e8pKywsXNampW8u\nNTY2orKyEr/73e9w9+7dlH0mJiYwNzeHRCKB8fFxeDweuFyuVaknm7FEzJnT6cTi4iI++ugj1X5r\nOWdA9nMhYu4y2deAtZs77mv6GTqAw+EwpqenUV5enlxmt9sRDodXnPvJpm+uPPvss3j44Ydx5swZ\nhEKhjNdLJNbu9htqY4mYs0OHDsHn82X9591qz1m2c7HWc6d1XwPWbn/jvpY9QwcwcO/dyZqaGmzb\ntg1FRUU4duwYfD5fyknKpq9ezz33HB566CF0dHQgGAyq9nU6nck/s8rKylBbW4vR0dGc16RlrLWc\nsx07dqCyshLXrl1L23e15sxkMsFiscBsNgO4d5mXxXLvTFy2c5HLuVOrK5t9Dcjt3KnVJXJfU6sL\nMMa+lgnD3w1t6bWDkiRhdHQUPT09iEajcLvdAIDu7u60fXOppKQEb775JqLR6LI/BW/duoXOzs4V\ndb322msoKyuD2WxGIBDA0NAQvF7vqpzkTzeWqDkDgPr6elRUVKC9vX1F21rN2YEDB9Dc3Lxs2dTU\nFFpbW9POxWrOnVJdb7/9tuq+lqquXM6d2nyJ3NfU6gKMsa9lwvABTES0URn+FAQR0UbFACYiEoQB\nTEQkCAOYiEgQBjARkSAMYCIiQRjARESCMICJiARhABMRCfL/ABFe0pvig4iqAAAAAElFTkSuQmCC\n"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Model:\n\n# L0 (input) [INPUT_SIZE]\n# W1 [INPUT_SIZE,L1_SIZE]\n# L1: sigmoid(L0*W1+B1) [L1_SIZE]\n# W2 [L1_SIZE,L2_SIZE]\n# L2: sigmoid(L1*W2+B2) [L2_SIZE]\n# W3: [L2_SIZE,OUTPUT_SIZE]\n# L3 (output): sigmoid(L2*W3+B3)",
"execution_count": 8,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# init weights and biases from normal dist with 0.1 sqv\n\nw1 = np.random.normal(scale=1, size=(INPUT_SIZE, L1_SIZE))\nw2 = np.random.normal(scale=1, size=(L1_SIZE, L2_SIZE))\nw3 = np.random.normal(scale=1, size=(L2_SIZE, OUTPUT_SIZE))\n\nb1 = np.random.normal(scale=1, size=(L1_SIZE))\nb2 = np.random.normal(scale=1, size=(L2_SIZE))\nb3 = np.random.normal(scale=1, size=(OUTPUT_SIZE))",
"execution_count": 9,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def eval_on_test():\n\n l0 = test_X\n l1 = sigmoid(np.dot(l0, w1) + b1)\n l2 = sigmoid(np.dot(l1, w2) + b2)\n l3 = sigmoid(np.dot(l2, w3) + b3)\n\n return np.mean((-np.log(\\\n [y[m] for m, y in zip(np.argmax(test_Y, axis=1), l3)]\\\n )))",
"execution_count": 10,
"outputs": []
},
{
"metadata": {
"scrolled": true,
"trusted": true
},
"cell_type": "code",
"source": "STEPS = 5000000 // BATCH_SIZE\n\nsamples = 0\n\nfor step in range(STEPS):\n\n batch_start = step * BATCH_SIZE\n if len(train_X) - batch_start <= 0: batch_start = 0\n\n train_X_batch = train_X[batch_start:batch_start + BATCH_SIZE]\n train_Y_batch = train_Y[batch_start:batch_start + BATCH_SIZE]\n\n # forward pass\n\n l0 = train_X_batch\n l1 = sigmoid(np.dot(l0, w1) + b1)\n l2 = sigmoid(np.dot(l1, w2) + b2)\n l3 = sigmoid(np.dot(l2, w3) + b3)\n\n # compute \"deltas\" (errors * transfer_function_derivative(layer_output))\n\n # for output layer: errors = true_labels - layer_output\n l3_delta = (train_Y_batch - l3) * sigmoid_d(l3) # [BATCH_SIZE,OUTPUT_SIZE]\n # for hidden layers: errors = upper_layer_deltas * upper_layer_weights.T\n l2_delta = np.dot(l3_delta, w3.T) * sigmoid_d(l2)\n l1_delta = np.dot(l2_delta, w2.T) * sigmoid_d(l1)\n\n # update weights and biases\n\n # weight = weight + LR * input * delta\n w3 += LEARNING_RATE * np.dot(l2.T, l3_delta)\n # bias = bias + LR * 1. * delta (like weights with inputs being 1.)\n # sum along batch index axis - same as multiplying by matrix of ones\n b3 += LEARNING_RATE * np.sum(l3_delta, axis=0)\n\n w2 += LEARNING_RATE * np.dot(l1.T, l2_delta)\n b2 += LEARNING_RATE * np.sum(l2_delta, axis=0)\n w1 += LEARNING_RATE * np.dot(l0.T, l1_delta)\n b1 += LEARNING_RATE * np.sum(l1_delta, axis=0)\n\n mae = np.mean(np.abs(l3_delta))\n ce = np.mean((-np.log(\\\n [y[m] for m, y in zip(np.argmax(train_Y_batch, axis=1), l3)]\\\n )))\n\n samples += BATCH_SIZE\n\n if samples % 25000 < BATCH_SIZE:\n print('samples seen: %.0fK, mae: %.7f, crossentropy: %.5f, test ce: %.5f' %\n (samples / 1000., mae, ce, eval_on_test()))",
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"text": "samples seen: 25K, mae: 0.0063756, crossentropy: 5.19431, test ce: 5.41616\nsamples seen: 50K, mae: 0.0075110, crossentropy: 5.28344, test ce: 4.12838\nsamples seen: 75K, mae: 0.0024622, crossentropy: 3.27827, test ce: 3.32374\nsamples seen: 100K, mae: 0.0012724, crossentropy: 3.12343, test ce: 3.37333\nsamples seen: 125K, mae: 0.0005859, crossentropy: 3.02297, test ce: 3.35083\nsamples seen: 150K, mae: 0.0010816, crossentropy: 2.04804, test ce: 2.55793\nsamples seen: 175K, mae: 0.0004957, crossentropy: 1.97337, test ce: 2.62282\nsamples seen: 200K, mae: 0.0003062, crossentropy: 1.96530, test ce: 2.65418\nsamples seen: 225K, mae: 0.0002304, crossentropy: 1.95872, test ce: 2.67450\nsamples seen: 250K, mae: 0.0001969, crossentropy: 1.94469, test ce: 2.68437\nsamples seen: 275K, mae: 0.0002511, crossentropy: 1.90612, test ce: 2.68340\nsamples seen: 300K, mae: 0.0001725, crossentropy: 1.89080, test ce: 2.70477\nsamples seen: 325K, mae: 0.0001390, crossentropy: 1.88303, test ce: 2.72119\nsamples seen: 350K, mae: 0.0001186, crossentropy: 1.87708, test ce: 2.73362\nsamples seen: 375K, mae: 0.0001043, crossentropy: 1.87177, test ce: 2.74346\nsamples seen: 400K, mae: 0.0000934, crossentropy: 1.86671, test ce: 2.75150\nsamples seen: 425K, mae: 0.0000848, crossentropy: 1.86172, test ce: 2.75819\nsamples seen: 450K, mae: 0.0000778, crossentropy: 1.85671, test ce: 2.76384\nsamples seen: 475K, mae: 0.0000719, crossentropy: 1.85163, test ce: 2.76862\nsamples seen: 500K, mae: 0.0000669, crossentropy: 1.84641, test ce: 2.77266\nsamples seen: 525K, mae: 0.0000626, crossentropy: 1.84101, test ce: 2.77602\nsamples seen: 550K, mae: 0.0000588, crossentropy: 1.83540, test ce: 2.77877\nsamples seen: 575K, mae: 0.0000555, crossentropy: 1.82953, test ce: 2.78093\nsamples seen: 600K, mae: 0.0000525, crossentropy: 1.82334, test ce: 2.78254\nsamples seen: 625K, mae: 0.0000499, crossentropy: 1.81677, test ce: 2.78359\nsamples seen: 650K, mae: 0.0000475, crossentropy: 1.80975, test ce: 2.78408\nsamples seen: 675K, mae: 0.0000454, crossentropy: 1.80216, test ce: 2.78397\nsamples seen: 700K, mae: 0.0000435, crossentropy: 1.79388, test ce: 2.78322\nsamples seen: 725K, mae: 0.0000417, crossentropy: 1.78472, test ce: 2.78173\nsamples seen: 750K, mae: 0.0000401, crossentropy: 1.77439, test ce: 2.77937\nsamples seen: 775K, mae: 0.0000387, crossentropy: 1.76246, test ce: 2.77588\nsamples seen: 800K, mae: 0.0000371, crossentropy: 1.74789, test ce: 2.76851\nsamples seen: 825K, mae: 0.0000361, crossentropy: 1.72925, test ce: 2.76041\nsamples seen: 850K, mae: 0.0000356, crossentropy: 1.70321, test ce: 2.74830\nsamples seen: 875K, mae: 0.0000370, crossentropy: 1.65515, test ce: 2.72387\nsamples seen: 900K, mae: 0.0001488, crossentropy: 1.32677, test ce: 2.57070\nsamples seen: 925K, mae: 0.0000847, crossentropy: 1.36105, test ce: 2.59207\nsamples seen: 950K, mae: 0.0000652, crossentropy: 1.37678, test ce: 2.60391\nsamples seen: 975K, mae: 0.0000547, crossentropy: 1.38707, test ce: 2.61186\nsamples seen: 1000K, mae: 0.0000481, crossentropy: 1.39460, test ce: 2.61791\nsamples seen: 1025K, mae: 0.0000434, crossentropy: 1.40042, test ce: 2.62291\nsamples seen: 1050K, mae: 0.0000399, crossentropy: 1.40500, test ce: 2.62725\nsamples seen: 1075K, mae: 0.0000372, crossentropy: 1.40860, test ce: 2.63114\n",
"name": "stdout"
},
{
"output_type": "error",
"ename": "KeyboardInterrupt",
"evalue": "",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-11-587e0dbdd971>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0mw2\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mLEARNING_RATE\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ml2_delta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0mb2\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mLEARNING_RATE\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml2_delta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 38\u001b[0;31m \u001b[0mw1\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mLEARNING_RATE\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml0\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ml1_delta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 39\u001b[0m \u001b[0mb1\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mLEARNING_RATE\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml1_delta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"_draft": {
"nbviewer_url": "https://gist.github.com/69efc0172b58fbccfcdbf3c097006fa8"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"gist": {
"id": "69efc0172b58fbccfcdbf3c097006fa8",
"data": {
"description": "Manual Backprop",
"public": true
}
},
"language_info": {
"nbconvert_exporter": "python",
"name": "python",
"codemirror_mode": {
"version": 3,
"name": "ipython"
},
"version": "3.5.2",
"pygments_lexer": "ipython3",
"mimetype": "text/x-python",
"file_extension": ".py"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment