Skip to content

Instantly share code, notes, and snippets.

@marcospgp
Created August 19, 2022 01:09
Show Gist options
  • Save marcospgp/d4b2166ebb76a00bafe09e7e1ea0be47 to your computer and use it in GitHub Desktop.
Save marcospgp/d4b2166ebb76a00bafe09e7e1ea0be47 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "backpropagation-from-scratch.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true,
"authorship_tag": "ABX9TyOKZgNOYvWf6UcwNMEXHseB",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/marcospgp/backpropagation-from-scratch/blob/master/backpropagation-from-scratch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CKRRjmhbFcx7",
"colab_type": "text"
},
"source": [
"# Loading MNIST dataset from Keras\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "kjvPSnDA4J_w",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 265
},
"outputId": "f0d75622-d005-4b0d-f4e1-560d5d2d9f34"
},
"source": [
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"\n",
"(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(path=\"mnist.npz\")\n",
"\n",
"plt.imshow(x_train[0],cmap='gray');"
],
"execution_count": 178,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAN9klEQVR4nO3df4xV9ZnH8c+zWP6QojBrOhKKSyEGg8ZON4gbl6w1hvojGhw1TSexoZE4/YNJaLIhNewf1WwwZBU2SzTNTKMWNl1qEzUgaQouoOzGhDgiKo5LdQ2mTEaowZEf/mCHefaPezBTnfu9w7nn3nOZ5/1Kbu6957nnnicnfDi/7pmvubsATH5/VXYDAJqDsANBEHYgCMIOBEHYgSAuaubCzIxT/0CDubuNN72uLbuZ3Wpmh8zsPTN7sJ7vAtBYlvc6u5lNkfRHSUslHZH0qqQudx9IzMOWHWiwRmzZF0t6z93fd/czkn4raVkd3weggeoJ+2xJfxrz/kg27S+YWbeZ9ZtZfx3LAlCnhp+gc/c+SX0Su/FAmerZsg9KmjPm/bezaQBaUD1hf1XSlWb2HTObKulHkrYV0xaAouXejXf3ETPrkbRD0hRJT7n724V1BqBQuS+95VoYx+xAwzXkRzUALhyEHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBJF7yGZcGKZMmZKsX3rppQ1dfk9PT9XaxRdfnJx3wYIFyfrKlSuT9ccee6xqraurKznv559/nqyvW7cuWX/44YeT9TLUFXYzOyzppKSzkkbcfVERTQEoXhFb9pvc/aMCvgdAA3HMDgRRb9hd0k4ze83Musf7gJl1m1m/mfXXuSwAdah3N36Juw+a2bckvWhm/+Pue8d+wN37JPVJkpl5ncsDkFNdW3Z3H8yej0l6XtLiIpoCULzcYTezaWY2/dxrST+QdLCoxgAUq57d+HZJz5vZue/5D3f/QyFdTTJXXHFFsj516tRk/YYbbkjWlyxZUrU2Y8aM5Lz33HNPsl6mI0eOJOsbN25M1js7O6vWTp48mZz3jTfeSNZffvnlZL0V5Q67u78v6bsF9gKggbj0BgRB2IEgCDsQBGEHgiDsQBDm3rwftU3WX9B1dHQk67t3707WG32baasaHR1N1u+///5k/dSpU7mXPTQ0lKx//PHHyfqhQ4dyL7vR3N3Gm86WHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeC4Dp7Adra2pL1ffv2Jevz5s0rsp1C1ep9eHg4Wb/pppuq1s6cOZOcN+rvD+rFdXYgOMIOBEHYgSAIOxAEYQeCIOxAEIQdCIIhmwtw/PjxZH316tXJ+h133JGsv/7668l6rT+pnHLgwIFkfenSpcn66dOnk/Wrr766am3VqlXJeVEstuxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EAT3s7eASy65JFmvNbxwb29v1dqKFSuS8953333J+pYtW5J1tJ7c97Ob2VNmdszMDo6Z1mZmL5rZu9nzzCKbBVC8iezG/1rSrV+Z9qCkXe5+paRd2XsALaxm2N19r6Sv/h50maRN2etNku4quC8ABcv72/h2dz83WNaHktqrfdDMuiV151wOgILUfSOMu3vqxJu790nqkzhBB5Qp76W3o2Y2S5Ky52PFtQSgEfKGfZuk5dnr5ZK2FtMOgEapuRtvZlskfV/SZWZ2RNIvJK2T9DszWyHpA0k/bGSTk92JEyfqmv+TTz7JPe8DDzyQrD/zzDPJeq0x1tE6aobd3buqlG4uuBcADcTPZYEgCDsQBGEHgiDsQBCEHQiCW1wngWnTplWtvfDCC8l5b7zxxmT9tttuS9Z37tyZrKP5GLIZCI6wA0EQdiAIwg4EQdiBIAg7EARhB4LgOvskN3/+/GR9//79yfrw8HCyvmfPnmS9v7+/au2JJ55IztvMf5uTCdfZgeAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIrrMH19nZmaw//fTTyfr06dNzL3vNmjXJ+ubNm5P1oaGhZD0qrrMDwRF2IAjCDgRB2IEgCDsQBGEHgiDsQBBcZ0fSNddck6xv2LAhWb/55vyD/fb29ibra9euTdYHBwdzL/tClvs6u5k9ZWbHzOzgmGkPmdmgmR3IHrcX2SyA4k1kN/7Xkm4dZ/q/untH9vh9sW0BKFrNsLv7XknHm9ALgAaq5wRdj5m9me3mz6z2ITPrNrN+M6v+x8gANFzesP9S0nxJHZKGJK2v9kF373P3Re6+KOeyABQgV9jd/ai7n3X3UUm/krS42LYAFC1X2M1s1pi3nZIOVvssgNZQ8zq7mW2R9H1Jl0k6KukX2fsOSS7psKSfunvNm4u5zj75zJgxI1m/8847q9Zq3StvNu7l4i/t3r07WV+6dGmyPllVu85+0QRm7Bpn8pN1dwSgqfi5LBAEYQeCIOxAEIQdCIKwA0FwiytK88UXXyTrF12Uvlg0MjKSrN9yyy1Vay+99FJy3gsZf0oaCI6wA0EQdiAIwg4EQdiBIAg7EARhB4KoedcbYrv22muT9XvvvTdZv+6666rWal1Hr2VgYCBZ37t3b13fP9mwZQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBILjOPsktWLAgWe/p6UnW77777mT98ssvP++eJurs2bPJ+tBQ+q+Xj46OFtnOBY8tOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EwXX2C0Cta9ldXeMNtFtR6zr63Llz87RUiP7+/mR97dq1yfq2bduKbGfSq7llN7M5ZrbHzAbM7G0zW5VNbzOzF83s3ex5ZuPbBZDXRHbjRyT9o7svlPR3klaa2UJJD0ra5e5XStqVvQfQomqG3d2H3H1/9vqkpHckzZa0TNKm7GObJN3VqCYB1O+8jtnNbK6k70naJ6nd3c/9OPlDSe1V5umW1J2/RQBFmPDZeDP7pqRnJf3M3U+MrXlldMhxB2109z53X+Tui+rqFEBdJhR2M/uGKkH/jbs/l00+amazsvosScca0yKAItTcjTczk/SkpHfcfcOY0jZJyyWty563NqTDSaC9fdwjnC8tXLgwWX/88ceT9auuuuq8eyrKvn37kvVHH320am3r1vQ/GW5RLdZEjtn/XtKPJb1lZgeyaWtUCfnvzGyFpA8k/bAxLQIoQs2wu/t/Sxp3cHdJNxfbDoBG4eeyQBCEHQiCsANBEHYgCMIOBMEtrhPU1tZWtdbb25uct6OjI1mfN29erp6K8MorryTr69evT9Z37NiRrH/22Wfn3RMagy07EARhB4Ig7EAQhB0IgrADQRB2IAjCDgQR5jr79ddfn6yvXr06WV+8eHHV2uzZs3P1VJRPP/20am3jxo3JeR955JFk/fTp07l6Quthyw4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQYS5zt7Z2VlXvR4DAwPJ+vbt25P1kZGRZD11z/nw8HByXsTBlh0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgjB3T3/AbI6kzZLaJbmkPnf/NzN7SNIDkv6cfXSNu/++xnelFwagbu4+7qjLEwn7LEmz3H2/mU2X9Jqku1QZj/2Uuz820SYIO9B41cI+kfHZhyQNZa9Pmtk7ksr90ywAztt5HbOb2VxJ35O0L5vUY2ZvmtlTZjazyjzdZtZvZv11dQqgLjV347/8oNk3Jb0saa27P2dm7ZI+UuU4/p9V2dW/v8Z3sBsPNFjuY3ZJMrNvSNouaYe7bxinPlfSdne/psb3EHagwaqFveZuvJmZpCclvTM26NmJu3M6JR2st0kAjTORs/FLJP2XpLckjWaT10jqktShym78YUk/zU7mpb6LLTvQYHXtxheFsAONl3s3HsDkQNiBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQii2UM2fyTpgzHvL8umtaJW7a1V+5LoLa8ie/ubaoWm3s/+tYWb9bv7otIaSGjV3lq1L4ne8mpWb+zGA0EQdiCIssPeV/LyU1q1t1btS6K3vJrSW6nH7ACap+wtO4AmIexAEKWE3cxuNbNDZvaemT1YRg/VmNlhM3vLzA6UPT5dNobeMTM7OGZam5m9aGbvZs/jjrFXUm8Pmdlgtu4OmNntJfU2x8z2mNmAmb1tZquy6aWuu0RfTVlvTT9mN7Mpkv4oaamkI5JeldTl7gNNbaQKMzssaZG7l/4DDDP7B0mnJG0+N7SWmf2LpOPuvi77j3Kmu/+8RXp7SOc5jHeDeqs2zPhPVOK6K3L48zzK2LIvlvSeu7/v7mck/VbSshL6aHnuvlfS8a9MXiZpU/Z6kyr/WJquSm8twd2H3H1/9vqkpHPDjJe67hJ9NUUZYZ8t6U9j3h9Ra4337pJ2mtlrZtZddjPjaB8zzNaHktrLbGYcNYfxbqavDDPeMusuz/Dn9eIE3dctcfe/lXSbpJXZ7mpL8soxWCtdO/2lpPmqjAE4JGl9mc1kw4w/K+ln7n5ibK3MdTdOX01Zb2WEfVDSnDHvv51NawnuPpg9H5P0vCqHHa3k6LkRdLPnYyX38yV3P+ruZ919VNKvVOK6y4YZf1bSb9z9uWxy6etuvL6atd7KCPurkq40s++Y2VRJP5K0rYQ+vsbMpmUnTmRm0yT9QK03FPU2Scuz18slbS2xl7/QKsN4VxtmXCWvu9KHP3f3pj8k3a7KGfn/lfRPZfRQpa95kt7IHm+X3ZukLars1v2fKuc2Vkj6a0m7JL0r6T8ltbVQb/+uytDeb6oSrFkl9bZElV30NyUdyB63l73uEn01Zb3xc1kgCE7QAUEQdiAIwg4EQdiBIAg7EARhB4Ig7EAQ/w8ie3GmjcGk5QAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Klx9qPmxF9jI",
"colab_type": "text"
},
"source": [
"# Utility functions"
]
},
{
"cell_type": "code",
"metadata": {
"id": "sdyvaUKoF7ux",
"colab_type": "code",
"colab": {}
},
"source": [
"import numpy as np\n",
"\n",
"def sigmoid(x):\n",
" # Numerically stable sigmoid function based on\n",
" # http://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/\n",
" \n",
" x = np.clip(x, -500, 500) # We get an overflow warning without this\n",
" \n",
" return np.where(\n",
" x >= 0,\n",
" 1 / (1 + np.exp(-x)),\n",
" np.exp(x) / (1 + np.exp(x))\n",
" )\n",
"\n",
"def dsigmoid(x): # Derivative of sigmoid\n",
" return sigmoid(x) * (1 - sigmoid(x))\n",
"\n",
"def softmax(x):\n",
" # Numerically stable softmax based on (same source as sigmoid)\n",
" # http://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/\n",
" b = x.max()\n",
" y = np.exp(x - b)\n",
" return y / y.sum()\n",
"\n",
"def cross_entropy_loss(y, yHat):\n",
" return -np.sum(y * np.log(yHat))\n",
"\n",
"def integer_to_one_hot(x, max):\n",
" # x: integer to convert to one hot encoding\n",
" # max: the size of the one hot encoded array\n",
" result = np.zeros(10)\n",
" result[x] = 1\n",
" return result"
],
"execution_count": 179,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "_xIZEupTHyNM",
"colab_type": "text"
},
"source": [
"# Initialize architecture, weights, and biases"
]
},
{
"cell_type": "code",
"metadata": {
"id": "zBeGvbu6FaM_",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 865
},
"outputId": "af98eda1-651f-4d66-ebad-4f4885cd0c8d"
},
"source": [
"import math\n",
"\n",
"# Initialize weights of each layer with a normal distribution of mean 0 and\n",
"# standard deviation 1/sqrt(n), where n is the number of inputs.\n",
"# This means the weighted input will be a random variable itself with mean\n",
"# 0 and standard deviation close to 1 (if biases are initialized as 0, standard\n",
"# deviation will be exactly 1)\n",
"\n",
"from numpy.random import default_rng\n",
"\n",
"rng = default_rng(80085)\n",
"\n",
"# Neural network layer sizes: 784 -> 32 -> 32 -> 10\n",
"\n",
"weights = [\n",
" rng.normal(0, 1/math.sqrt(784), (32, 784)),\n",
" rng.normal(0, 1/math.sqrt(32), (32, 32)),\n",
" rng.normal(0, 1/math.sqrt(32), (10, 32))\n",
"]\n",
"\n",
"biases = [np.zeros(32), np.zeros(32), np.zeros(10)]\n",
"\n",
"# Plot histogram of layer weights to check probability distribution\n",
"\n",
"print(\"Weight distribution per layer:\")\n",
"for index, layer in enumerate(weights):\n",
" plt.figure()\n",
" plt.suptitle(\n",
" \"Layer \" + str(index + 1) + \" with \" + str(layer.shape[0]) +\n",
" \" neurons, \" + str(layer.shape[1]) + \" inputs each (\" + str(layer.size) +\n",
" \" weights in total)\"\n",
" )\n",
" plt.hist(layer.flatten(), bins=100);\n"
],
"execution_count": 180,
"outputs": [
{
"output_type": "stream",
"text": [
"Weight distribution per layer:\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZ0AAAEVCAYAAAA7PDgXAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3debgcVb3u8e8rgaBMCRBjSIIBiSh4NOIWcECRoMwGFREHjBifqAc8KngF9XpEj1dBvSI8HsUcORIUQQQxkZOLMgVxAN0RZBSyQTAJCdmEJISZyO/+sVaTSrN7j13VvXfez/P001WrVlWtGn+1VlVXKyIwMzOrwvNaXQAzM9t0OOiYmVllHHTMzKwyDjpmZlYZBx0zM6uMg46ZmVXGQadA0iOSdu1l+L2SDqyyTNZckvaTdGery9EKks6V9NUBjvN7Sa8uq0ztaiD7iaT9JS0tqRxnS/piGdMeKkmnSvpJ7h4v6Q5Jo/sar8+g044nWkkTJM2XdL+kkDSlGdONiK0j4p48jwEfoHVlfIukWyStkbRK0qWSJhaGf0vSYknrJP1N0gebsQzDkaSdc8AvfkLSSYU8n5D0d0kPS+qU9MYeprNF3vEbngAi4rqI2L2sZakrT0jarYp5lUHSEcC6iLgx98+UtChvg6WSviFpVCH/QklPFLbhnXXTe5+k+yQ9KumXkrYvDJsiaYGk1ZJWSPpu3bQPkPSXPO97JM0uc9mbuZ8M5VwSER+LiP8Y5HwXSvpIWfmLIuIB4Bqgz+3S9jWd4o5X8AxwOfCuioszELcDB0XEGGAnYDHw/cLwR4EjgO2AmcCZkl5feSnrSNqs6nlGxD9ywN86IrYG/oW0jS/JZdoHOA04irS+zgEu7aGs/wvorq7kI97HgB8X+l8AfArYEdgHmA58pm6cEwrb8tmTtqQ9gR8AxwLjgceA7xXG+x6wEpgATAPeDPxrHndz4NI8/nbAe4BvS3pVcxbTmuR84KN95oqIXj/AvcCBPaSPBS4jHeSrc/ekPOzdwKK6/CcC83L3aOBbwD+AB4CzgefnYfsDS4GTgRXAj3sp2ygggCm95DkO+FWhfzHw80L/EmBa7g5gN1K0fhp4CnikNn5eF58BbgbWAj8DtuzHOhwNfB24vZc884GTGgyrrZOTSAfmcuC4uuk3Wp8fAn5XN70Adsvd55KC4QJSIDwQeDmwEFgD3Aa8vTDuucB/Av8DrANuAF6Shwk4I5fxYeAW4BV9rZ8elvdLwDWF/vcAfyr0b5WXYUIhbRfgDuAQYGkv096/OLy3bVpY758HHsx5318YdyHwkUL/s+sa+G0u46N5H3oP6WR9WV6vDwHXAc9rUM6XAVfkfHcCRxeGHQbcmNfxEuDUunHfCPwhz2cJ8KG+tl0P898CeJx8TDfIcyIbH1sbrY+6vF8Dflrofwnp+Nom998BHFoY/k3gB7l7fF6XLygM/zPw3h7ms2Uu9465/wvAemDb3P8fwHf6ex4qTHevvM7XAT/P+8lX+zo+aXwuORlYlqd3JzC9wXo7tz/z6WG8/wP8E3giz/e7Of31ed2tzd+v7yP/mXkfehhYBOxXmMepwE8K/aNIFxMv7vX47scJ4F56Djo7kGoaLwC2yRvil4WN+RDw8kL+G4F35e4zSCfZ7fO4vwK+Xlix64HT83Se30vZ+hN0diUdfM8j1Tjuq+1Medhq8oHPc0/GX+1hXfwpT2d70oHysV7mvXOe9zN5x/tQg3zPzzvQwQ2G19bJV4DNgUPzxh3bj/X5IfoOOmuBN+R1tA3QRTrRbgEcQDowdi/kXwXsndf/+cCFedhBecccQwpAL6cQGPrzyePdXVxXwLZ5uvsAmwGfyPuTCnkuA95B3cmiwbqsDzo9btPCev923hffTAoitXWxkAZBp3495/6vk05sm+fPfsVlKOTbinSgH5fX8atJQW+PQrn+JW+vV5JOmEfmYS/O2+u9eR47sOGiquG266EMewKP9rGtfgmcVuhfSLoIfRD4PbB/Ydg84OS68R8BXpO7PwqcRzqfTARuBd5RyPtT4Pi8/V9HOulOblCu37LhXPObvD8dUhj2jn6eh2rniS1I541P5nX6TlIQKQaD3o7PcymcS4Dd8/bdKfdPoXHwP7e/8+lh3IVsvH9uTzrfHZu3/3tz/w495c9pH8j70ChSsFvBhouyUykEnZx2M4WL1J4+g25ei4hVEXFJRDwWEetIkfLNediTpCuBD8CzVespwGWSRIr+n46Ih/K4XwOOKUz+GeBLEfFkRDw+2DLmstxDOginAW8Cfg3cL+llubzXRcQzA5jkWRFxf0Q8RNpJp/Uy739Eal7bEfjfwN8aZD0b+GsuWyNPA1+JiKcjYgHpgN29n+uzL/Mi4vd5PUwDtiadTJ6KiKtJJ/T3FvJfGhF/ioj1pBNXbR08TTp4X0Y6md4REcsHUA5IV+njgYsLaetITW2/A54k1YRmRy1KSe8ANouISwc4r5q+tukX8754LamWcPQg5/M0qfnoxXk7XldbhjqHA/dGxI8iYn2keyqXkFoQiIiFEXFLRDwTETcDF5CPPeB9wJURcUGex6qIuKkw7Ubbrt4Y0nrvkaQPAx2kmkLNyaQLuYnAHOBXkl6Sh21NurgpWkvaXyAFgz1JV9RLgU5SUKu5APh30va/DvhCRCxpULxrgTfnpvlXAmfl/i2B1wK/HeBxsy/ppHtWXqe/IF2oFPV4fDYo3z9JFzF7SNo8Iu6NiLsb5K03kPnUOwxYHBE/zvvVBaRz0hGNRoiIn+R9aH1E/N9c7t7mt4607zQ06KAj6QWSfpBvDD5M2mnGFNrZ5wLvyxv3WOCiHIzGka5mFuWb7GtI92fGFSbfHRFPDLZsPbiWdJXwpty9kHSQvjn3D8SKQvdjpIOpV/lkNheYV3+PStI3gVeQmk96OgHVrMonivp592d99qV48O4ELKkLxPeRTiQ1Pa6DHKC+S2rCWSlpjqRtB1AOSPe3LomIRwpps0hX/XuSrjo/QLqA2UnSVsA3gH8b4HyKetumqyPi0UL/faR1NBjfJNUif5Nvhp/SIN+LgX1q2zNv0/cDL4J0j0vSNZK6Ja0l3XvZMY87mXRl30h/99/VbAgIG5F0JKnWdkhEPFhLj4gbImJdDtBzSbWdQ/PgR0g11qJtgXWSnkfaZ39BquXtSGq+Pz3P72XAhcAHSdt/T+Czkg5rUPba8b4XqYn3CtKxvi/QFRGrGNhxsxOwrO74rA94jY7P54iILtK9sVNJx8mFkvq7T/V7Pj2otfQU1R/bG5H0mfxwztq8jrZjw77Wk21IrTsNDeVBgpNIEW+fiNiWdEKH1DxCRFxPqoLuR7r6qt2QfJDU5rpnRIzJn+0i3UCu6e3kOxi1nXC/3H0tfQedZpdhFPBCCgeepC+T7kG8LSIeHuR0+1qfj5IOrto8X9TDNIrLej8wOZ8IanYmtT/3KSLOiojXAHsALyXd3O8XSc8nXc3PrRs0DbgsIu7KV/eXk5ojXw9MJdWir5O0gnTimpCfgJrS33n3YmwObDU7k9YR1K1bclBoJJ+QT4qIXYG3AydKmt5D1iXAtYXtOSbSjfmP5+E/JTULTY6I7Ug1ZRXGfclzJzlgXYBUeOKSlHAw8F/AERFxSx/TiEK5bgOevfGv9NOE0cBdpGafnUn3EZ7MQeFHbAhYrwDuiohf5+1/J6nGeUiD+f6BdG56B2k93p6nfygbjvf+nIdqlgMT8wV0zeQ+lr3oOeeSiPhpRLyRdIER5ADbZPXzvT/Pr6h4bG+UX9J+wGdJNfuxudVmLRu2KXX5R5Huif+1t0L1N+hsLmnLwmcUKaI9DqzJjz5+qYfxziNd+T4dEb8DyFfQ/wWcIemFubATJR3Uz7KQx9mStNMCjM79jVwLvIV0f2gpqXp+MKmt8sYG4zxAaioYFEnvlLS7pOdJGke6L3BjrvUg6XOkYHxgPsgGpR/r86/AnpKm5XV0ah+TvIF09fRZSZtL2p9U/b6wr7JIem2+Ct+cdEJ+gtRUiqQPSbq3j0m8g3SFfU1d+p+BwyTtquStpIB2a/5MJgWmacBHSNtuGs+9Gh2sLys9jr0fqenr5zn9JuCduda/G6lGVrTRPiTpcEm75ZPXWlIzS09Nu5cBL5V0bN4Gm+d1+/I8fBvgoYh4QtLepP2o5nzgQElHSxolaQdJDZuAG4mIp4Ar2dBsh6QD8vTfFREbNS9JGiPpoNr5QdL7SReilxfKdYTS71+2It2X+EUOxA8Cfwc+nscdQ6rx3pzHvRGYqvTYtJSa7A4vDK8v+2Oke4DHsyHI/IFUI7w25xnIeeiPpG11Qi7fDNJ9sf6q3w92z8symnSMPE7P+8FQ1Z/DFpD2q/fl5XgP6eLwsgb5tyHdQ+oGRkn6d55bWy3am9QsXF+b2kh/g84C0oqpfU4FvkO6Af4gcD0bdq6iH5OuUn5Sl34y6UrqeqWmuSvpf7tkzeOkKjukdsmG934i4q6c97rc/zBwD/D7iPhng9HOIbW5rpH0ywZ5ejORtE7Wkar4z5BOqjVfI11ldGnD7xo+P4j5QC/rMy/7V3LaYtJ9kYbyyeYI0lXkg6RHWT8YEY3uRxVtSzqQV5Oq7atITUqQAsPv+xh/JulpxfortPNIQW8hqc3/LOCjEfG33Na8ovYhPcDyTO5vtG0HYkVenvtJJ86PFdbFGaTa/AOk2tn5deOeCszN+9DRpFrZlaR98Y/A9yKiPsCS7y+8jXR/4f5chtqDNZAeJf6KpHWk+xwXFcb9B+mK/iTSuriJQg1jgGqPONd8kdS8sqCwz/6/PGxz4KtseJDgE6SHG+7K5bqNdNI/n/QQwDZ5OWreSboQ7Cbty08Dn87j3g18mLTdHyYFjkuAH/ZS9mtzmf5U6N+GdBugpl/noXxMvJN0UbGG3LxLur/UH/XnktGknwA8SNq2LwQ+189pDcSZwFFKv306K1/cHk7aN1aRajGHF5pIN8pPusd8Oak2eh8pQPZ2Ifd+Uq27V+r9NsLQ5OaSlcBeEbG4tBlZ25P0G+CTEXFHq8vSX7mW95OImNTqsrSKpN+TfnvTqEVgkyTpBuDsiPhRq8vSDnJt8Vrg1X3dj+/ph5fN9HHgzw44FhFva3UZbOAi4g2tLkM7kPRm0u9pHiRd0b+Snlt3NkkRsZL0E4k+lRZ0cvu9gCPLmoeZWUV2JzVjbkVqmj8qBv5zAKPk5jUzM7Oitn/3mpmZjRwOOmZmVhkHHTMzq4yDjpmZVcZBx8zMKuOgY2ZmlXHQMTOzyjjomJlZZRx0zMysMg46ZmZWGQcdMzOrjIOOmZlVxkHHzMwq46BjZmaVKftP3Eqx4447xpQpU1pdDDOzYWXRokUPRsS4VpZhWAadKVOm0NnZ2epimJkNK5Lua3UZ3LxmZmaVcdAxM7PKOOiYmVllHHTMzKwypQQdSZ+WdJukWyVdIGlLSbtIukFSl6SfSdoi5x2d+7vy8ClllMnMzFqv6UFH0kTg34COiHgFsBlwDHA6cEZE7AasBmblUWYBq3P6GTmfmZmNQGU1r40Cni9pFPACYDlwAHBxHj4XODJ3z8j95OHTJamkcpmZWQs1PehExDLgW8A/SMFmLbAIWBMR63O2pcDE3D0RWJLHXZ/z71A/XUmzJXVK6uzu7m52sc3MrAJlNK+NJdVedgF2ArYCDh7qdCNiTkR0RETHuHEt/UGtmZkNUhlvJDgQ+HtEdANI+gXwBmCMpFG5NjMJWJbzLwMmA0tzc9x2wKoSymVWiimn/M+z3feedlgLS2LW/sq4p/MPYF9JL8j3ZqYDtwPXAEflPDOBebl7fu4nD786IqKEcpmZWYuVcU/nBtIDAX8BbsnzmAOcDJwoqYt0z+acPMo5wA45/UTglGaXyczM2kMpL/yMiC8BX6pLvgfYu4e8TwDvLqMcZmbWXvxGAjMzq4yDjpmZVWZY/p+OWSv4KTWzoXNNx8zMKuOajlkvirUbMxs613TMzKwyrumY1RlK7abRuL4HZJa4pmNmZpVx0DEzs8o46JiZWWV8T8dsEPxUm9nguKZjZmaVcdAxM7PKOOiYmVllHHTMzKwyDjpmZlYZP71mhp9GM6tK02s6knaXdFPh87CkT0naXtIVkhbn77E5vySdJalL0s2S9mp2mczMrD00PehExJ0RMS0ipgGvAR4DLgVOAa6KiKnAVbkf4BBgav7MBr7f7DKZmVl7KLt5bTpwd0TcJ2kGsH9OnwssBE4GZgDnRUQA10saI2lCRCwvuWxmlfEfwJklZT9IcAxwQe4eXwgkK4DxuXsisKQwztKcthFJsyV1Surs7u4uq7xmZlai0oKOpC2AtwM/rx+WazUxkOlFxJyI6IiIjnHjxjWplGZmVqUym9cOAf4SEQ/k/gdqzWaSJgArc/oyYHJhvEk5zawp3LRl1j7KDDrvZUPTGsB8YCZwWv6eV0g/QdKFwD7AWt/PsSr4MWmz6pUSdCRtBbwV+Ggh+TTgIkmzgPuAo3P6AuBQoIv0pNtxZZTJrF245mWbslKCTkQ8CuxQl7aK9DRbfd4Aji+jHGZm1l78GhwzM6uMX4NjmxTfxzFrLdd0zMysMq7pmLVQfc3LDxbYSOeajpmZVcZBx8zMKuOgY2ZmlXHQMTOzyjjomJlZZfz0mo1I/j2OWXty0DFrI34vm410bl4zM7PKOOiYmVllHHTMzKwyDjpmZlYZBx0zM6uMg46ZmVWmlKAjaYykiyX9TdIdkl4naXtJV0hanL/H5rySdJakLkk3S9qrjDKZmVnrlVXTORO4PCJeBrwKuAM4BbgqIqYCV+V+gEOAqfkzG/h+SWUyM7MWa/qPQyVtB7wJ+BBARDwFPCVpBrB/zjYXWAicDMwAzouIAK7PtaQJEbG82WWzkc1vITBrf2XUdHYBuoEfSbpR0g8lbQWMLwSSFcD43D0RWFIYf2lOMzOzEaaM1+CMAvYCPhERN0g6kw1NaQBEREiKgUxU0mxS8xs777xzs8pqw9xIrt00Wja/HseGszKCzlJgaUTckPsvJgWdB2rNZpImACvz8GXA5ML4k3LaRiJiDjAHoKOjY0ABy2wk8fvZbDhrevNaRKwAlkjaPSdNB24H5gMzc9pMYF7ung98MD/Fti+w1vdzzMxGprLeMv0J4HxJWwD3AMeRAtxFkmYB9wFH57wLgEOBLuCxnNfMzEagUoJORNwEdPQwaHoPeQM4voxy2Mg0ku/jmI10fiOBmZlVxkHHzMwq46BjZmaVcdAxM7PKOOiYmVllHHTMzKwyDjpmZlYZBx0zM6uMg46ZmVXGQcfMzCrjoGNmZpVx0DEzs8o46JiZWWXK+msDsyHzn5WZjTyu6ZiZWWVc07Fhwf+hYzYyOOiYDWNugrThppTmNUn3SrpF0k2SOnPa9pKukLQ4f4/N6ZJ0lqQuSTdL2quMMpmZWeuVeU/nLRExLSJqf1t9CnBVREwFrsr9AIcAU/NnNvD9EstkZmYtVOWDBDOAubl7LnBkIf28SK4HxkiaUGG5zMysImUFnQB+I2mRpNk5bXxELM/dK4DxuXsisKQw7tKcZmZmI0xZDxK8MSKWSXohcIWkvxUHRkRIioFMMAev2QA777xz80pqZmaVKaWmExHL8vdK4FJgb+CBWrNZ/l6Zsy8DJhdGn5TT6qc5JyI6IqJj3LhxZRTbzMxK1vSgI2krSdvUuoG3AbcC84GZOdtMYF7ung98MD/Fti+wttAMZ2ZmI0gZzWvjgUsl1ab/04i4XNKfgYskzQLuA47O+RcAhwJdwGPAcSWUyczM2kDTg05E3AO8qof0VcD0HtIDOL7Z5TAzs/bjNxKYjRB+O4ENB37hp5mZVcY1HbMRyLUea1eu6ZiZWWUcdMzMrDIOOmZmVhkHHTMzq4yDjpmZVcZBx8zMKuOgY2ZmlXHQMTOzyjjomJlZZfxGAmsrxV/Sm9nI45qOmZlVxkHHzMwq46BjZmaVcdAxM7PKlBZ0JG0m6UZJl+X+XSTdIKlL0s8kbZHTR+f+rjx8SlllMjOz1iqzpvNJ4I5C/+nAGRGxG7AamJXTZwGrc/oZOZ+ZmY1ApQQdSZOAw4Af5n4BBwAX5yxzgSNz94zcTx4+Pec3M7MRpqyazneAzwLP5P4dgDURsT73LwUm5u6JwBKAPHxtzm9mZiNM04OOpMOBlRGxqMnTnS2pU1Jnd3d3MydtZmYVKaOm8wbg7ZLuBS4kNaudCYyRVHsDwiRgWe5eBkwGyMO3A1bVTzQi5kRER0R0jBs3roRim5lZ2ZoedCLicxExKSKmAMcAV0fE+4FrgKNytpnAvNw9P/eTh18dEdHscpmZWetV+Tudk4ETJXWR7tmck9PPAXbI6ScCp1RYJjMzq5CGY6Wio6MjOjs7W10MGwK/2LM17j3tsFYXwVpI0qKI6GhlGfyWabNNSDHYOwBZK/g1OGZmVhnXdMw2Ua71WCu4pmNmZpVx0DEzs8o46JiZWWUcdMzMrDIOOmZmVhkHHTMzq4yDjpmZVca/07FK+LU3Zgau6ZiZWYVc0zEzv53AKuOajpmZVcZBx8zMKuOgY2ZmlXHQMTOzyjQ96EjaUtKfJP1V0m2SvpzTd5F0g6QuST+TtEVOH537u/LwKc0uk5mZtYcyajpPAgdExKuAacDBkvYFTgfOiIjdgNXArJx/FrA6p5+R85mZ2QjU9KATySO5d/P8CeAA4OKcPhc4MnfPyP3k4dMlqdnlMjOz1ivlno6kzSTdBKwErgDuBtZExPqcZSkwMXdPBJYA5OFrgR3KKJeZmbVWKUEnIv4ZEdOAScDewMuGOk1JsyV1Surs7u4echnNzKx6pT69FhFrgGuA1wFjJNXegDAJWJa7lwGTAfLw7YBVPUxrTkR0RETHuHHjyiy2mZmVpOmvwZE0Dng6ItZIej7wVtLDAdcARwEXAjOBeXmU+bn/j3n41RERzS6XVc8v+TSzemW8e20CMFfSZqSa1EURcZmk24ELJX0VuBE4J+c/B/ixpC7gIeCYEspkZmZtoOlBJyJuBl7dQ/o9pPs79elPAO9udjnMzKz9+I0EZmZWGf+1gTWV7+OYWW9c0zEzs8o46JiZWWUcdMzMrDK+p2NmG/FfV1uZXNMxM7PKOOiYmVll3LxmZg25qc2azTUdMzOrjIOOmZlVxs1rNihudjGzwXDQMbN+8YWGNYOb18zMrDIOOmZmVhk3r9mQ+c3SZtZfrumYmVllmh50JE2WdI2k2yXdJumTOX17SVdIWpy/x+Z0STpLUpekmyXt1ewymZlZeyijprMeOCki9gD2BY6XtAdwCnBVREwFrsr9AIcAU/NnNvD9EspkZmZtoOlBJyKWR8Rfcvc64A5gIjADmJuzzQWOzN0zgPMiuR4YI2lCs8tlZmatV+o9HUlTgFcDNwDjI2J5HrQCGJ+7JwJLCqMtzWlmZjbClBZ0JG0NXAJ8KiIeLg6LiABigNObLalTUmd3d3cTS2pmZlUp5ZFpSZuTAs75EfGLnPyApAkRsTw3n63M6cuAyYXRJ+W0jUTEHGAOQEdHx4ACljWHH402s6FqetCRJOAc4I6I+HZh0HxgJnBa/p5XSD9B0oXAPsDaQjOcmbU5vx7HBqKMms4bgGOBWyTdlNM+Two2F0maBdwHHJ2HLQAOBbqAx4DjSiiTmTWRa702WE0POhHxO0ANBk/vIX8Axze7HGZm1n78RgIzM6uMg46ZmVXGQcfMzCrjt0xbr3zD2MyayTUdMzOrjGs6ZtY0/s2O9cU1HTMzq4yDjpmZVcZBx8zMKuOgY2ZmlfGDBPYcfkzazMrioGNmpfCTbNYTBx0DXLsxs2r4no6ZmVXGNR0zK52b2qzGQWcT5iY1M6uam9fMzKwyTQ86kv5b0kpJtxbStpd0haTF+XtsTpeksyR1SbpZ0l7NLo+ZmbWPMmo65wIH16WdAlwVEVOBq3I/wCHA1PyZDXy/hPKYmVmbaHrQiYjfAg/VJc8A5ubuucCRhfTzIrkeGCNpQrPLZGZm7aGqBwnGR8Ty3L0CGJ+7JwJLCvmW5rTlmNmI56faNj2VP0gQEQHEQMeTNFtSp6TO7u7uEkpmZmZlq6qm84CkCRGxPDefrczpy4DJhXyTctpzRMQcYA5AR0fHgIOWJX5M2sxaqaqgMx+YCZyWv+cV0k+QdCGwD7C20AxnZiOQL3w2bU0POpIuAPYHdpS0FPgSKdhcJGkWcB9wdM6+ADgU6AIeA45rdnnMzKx9ND3oRMR7Gwya3kPeAI5vdhlsY76yNLN24TcSmJlZZfzuNTNrC73VyP049cjhmo6ZmVXGNZ1hzj+uM7PhxDUdMzOrjIOOmZlVxkHHzMwq43s6I4jv75hZu3PQGaH8g1AbSXxBNXI46AxDDihmNlw56JjZsNLooss1oOHBQWeYcO3GzEYCB50247ZrMxvJHHTamGs3Zv3nC7bhwb/TMTOzyjjomJlZZdy81gbcjGZWnkbNbm6Oaw2lP+9scSGkg4Ezgc2AH0bEab3l7+joiM7OzkrKVhYHGrP2NJIDkKRFEdHRyjK0vKYjaTPgP4G3AkuBP0uaHxG3t7Zkg+OrJzOzxloedIC9ga6IuAdA0oXADKCtg05/gotrM2bDz0CPW19cDkw7BJ2JwJJC/1Jgn7JmNtCaSH92QAcXs03XUN6QsCm2jLRD0OkXSbOB2bn3EUl3Dnmapw91CkO2I/BgqwtRIi/f8OblG4KBnl9KOB/1tHwvbvpcBqgdgs4yYHKhf1JO20hEzAHmVFWoKkjqbPVNvTJ5+YY3L9/w1q7L1w6/0/kzMFXSLpK2AI4B5re4TGZmVoKW13QiYr2kE4Bfkx6Z/u+IuK3FxTIzsxK0POgARMQCYEGry9ECI6q5sAdevuHNyze8teXytcWPQ83MbNPQDvd0zMxsE+GgUzJJ20u6QtLi/D22Qb7LJa2RdFld+i6SbpDUJeln+WGLtjGA5ZuZ8yyWNLOQvlDSnZJuyp8XVlf6xiQdnMvVJemUHoaPztujK2+fKYVhn8vpd0o6qMpy99dgl0/SFEmPF7bX2VWXvS/9WLY3SfqLpPWSjqob1uN+2k6GuHz/LGy71jywFRH+lPgBvgGckrtPAU5vkG86cARwWWjJ8xoAAANRSURBVF36RcAxufts4OOtXqaBLh+wPXBP/h6bu8fmYQuBjlYvR115NwPuBnYFtgD+CuxRl+dfgbNz9zHAz3L3Hjn/aGCXPJ3NWr1MTVy+KcCtrV6GIS7bFOCVwHnAUf3ZT9vlM5Tly8MeafUyuKZTvhnA3Nw9Fziyp0wRcRWwrpgmScABwMV9jd9C/Vm+g4ArIuKhiFgNXAEcXFH5BuPZVzNFxFNA7dVMRcXlvhiYnrfXDODCiHgyIv4OdOXptZOhLF+763PZIuLeiLgZeKZu3OGwnw5l+dqCg075xkfE8ty9Ahg/gHF3ANZExPrcv5T02qB20p/l6+lVR8Xl+FGu7n+xTU5sfZV3ozx5+6wlba/+jNtqQ1k+gF0k3SjpWkn7lV3YARrK+h8p2643W0rqlHS9pJZcwLbFI9PDnaQrgRf1MOgLxZ6ICEnD7nHBkpfv/RGxTNI2wCXAsaRmAWtPy4GdI2KVpNcAv5S0Z0Q83OqCWb+8OB9vuwJXS7olIu6usgAOOk0QEQc2GibpAUkTImK5pAnAygFMehUwRtKofLXZ4yuCytaE5VsG7F/on0S6l0NELMvf6yT9lNR80Oqg059XM9XyLJU0CtiOtL369VqnFhv08kW6MfAkQEQsknQ38FKgXf7gaijrv+F+2kaGtH8Vjrd7JC0EXk26R1QZN6+Vbz5QewpmJjCvvyPmA/waoPYEyoDGr0h/lu/XwNskjc1Pt70N+LWkUZJ2BJC0OXA4cGsFZe5Lf17NVFzuo4Cr8/aaDxyTn/7aBZgK/KmicvfXoJdP0jil/8AiXy1PJd1wbxdDea1Wj/tpSeUcrEEvX16u0bl7R+ANtOIvZFr9JMNI/5Dawa8CFgNXAtvn9A7Sv6TW8l0HdAOPk9ppD8rpu5JOWl3Az4HRrV6mQS7fh/MydAHH5bStgEXAzcBt5H+PbfUy5bIdCtxFugr8Qk77CvD23L1l3h5defvsWhj3C3m8O4FDWr0szVw+4F15W90E/AU4otXLMohle20+xh4l1U5v620/bbfPYJcPeD1wC+mJt1uAWa0ov99IYGZmlXHzmpmZVcZBx8zMKuOgY2ZmlXHQMTOzyjjomJlZZRx0zMysMg46ZmZWGQcdMzOrzP8HEJOSe4XvcpwAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IafUGD_VGeLh",
"colab_type": "text"
},
"source": [
"# Feed forward\n",
"\n",
"Feed forward once before training to check accuracy"
]
},
{
"cell_type": "code",
"metadata": {
"id": "cd6jGroQGdOF",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 134
},
"outputId": "c059fa47-5c37-442c-8207-9412d7be7eca"
},
"source": [
"def feed_forward_sample(sample, y):\n",
" \"\"\" Feeds a sample forward through the neural network.\n",
" Parameters:\n",
" sample: 1D numpy array. The input sample (an MNIST digit).\n",
" label: An integer from 0 to 9.\n",
"\n",
" Returns: The cross entropy loss.\n",
" \"\"\"\n",
" a = sample.flatten()\n",
"\n",
" for index, w in enumerate(weights):\n",
" z = np.matmul(w, a) + biases[index]\n",
" if index < len(weights) - 1:\n",
" a = sigmoid(z)\n",
" else:\n",
" a = softmax(z)\n",
"\n",
" # Calculate loss\n",
" one_hot_y = integer_to_one_hot(y, 10)\n",
" loss = cross_entropy_loss(one_hot_y, a)\n",
"\n",
" # Convert activations to one hot encoded guess\n",
" one_hot_guess = np.zeros_like(a)\n",
" one_hot_guess[np.argmax(a)] = 1\n",
" \n",
" return loss, one_hot_guess\n",
"\n",
"\n",
"# Feedforward all training samples\n",
"def feed_forward_dataset(x, y):\n",
" losses = np.empty(x.shape[0])\n",
" one_hot_guesses = np.empty((x.shape[0], 10))\n",
"\n",
" for i in range(x.shape[0]):\n",
" if i == 0 or ((i + 1) % 10000 == 0):\n",
" print(i + 1, \"/\", x.shape[0], \"(\", format(((i + 1) / x.shape[0]) * 100, \".2f\"), \"%)\")\n",
" losses[i], one_hot_guesses[i] = feed_forward(x[i], y[i])\n",
"\n",
" print(\"\\nAverage loss:\", np.round(np.average(losses), decimals=2))\n",
"\n",
" y_one_hot = np.zeros((y.size, 10))\n",
" y_one_hot[np.arange(y.size), y] = 1\n",
"\n",
" # Expected correct guesses 6 000/60 000, assuming perfect randomness\n",
" correct_guesses = np.sum(y_one_hot * one_hot_guesses)\n",
" correct_guess_percent = format((correct_guesses / y.shape[0]) * 100, \".2f\")\n",
" print(\"Accuracy (# of correct guesses):\", correct_guesses, \"/\", y.shape[0], \"(\", correct_guess_percent, \"%)\")\n",
"\n",
"def feed_forward_training_data():\n",
" print(\"Feeding forward all training data...\")\n",
" feed_forward_dataset(x_train, y_train)\n",
" print(\"\")\n",
"\n",
"def feed_forward_test_data():\n",
" print(\"Feeding forward all test data...\")\n",
" feed_forward_dataset(x_test, y_test)\n",
" print(\"\")\n",
"\n",
"feed_forward_test_data()"
],
"execution_count": 181,
"outputs": [
{
"output_type": "stream",
"text": [
"Feeding forward all test data...\n",
"1 / 10000 ( 0.01 %)\n",
"10000 / 10000 ( 100.00 %)\n",
"\n",
"Average loss: 2.36\n",
"Accuracy (# of correct guesses): 994.0 / 10000 ( 9.94 %)\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sSrlc2VLOi8L",
"colab_type": "text"
},
"source": [
"# Training\n",
"\n",
"1. Feedforward one sample, storing layer activations\n",
"2. Calculate gradient\n",
"3. Update weights & biases according to learning rate\n",
"4. Repeat\n",
"\n",
"More details about this implementation [here](https://www.mathcha.io/editor/vrmV3C1KFnvu2Dx3ewh7rgr54fBOvJL2TzoNWNe)\n",
"\n",
"The maximum accuracy reached was ~85%, after training for a few epochs. The network then starts to overfit.\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "BLLEsVdcOgzi",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 991
},
"outputId": "4dd03490-34a4-4646-8a16-733f72601816"
},
"source": [
"def train_one_sample(sample, y, learning_rate=0.003):\n",
" a = sample.flatten()\n",
"\n",
" # We will store each layer's activations to calculate gradient\n",
" activations = []\n",
"\n",
" # Feedforward\n",
" for i, w in enumerate(weights): # Each w is a layer's 2D weight matrix\n",
" z = np.matmul(w, a) + biases[i]\n",
" if (i < len(weights) - 1):\n",
" a = sigmoid(z)\n",
" else: \n",
" a = softmax(z) # softmax on last layer\n",
" activations.append(a)\n",
"\n",
" # Calculate loss\n",
" one_hot_y = integer_to_one_hot(y, 10)\n",
" loss = cross_entropy_loss(one_hot_y, a)\n",
"\n",
" # Convert last layer's activations to one hot encoded guess\n",
" one_hot_guess = np.zeros_like(a)\n",
" one_hot_guess[np.argmax(a)] = 1\n",
"\n",
" # Check whether guess was correct\n",
" correct_guess = (np.sum(one_hot_y * one_hot_guess) == 1)\n",
"\n",
" weight_gradients = [None] * len(weights)\n",
" bias_gradients = [None] * len(weights)\n",
" activation_gradients = [None] * (len(weights) - 1)\n",
" \n",
" # Backpropagation\n",
" for i in range(len(weights) - 1, -1, -1): # Traverse layers in reverse\n",
" last_layer = i == len(weights) - 1\n",
" second_to_last_layer = i == len(weights) - 2\n",
"\n",
" if last_layer:\n",
" # Gather all needed variables, making vectors vertical\n",
" y = one_hot_y[:, np.newaxis]\n",
" a = activations[i][:, np.newaxis]\n",
" a_prev = activations[i-1][:, np.newaxis]\n",
"\n",
" weight_gradients[i] = np.matmul((a - y), a_prev.T)\n",
" bias_gradients[i] = a - y\n",
"\n",
" else:\n",
" # Gather all needed variables, making vectors vertical\n",
" w_next = weights[i+1]\n",
" a_next = activations[i + 1][:, np.newaxis]\n",
" y = one_hot_y[:, np.newaxis]\n",
" a = activations[i][:, np.newaxis]\n",
" if i > 0:\n",
" a_prev = activations[i-1][:, np.newaxis]\n",
" else:\n",
" # Previous activation is the sample itself\n",
" a_prev = sample.flatten()[:, np.newaxis]\n",
"\n",
" # Activation gradient\n",
" if second_to_last_layer:\n",
" dCda = np.matmul(w_next.T, (a_next - y))\n",
" activation_gradients[i] = dCda\n",
" else:\n",
" dCda_next = activation_gradients[i+1]\n",
" dCda = np.matmul(w_next.T, (dsigmoid(a_next) * dCda_next))\n",
" activation_gradients[i] = dCda\n",
"\n",
" # Weights & biases gradients\n",
" x = dsigmoid(a) * dCda\n",
" weight_gradients[i] = np.matmul(x, a_prev.T)\n",
" bias_gradients[i] = x\n",
"\n",
" # Update weights & biases based on gradient\n",
" weights[i] -= weight_gradients[i] * learning_rate\n",
" biases[i] -= bias_gradients[i].flatten() * learning_rate\n",
"\n",
"def train_one_epoch(learning_rate=0.003):\n",
" print(\"Training for one epoch over the training dataset...\")\n",
" for i in range(x_train.shape[0]):\n",
" if i == 0 or ((i + 1) % 10000 == 0):\n",
" completion_percent = format(((i + 1) / x_train.shape[0]) * 100, \".2f\")\n",
" print(i + 1, \"/\", x_train.shape[0], \"(\", completion_percent, \"%)\")\n",
" train_one_sample(x_train[i], y_train[i], learning_rate)\n",
" print(\"Finished training.\\n\")\n",
"\n",
"# Train and check accuracy before & after each epoch\n",
"\n",
"feed_forward_test_data()\n",
"\n",
"def test_and_train():\n",
" train_one_epoch()\n",
" feed_forward_test_data()\n",
"\n",
"for i in range(3): # Adjust number of epochs here\n",
" test_and_train()"
],
"execution_count": 182,
"outputs": [
{
"output_type": "stream",
"text": [
"Feeding forward all test data...\n",
"1 / 10000 ( 0.01 %)\n",
"10000 / 10000 ( 100.00 %)\n",
"\n",
"Average loss: 2.36\n",
"Accuracy (# of correct guesses): 994.0 / 10000 ( 9.94 %)\n",
"\n",
"Training for one epoch over the training dataset...\n",
"1 / 60000 ( 0.00 %)\n",
"10000 / 60000 ( 16.67 %)\n",
"20000 / 60000 ( 33.33 %)\n",
"30000 / 60000 ( 50.00 %)\n",
"40000 / 60000 ( 66.67 %)\n",
"50000 / 60000 ( 83.33 %)\n",
"60000 / 60000 ( 100.00 %)\n",
"Finished training.\n",
"\n",
"Feeding forward all test data...\n",
"1 / 10000 ( 0.01 %)\n",
"10000 / 10000 ( 100.00 %)\n",
"\n",
"Average loss: 0.63\n",
"Accuracy (# of correct guesses): 8023.0 / 10000 ( 80.23 %)\n",
"\n",
"Training for one epoch over the training dataset...\n",
"1 / 60000 ( 0.00 %)\n",
"10000 / 60000 ( 16.67 %)\n",
"20000 / 60000 ( 33.33 %)\n",
"30000 / 60000 ( 50.00 %)\n",
"40000 / 60000 ( 66.67 %)\n",
"50000 / 60000 ( 83.33 %)\n",
"60000 / 60000 ( 100.00 %)\n",
"Finished training.\n",
"\n",
"Feeding forward all test data...\n",
"1 / 10000 ( 0.01 %)\n",
"10000 / 10000 ( 100.00 %)\n",
"\n",
"Average loss: 0.51\n",
"Accuracy (# of correct guesses): 8481.0 / 10000 ( 84.81 %)\n",
"\n",
"Training for one epoch over the training dataset...\n",
"1 / 60000 ( 0.00 %)\n",
"10000 / 60000 ( 16.67 %)\n",
"20000 / 60000 ( 33.33 %)\n",
"30000 / 60000 ( 50.00 %)\n",
"40000 / 60000 ( 66.67 %)\n",
"50000 / 60000 ( 83.33 %)\n",
"60000 / 60000 ( 100.00 %)\n",
"Finished training.\n",
"\n",
"Feeding forward all test data...\n",
"1 / 10000 ( 0.01 %)\n",
"10000 / 10000 ( 100.00 %)\n",
"\n",
"Average loss: 0.49\n",
"Accuracy (# of correct guesses): 8563.0 / 10000 ( 85.63 %)\n",
"\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment