Skip to content

Instantly share code, notes, and snippets.

@ilyarudyak
Last active January 31, 2022 11:55
Show Gist options
  • Save ilyarudyak/55ff4d9c705964eb3dc83bde091d97a8 to your computer and use it in GitHub Desktop.
Save ilyarudyak/55ff4d9c705964eb3dc83bde091d97a8 to your computer and use it in GitHub Desktop.
Deriving gradient for batch norm (cs231n)
Display the source blob
Display the rendered blob
Raw
{"cells":[{"cell_type":"code","execution_count":null,"id":"ongoing-appraisal","metadata":{"id":"ongoing-appraisal","outputId":"f8661c05-71f4-4782-d141-f74ff57b6698"},"outputs":[{"name":"stdout","output_type":"stream","text":["The autoreload extension is already loaded. To reload it, use:\n"," %reload_ext autoreload\n"]}],"source":["import numpy as np\n","import matplotlib.pyplot as plt\n","from cs231n.classifiers.fc_net import *\n","from cs231n.data_utils import get_CIFAR10_data\n","from cs231n.gradient_check import eval_numerical_gradient, eval_numerical_gradient_array\n","from cs231n.solver import Solver\n","\n","\n","%load_ext autoreload\n","%autoreload 2"]},{"cell_type":"markdown","id":"flush-veteran","metadata":{"id":"flush-veteran"},"source":["In [assignment 2 of cs231n](https://cs231n.github.io/assignments2021/assignment2/#q2-batch-normalization-34) we need to compute forward and backward passes of batch normalization using a few different approaches (specifically I'm talking about Spring 2021 version of the class). Moreover we just don't have any formulas to do this (see [here](https://cs231n.github.io/neural-networks-2/)):\n","\n","> We do not expand on this technique here because it is well described in the linked paper ..."]},{"cell_type":"markdown","id":"posted-leonard","metadata":{"id":"posted-leonard"},"source":["We have to use at least 2 approaches to do this (specified in the description of functions):\n","\n","- For this implementation, you should write out a computation graph for \n"," batch normalization on paper and propagate gradients backward through\n"," intermediate nodes.\n","- For this implementation you should work out the derivatives for the batch\n"," normalizaton backward pass on paper and simplify as much as possible. You\n"," should be able to derive a simple expression for the backward pass.\n"," See the jupyter notebook for more hints.\n"," \n","I would start from yet another approach - let's implement formulas from the [paper](https://arxiv.org/abs/1502.03167)."]},{"cell_type":"markdown","id":"exact-sunrise","metadata":{"id":"exact-sunrise"},"source":["## 01 - formulas (no simplification)"]},{"cell_type":"markdown","id":"italic-missile","metadata":{"id":"italic-missile"},"source":["Suppose $x_i \\in \\mathbb{R}^D$ and we have $N$ training examples (size of our batch; not $m$ as in the paper). So our matrix $X$ has the size $N \\times D$. \n","\n","What is the size of the batch mean (and variance)? Well we average over training examples, so $\\mu \\in \\mathbb{R}^D\\ (\\sigma^2 \\in \\mathbb{R}^D)$:\n","\n","$$\\mu = \\frac{1}{N} \\sum_i{x_i}$$\n","$$\\sigma^2 = \\frac{1}{N} \\sum_i{(x_i - \\mu)^2}$$"]},{"cell_type":"markdown","id":"nearby-weapon","metadata":{"id":"nearby-weapon"},"source":["### 01-1 gradients with respect to $\\gamma$ and $\\beta$"]},{"cell_type":"markdown","id":"billion-record","metadata":{"id":"billion-record"},"source":["Let's start from the easiest gradients - $\\partial L / \\partial \\gamma$ and $\\partial L / \\partial \\beta$ (we're considering only the former; all the derivations for the latter are quite similar). We may find in the paper that:\n","\n","$$\n","\\frac{\\partial L}{\\partial \\gamma} = \n","\\sum_i{\\frac{\\partial L}{\\partial y_i} \\hat{x}_i}\n","$$"]},{"cell_type":"markdown","id":"helpful-wells","metadata":{"id":"helpful-wells"},"source":["First of all this is an equation between vectors in $\\mathbb{R}^D$. So we don't have a scalar product here (otherwise we would have a sum os scalars) - rather we have element-by-element product like in many other places:\n","\n","$$\n","\\frac{\\partial L}{\\partial \\gamma} = \n","\\sum_i{\\frac{\\partial L}{\\partial y_i} * \\hat{x}_i}\n","$$"]},{"cell_type":"markdown","id":"elder-summary","metadata":{"id":"elder-summary"},"source":["To prove this let's consider a derivative of $L$ with respect to $\\gamma_k$:\n","\n","$$\n","\\frac{\\partial L}{\\partial \\gamma_k} = \n","\\sum_{ij}{\n","\\frac{\\partial L}{\\partial y_{ij}} \n","\\frac{\\partial y_{ij}}{\\partial \\gamma_k} \n","} =\n","\\sum_{i}{\n","\\frac{\\partial L}{\\partial y_{ik}} \n","\\frac{\\partial y_{ik}}{\\partial \\gamma_k} \n","} =\n","\\sum_{i}{\n","\\frac{\\partial L}{\\partial y_{ik}} \n","\\hat{x}_{ik} \n","}\n","$$"]},{"cell_type":"markdown","id":"executed-metadata","metadata":{"id":"executed-metadata"},"source":["How can we compute this in `python`? We need first to compute an element-by-element product of $N \\times D$ `dout` and `x_hat` and then sum up over rows (which are our training examples).\n","\n","```python\n","dgamma = np.sum(dout * x_hat, axis=0) \n","\n","```"]},{"cell_type":"markdown","id":"sunset-employee","metadata":{"id":"sunset-employee"},"source":["### 01-2 gradient with respect to $\\hat{x}$"]},{"cell_type":"markdown","id":"engaging-valley","metadata":{"id":"engaging-valley"},"source":["Let's do this derivation with almost all the details for one of the gradients."]},{"cell_type":"markdown","id":"reduced-medicaid","metadata":{"id":"reduced-medicaid"},"source":["Next question - what is the size of $\\gamma$ and $\\beta$? It turns out that they are vectors in $\\mathbb{R}^D$ as well. In the paper we may find the formula:\n","\n","$$y^{(k)} = \\gamma^{(k)} \\hat{x}^{(k)} + \\beta^{(k)}$$\n","\n","Here $k=1, ..., D$ stands for a component of a vector from $\\mathbb{R}^D$. So in terms of vectors we have (where $*$ means element-by-element product):\n","\n","$$y = \\gamma * \\hat{x} + \\beta$$"]},{"cell_type":"markdown","id":"likely-commerce","metadata":{"id":"likely-commerce"},"source":["Let's find $\\partial L / \\partial \\hat{x}_{i}\\ i=1, ..., N\\ j=1, ..., D$ (we're using lower indices - 1st is for training example, 2nd for its component):\n","\n","$$y_i = \\gamma \\hat{x}_i + \\beta$$\n","$$y_{ij} = \\gamma_j \\hat{x}_{ij} + \\beta{j}$$\n","$$\n","\\frac{\\partial y_{ij}}{\\partial \\hat{x}_{ij}} = \\gamma_j \\ \n","\\frac{\\partial y_{ks}}{\\partial \\hat{x}_{ij}} = 0\\ \\ k, s \\neq i, j\n","$$\n","\n"]},{"cell_type":"markdown","id":"sitting-eligibility","metadata":{"id":"sitting-eligibility"},"source":["To find gradient with respect to $\\hat{x}_{ij}$ we need to sum over all $y_{ks}$ but all of them are $0$s except $ij$:\n","\n","$$\n","\\frac{\\partial L}{\\partial \\hat{x}_{ij}} = \n","\\sum_{ks}\n"," {\\frac{\\partial L}{\\partial y_{ks}}}\n"," {\\frac{\\partial y_{ks}}{\\partial \\hat{x}_{ij}}} =\n","{\\frac{\\partial L}{\\partial y_{ij}}}\n","{\\frac{\\partial y_{ij}}{\\partial \\hat{x}_{ij}}} = \n","{\\frac{\\partial L}{\\partial y_{ij}}} \\gamma_j\n","$$\n","\n","$$\n","\\frac{\\partial L}{\\partial \\hat{x}_{i}} =\n","{\\frac{\\partial L}{\\partial y_{i}}} * \\gamma\n","$$\n","\n","The last equality is between vectors in $\\mathbb{R}^D$ and we again have element-by-element multiplication. It's important to notice that we have the same $\\gamma$ for all of our training examples $x_i$. We also have to notice that $\\partial L / \\partial y_i$ is a row of an incoming derivative `dout`."]},{"cell_type":"markdown","id":"guided-archives","metadata":{"id":"guided-archives"},"source":["Finally how can we compute $\\partial L / \\partial \\hat{X}$ where $\\hat{X}$ is an $N \\times D$ matrix? We have to multiply each row of incoming gradient `dout` by $\\gamma$ and we may do this using broadcasting:\n","\n","```python\n","dx_hat = gamma * dout\n","```"]},{"cell_type":"code","execution_count":null,"id":"streaming-baptist","metadata":{"id":"streaming-baptist"},"outputs":[],"source":["dout = np.arange(6).reshape(2, 3)"]},{"cell_type":"code","execution_count":null,"id":"expensive-illustration","metadata":{"id":"expensive-illustration","outputId":"afc82fe8-ec75-4daa-d3f5-d952330cef13"},"outputs":[{"data":{"text/plain":["array([[0, 1, 2],\n"," [3, 4, 5]])"]},"execution_count":144,"metadata":{},"output_type":"execute_result"}],"source":["dout"]},{"cell_type":"code","execution_count":null,"id":"future-recovery","metadata":{"id":"future-recovery"},"outputs":[],"source":["gamma = np.array([10, 20, 30])"]},{"cell_type":"code","execution_count":null,"id":"tough-monitoring","metadata":{"id":"tough-monitoring","outputId":"b6b43d81-6bf5-4ce7-b37f-7d5376df8de3"},"outputs":[{"data":{"text/plain":["array([[ 0, 20, 60],\n"," [ 30, 80, 150]])"]},"execution_count":146,"metadata":{},"output_type":"execute_result"}],"source":["gamma * dout"]},{"cell_type":"code","execution_count":null,"id":"chronic-corrections","metadata":{"id":"chronic-corrections","outputId":"976226e3-ba53-4821-b093-c61b819a73c7"},"outputs":[{"data":{"text/plain":["array([ 0, 20, 60])"]},"execution_count":147,"metadata":{},"output_type":"execute_result"}],"source":["gamma * dout[0]"]},{"cell_type":"code","execution_count":null,"id":"whole-antarctica","metadata":{"id":"whole-antarctica","outputId":"ee6bd0b1-64d2-447d-d7fd-fd79f1401537"},"outputs":[{"data":{"text/plain":["array([ 30, 80, 150])"]},"execution_count":148,"metadata":{},"output_type":"execute_result"}],"source":["gamma * dout[1]"]},{"cell_type":"markdown","id":"lovely-element","metadata":{"id":"lovely-element"},"source":["### 01-2 gradient with respect to $x_i$"]},{"cell_type":"markdown","id":"assured-operation","metadata":{"id":"assured-operation"},"source":["So suppose we have gradients $\\partial L / \\partial \\hat{x}_i$, $\\partial L / \\partial \\sigma^2$ and $\\partial L / \\partial \\mu$. How can we get the gradient with respect to $x_i$? We may consider $L$ as a function of $\\hat{x}_i$, $\\sigma^2$ and $\\mu$:\n","\n","$$\n","\\frac{\\partial L}{\\partial x_i} = \n","\\frac{\\partial L}{\\partial \\hat{x}_i} \\frac{\\partial \\hat{x}_i}{\\partial x_i} +\n","\\frac{\\partial L}{\\partial \\sigma^2} \\frac{\\partial \\sigma^2}{\\partial x_i} +\n","\\frac{\\partial L}{\\partial \\mu} \\frac{\\partial \\mu}{\\partial x_i} = \n","\\frac{\\partial L}{\\partial \\hat{x}_i} \\frac{1}{\\sqrt{\\sigma^2 + \\epsilon}} +\n","\\frac{\\partial L}{\\partial \\sigma^2} \\frac{2(x_i - \\mu)}{N} +\n","\\frac{\\partial L}{\\partial \\mu} \\frac{1}{N}\n","$$"]},{"cell_type":"markdown","id":"forty-equality","metadata":{"id":"forty-equality"},"source":["## 02 - formulas (simplification)"]},{"cell_type":"markdown","id":"stretch-genetics","metadata":{"id":"stretch-genetics"},"source":["The main idea behind simplification is to notice that:\n","$\n","\\sum_i{(x_i - \\mu)} = 0\n","$"]},{"cell_type":"markdown","id":"distant-library","metadata":{"id":"distant-library"},"source":["We have 2 main formulas for the simplified version:\n","\n","$$\n","\\frac{\\partial L}{\\partial x_i} = \n","\\frac{1}{\\sqrt{\\sigma^2 + \\epsilon}}\n","\\left( \\frac{\\partial L}{\\partial \\hat{x}_i} - \n","\\frac{1}{N}\\sum_s{\\frac{\\partial L}{\\partial \\hat{x}_s}} -\n","\\frac{x_i}{N}\\sum_s{\\frac{\\partial L}{\\partial \\hat{x}_s} \\hat{x}_s}\n","\\right)\n","$$\n","\n","$$\n","\\frac{\\partial L}{\\partial x_i} = \n","\\frac{\\gamma}{\\sqrt{\\sigma^2 + \\epsilon}}\n","\\left( \\frac{\\partial L}{\\partial y_i} - \n","\\frac{1}{N}\n","\\left(\n","\\frac{\\partial L}{\\partial \\gamma} \\hat{x}_i +\n","\\frac{\\partial L}{\\partial \\beta}\n","\\right)\n","\\right)\n","$$"]},{"cell_type":"markdown","id":"dental-local","metadata":{"id":"dental-local"},"source":["Let's prove that the 2nd term is our gradient with respect to $\\mu$:\n","\n","$$\n","\\frac{1}{N} \\frac{\\partial L}{\\partial \\mu} =\n","-\\frac{1}{N \\sqrt{\\sigma^2 + \\epsilon}}\n","\\sum_s{\\frac{\\partial L}{\\partial \\hat{x}_s}}\n","$$"]},{"cell_type":"markdown","id":"vulnerable-novel","metadata":{"id":"vulnerable-novel"},"source":["If we use the idea that mentioned above we'll get what we need (the formula is from the paper):\n","\n","$$\n","\\frac{\\partial L}{\\partial \\mu} = \n","-\\frac{1}{\\sqrt{\\sigma^2 + \\epsilon}}\n","\\sum_s{\\frac{\\partial L}{\\partial \\hat{x}_s}}\n","$$"]},{"cell_type":"markdown","id":"green-chile","metadata":{"id":"green-chile"},"source":["Finally let's see why the 2nd formula is correct. Let's look at the first term. We have to show that:\n","\n","$$\n","\\frac{\\partial L}{\\partial \\hat{x}_{i}} =\n","{\\frac{\\partial L}{\\partial y_{i}}} * \\gamma\n","$$\n","\n","But that's exactly what we proved above."]},{"cell_type":"code","execution_count":null,"id":"structured-philadelphia","metadata":{"id":"structured-philadelphia"},"outputs":[],"source":[""]},{"cell_type":"code","execution_count":null,"id":"informative-mercury","metadata":{"id":"informative-mercury"},"outputs":[],"source":[""]},{"cell_type":"code","execution_count":null,"id":"broad-survival","metadata":{"id":"broad-survival"},"outputs":[],"source":[""]},{"cell_type":"code","execution_count":null,"id":"featured-democrat","metadata":{"id":"featured-democrat"},"outputs":[],"source":[""]},{"cell_type":"code","execution_count":null,"id":"mathematical-bookmark","metadata":{"id":"mathematical-bookmark"},"outputs":[],"source":[""]}],"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.6"},"colab":{"name":"batch_norm_debugging_v1.ipynb","provenance":[]}},"nbformat":4,"nbformat_minor":5}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment