Last active
January 31, 2022 11:55
-
-
Save ilyarudyak/55ff4d9c705964eb3dc83bde091d97a8 to your computer and use it in GitHub Desktop.
Deriving gradient for batch norm (cs231n)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{"cells":[{"cell_type":"code","execution_count":null,"id":"ongoing-appraisal","metadata":{"id":"ongoing-appraisal","outputId":"f8661c05-71f4-4782-d141-f74ff57b6698"},"outputs":[{"name":"stdout","output_type":"stream","text":["The autoreload extension is already loaded. To reload it, use:\n"," %reload_ext autoreload\n"]}],"source":["import numpy as np\n","import matplotlib.pyplot as plt\n","from cs231n.classifiers.fc_net import *\n","from cs231n.data_utils import get_CIFAR10_data\n","from cs231n.gradient_check import eval_numerical_gradient, eval_numerical_gradient_array\n","from cs231n.solver import Solver\n","\n","\n","%load_ext autoreload\n","%autoreload 2"]},{"cell_type":"markdown","id":"flush-veteran","metadata":{"id":"flush-veteran"},"source":["In [assignment 2 of cs231n](https://cs231n.github.io/assignments2021/assignment2/#q2-batch-normalization-34) we need to compute forward and backward passes of batch normalization using a few different approaches (specifically I'm talking about Spring 2021 version of the class). Moreover we just don't have any formulas to do this (see [here](https://cs231n.github.io/neural-networks-2/)):\n","\n","> We do not expand on this technique here because it is well described in the linked paper ..."]},{"cell_type":"markdown","id":"posted-leonard","metadata":{"id":"posted-leonard"},"source":["We have to use at least 2 approaches to do this (specified in the description of functions):\n","\n","- For this implementation, you should write out a computation graph for \n"," batch normalization on paper and propagate gradients backward through\n"," intermediate nodes.\n","- For this implementation you should work out the derivatives for the batch\n"," normalizaton backward pass on paper and simplify as much as possible. You\n"," should be able to derive a simple expression for the backward pass.\n"," See the jupyter notebook for more hints.\n"," \n","I would start from yet another approach - let's implement formulas from the [paper](https://arxiv.org/abs/1502.03167)."]},{"cell_type":"markdown","id":"exact-sunrise","metadata":{"id":"exact-sunrise"},"source":["## 01 - formulas (no simplification)"]},{"cell_type":"markdown","id":"italic-missile","metadata":{"id":"italic-missile"},"source":["Suppose $x_i \\in \\mathbb{R}^D$ and we have $N$ training examples (size of our batch; not $m$ as in the paper). So our matrix $X$ has the size $N \\times D$. \n","\n","What is the size of the batch mean (and variance)? Well we average over training examples, so $\\mu \\in \\mathbb{R}^D\\ (\\sigma^2 \\in \\mathbb{R}^D)$:\n","\n","$$\\mu = \\frac{1}{N} \\sum_i{x_i}$$\n","$$\\sigma^2 = \\frac{1}{N} \\sum_i{(x_i - \\mu)^2}$$"]},{"cell_type":"markdown","id":"nearby-weapon","metadata":{"id":"nearby-weapon"},"source":["### 01-1 gradients with respect to $\\gamma$ and $\\beta$"]},{"cell_type":"markdown","id":"billion-record","metadata":{"id":"billion-record"},"source":["Let's start from the easiest gradients - $\\partial L / \\partial \\gamma$ and $\\partial L / \\partial \\beta$ (we're considering only the former; all the derivations for the latter are quite similar). We may find in the paper that:\n","\n","$$\n","\\frac{\\partial L}{\\partial \\gamma} = \n","\\sum_i{\\frac{\\partial L}{\\partial y_i} \\hat{x}_i}\n","$$"]},{"cell_type":"markdown","id":"helpful-wells","metadata":{"id":"helpful-wells"},"source":["First of all this is an equation between vectors in $\\mathbb{R}^D$. So we don't have a scalar product here (otherwise we would have a sum os scalars) - rather we have element-by-element product like in many other places:\n","\n","$$\n","\\frac{\\partial L}{\\partial \\gamma} = \n","\\sum_i{\\frac{\\partial L}{\\partial y_i} * \\hat{x}_i}\n","$$"]},{"cell_type":"markdown","id":"elder-summary","metadata":{"id":"elder-summary"},"source":["To prove this let's consider a derivative of $L$ with respect to $\\gamma_k$:\n","\n","$$\n","\\frac{\\partial L}{\\partial \\gamma_k} = \n","\\sum_{ij}{\n","\\frac{\\partial L}{\\partial y_{ij}} \n","\\frac{\\partial y_{ij}}{\\partial \\gamma_k} \n","} =\n","\\sum_{i}{\n","\\frac{\\partial L}{\\partial y_{ik}} \n","\\frac{\\partial y_{ik}}{\\partial \\gamma_k} \n","} =\n","\\sum_{i}{\n","\\frac{\\partial L}{\\partial y_{ik}} \n","\\hat{x}_{ik} \n","}\n","$$"]},{"cell_type":"markdown","id":"executed-metadata","metadata":{"id":"executed-metadata"},"source":["How can we compute this in `python`? We need first to compute an element-by-element product of $N \\times D$ `dout` and `x_hat` and then sum up over rows (which are our training examples).\n","\n","```python\n","dgamma = np.sum(dout * x_hat, axis=0) \n","\n","```"]},{"cell_type":"markdown","id":"sunset-employee","metadata":{"id":"sunset-employee"},"source":["### 01-2 gradient with respect to $\\hat{x}$"]},{"cell_type":"markdown","id":"engaging-valley","metadata":{"id":"engaging-valley"},"source":["Let's do this derivation with almost all the details for one of the gradients."]},{"cell_type":"markdown","id":"reduced-medicaid","metadata":{"id":"reduced-medicaid"},"source":["Next question - what is the size of $\\gamma$ and $\\beta$? It turns out that they are vectors in $\\mathbb{R}^D$ as well. In the paper we may find the formula:\n","\n","$$y^{(k)} = \\gamma^{(k)} \\hat{x}^{(k)} + \\beta^{(k)}$$\n","\n","Here $k=1, ..., D$ stands for a component of a vector from $\\mathbb{R}^D$. So in terms of vectors we have (where $*$ means element-by-element product):\n","\n","$$y = \\gamma * \\hat{x} + \\beta$$"]},{"cell_type":"markdown","id":"likely-commerce","metadata":{"id":"likely-commerce"},"source":["Let's find $\\partial L / \\partial \\hat{x}_{i}\\ i=1, ..., N\\ j=1, ..., D$ (we're using lower indices - 1st is for training example, 2nd for its component):\n","\n","$$y_i = \\gamma \\hat{x}_i + \\beta$$\n","$$y_{ij} = \\gamma_j \\hat{x}_{ij} + \\beta{j}$$\n","$$\n","\\frac{\\partial y_{ij}}{\\partial \\hat{x}_{ij}} = \\gamma_j \\ \n","\\frac{\\partial y_{ks}}{\\partial \\hat{x}_{ij}} = 0\\ \\ k, s \\neq i, j\n","$$\n","\n"]},{"cell_type":"markdown","id":"sitting-eligibility","metadata":{"id":"sitting-eligibility"},"source":["To find gradient with respect to $\\hat{x}_{ij}$ we need to sum over all $y_{ks}$ but all of them are $0$s except $ij$:\n","\n","$$\n","\\frac{\\partial L}{\\partial \\hat{x}_{ij}} = \n","\\sum_{ks}\n"," {\\frac{\\partial L}{\\partial y_{ks}}}\n"," {\\frac{\\partial y_{ks}}{\\partial \\hat{x}_{ij}}} =\n","{\\frac{\\partial L}{\\partial y_{ij}}}\n","{\\frac{\\partial y_{ij}}{\\partial \\hat{x}_{ij}}} = \n","{\\frac{\\partial L}{\\partial y_{ij}}} \\gamma_j\n","$$\n","\n","$$\n","\\frac{\\partial L}{\\partial \\hat{x}_{i}} =\n","{\\frac{\\partial L}{\\partial y_{i}}} * \\gamma\n","$$\n","\n","The last equality is between vectors in $\\mathbb{R}^D$ and we again have element-by-element multiplication. It's important to notice that we have the same $\\gamma$ for all of our training examples $x_i$. We also have to notice that $\\partial L / \\partial y_i$ is a row of an incoming derivative `dout`."]},{"cell_type":"markdown","id":"guided-archives","metadata":{"id":"guided-archives"},"source":["Finally how can we compute $\\partial L / \\partial \\hat{X}$ where $\\hat{X}$ is an $N \\times D$ matrix? We have to multiply each row of incoming gradient `dout` by $\\gamma$ and we may do this using broadcasting:\n","\n","```python\n","dx_hat = gamma * dout\n","```"]},{"cell_type":"code","execution_count":null,"id":"streaming-baptist","metadata":{"id":"streaming-baptist"},"outputs":[],"source":["dout = np.arange(6).reshape(2, 3)"]},{"cell_type":"code","execution_count":null,"id":"expensive-illustration","metadata":{"id":"expensive-illustration","outputId":"afc82fe8-ec75-4daa-d3f5-d952330cef13"},"outputs":[{"data":{"text/plain":["array([[0, 1, 2],\n"," [3, 4, 5]])"]},"execution_count":144,"metadata":{},"output_type":"execute_result"}],"source":["dout"]},{"cell_type":"code","execution_count":null,"id":"future-recovery","metadata":{"id":"future-recovery"},"outputs":[],"source":["gamma = np.array([10, 20, 30])"]},{"cell_type":"code","execution_count":null,"id":"tough-monitoring","metadata":{"id":"tough-monitoring","outputId":"b6b43d81-6bf5-4ce7-b37f-7d5376df8de3"},"outputs":[{"data":{"text/plain":["array([[ 0, 20, 60],\n"," [ 30, 80, 150]])"]},"execution_count":146,"metadata":{},"output_type":"execute_result"}],"source":["gamma * dout"]},{"cell_type":"code","execution_count":null,"id":"chronic-corrections","metadata":{"id":"chronic-corrections","outputId":"976226e3-ba53-4821-b093-c61b819a73c7"},"outputs":[{"data":{"text/plain":["array([ 0, 20, 60])"]},"execution_count":147,"metadata":{},"output_type":"execute_result"}],"source":["gamma * dout[0]"]},{"cell_type":"code","execution_count":null,"id":"whole-antarctica","metadata":{"id":"whole-antarctica","outputId":"ee6bd0b1-64d2-447d-d7fd-fd79f1401537"},"outputs":[{"data":{"text/plain":["array([ 30, 80, 150])"]},"execution_count":148,"metadata":{},"output_type":"execute_result"}],"source":["gamma * dout[1]"]},{"cell_type":"markdown","id":"lovely-element","metadata":{"id":"lovely-element"},"source":["### 01-2 gradient with respect to $x_i$"]},{"cell_type":"markdown","id":"assured-operation","metadata":{"id":"assured-operation"},"source":["So suppose we have gradients $\\partial L / \\partial \\hat{x}_i$, $\\partial L / \\partial \\sigma^2$ and $\\partial L / \\partial \\mu$. How can we get the gradient with respect to $x_i$? We may consider $L$ as a function of $\\hat{x}_i$, $\\sigma^2$ and $\\mu$:\n","\n","$$\n","\\frac{\\partial L}{\\partial x_i} = \n","\\frac{\\partial L}{\\partial \\hat{x}_i} \\frac{\\partial \\hat{x}_i}{\\partial x_i} +\n","\\frac{\\partial L}{\\partial \\sigma^2} \\frac{\\partial \\sigma^2}{\\partial x_i} +\n","\\frac{\\partial L}{\\partial \\mu} \\frac{\\partial \\mu}{\\partial x_i} = \n","\\frac{\\partial L}{\\partial \\hat{x}_i} \\frac{1}{\\sqrt{\\sigma^2 + \\epsilon}} +\n","\\frac{\\partial L}{\\partial \\sigma^2} \\frac{2(x_i - \\mu)}{N} +\n","\\frac{\\partial L}{\\partial \\mu} \\frac{1}{N}\n","$$"]},{"cell_type":"markdown","id":"forty-equality","metadata":{"id":"forty-equality"},"source":["## 02 - formulas (simplification)"]},{"cell_type":"markdown","id":"stretch-genetics","metadata":{"id":"stretch-genetics"},"source":["The main idea behind simplification is to notice that:\n","$\n","\\sum_i{(x_i - \\mu)} = 0\n","$"]},{"cell_type":"markdown","id":"distant-library","metadata":{"id":"distant-library"},"source":["We have 2 main formulas for the simplified version:\n","\n","$$\n","\\frac{\\partial L}{\\partial x_i} = \n","\\frac{1}{\\sqrt{\\sigma^2 + \\epsilon}}\n","\\left( \\frac{\\partial L}{\\partial \\hat{x}_i} - \n","\\frac{1}{N}\\sum_s{\\frac{\\partial L}{\\partial \\hat{x}_s}} -\n","\\frac{x_i}{N}\\sum_s{\\frac{\\partial L}{\\partial \\hat{x}_s} \\hat{x}_s}\n","\\right)\n","$$\n","\n","$$\n","\\frac{\\partial L}{\\partial x_i} = \n","\\frac{\\gamma}{\\sqrt{\\sigma^2 + \\epsilon}}\n","\\left( \\frac{\\partial L}{\\partial y_i} - \n","\\frac{1}{N}\n","\\left(\n","\\frac{\\partial L}{\\partial \\gamma} \\hat{x}_i +\n","\\frac{\\partial L}{\\partial \\beta}\n","\\right)\n","\\right)\n","$$"]},{"cell_type":"markdown","id":"dental-local","metadata":{"id":"dental-local"},"source":["Let's prove that the 2nd term is our gradient with respect to $\\mu$:\n","\n","$$\n","\\frac{1}{N} \\frac{\\partial L}{\\partial \\mu} =\n","-\\frac{1}{N \\sqrt{\\sigma^2 + \\epsilon}}\n","\\sum_s{\\frac{\\partial L}{\\partial \\hat{x}_s}}\n","$$"]},{"cell_type":"markdown","id":"vulnerable-novel","metadata":{"id":"vulnerable-novel"},"source":["If we use the idea that mentioned above we'll get what we need (the formula is from the paper):\n","\n","$$\n","\\frac{\\partial L}{\\partial \\mu} = \n","-\\frac{1}{\\sqrt{\\sigma^2 + \\epsilon}}\n","\\sum_s{\\frac{\\partial L}{\\partial \\hat{x}_s}}\n","$$"]},{"cell_type":"markdown","id":"green-chile","metadata":{"id":"green-chile"},"source":["Finally let's see why the 2nd formula is correct. Let's look at the first term. We have to show that:\n","\n","$$\n","\\frac{\\partial L}{\\partial \\hat{x}_{i}} =\n","{\\frac{\\partial L}{\\partial y_{i}}} * \\gamma\n","$$\n","\n","But that's exactly what we proved above."]},{"cell_type":"code","execution_count":null,"id":"structured-philadelphia","metadata":{"id":"structured-philadelphia"},"outputs":[],"source":[""]},{"cell_type":"code","execution_count":null,"id":"informative-mercury","metadata":{"id":"informative-mercury"},"outputs":[],"source":[""]},{"cell_type":"code","execution_count":null,"id":"broad-survival","metadata":{"id":"broad-survival"},"outputs":[],"source":[""]},{"cell_type":"code","execution_count":null,"id":"featured-democrat","metadata":{"id":"featured-democrat"},"outputs":[],"source":[""]},{"cell_type":"code","execution_count":null,"id":"mathematical-bookmark","metadata":{"id":"mathematical-bookmark"},"outputs":[],"source":[""]}],"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.6"},"colab":{"name":"batch_norm_debugging_v1.ipynb","provenance":[]}},"nbformat":4,"nbformat_minor":5} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment