Last active
November 4, 2021 10:38
-
-
Save ilyarudyak/a05f71b2eb903dcf9c1ca499a6deed13 to your computer and use it in GitHub Desktop.
`np.add.at` examples and explanation (based on assignment 3 of cs231n)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "instructional-punch", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Setup cell.\n", | |
"import time, os, json\n", | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"from cs231n.gradient_check import eval_numerical_gradient, eval_numerical_gradient_array\n", | |
"from cs231n.rnn_layers import *\n", | |
"from cs231n.captioning_solver import CaptioningSolver\n", | |
"from cs231n.classifiers.rnn import CaptioningRNN\n", | |
"from cs231n.coco_utils import load_coco_data, sample_coco_minibatch, decode_captions\n", | |
"from cs231n.image_utils import image_from_url\n", | |
"\n", | |
"%matplotlib inline\n", | |
"plt.rcParams['figure.figsize'] = (10.0, 8.0) # Set default size of plots.\n", | |
"plt.rcParams['image.interpolation'] = 'nearest'\n", | |
"plt.rcParams['image.cmap'] = 'gray'\n", | |
"\n", | |
"%load_ext autoreload\n", | |
"%autoreload 2\n", | |
"\n", | |
"def rel_error(x, y):\n", | |
" \"\"\" returns relative error \"\"\"\n", | |
" return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "danish-breast", | |
"metadata": {}, | |
"source": [ | |
"## 01 - indexing and gradient" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "realistic-installation", | |
"metadata": {}, | |
"source": [ | |
"### 01-1 indexing" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "entitled-hunger", | |
"metadata": {}, | |
"source": [ | |
"First of all - what is this indexing `W[x]`? Let's look at an example we're interested in:\n", | |
"- first of all `W[x]` means `W[x, :]` so the last dimension of `W` will be preserved - in our case `D=6`. \n", | |
"- the first 2 dimensions of `W[x]` are the same as dimensions of `x`;\n", | |
"- finally we take rows of `W` based on indicies in `x` - `x` is an array of indicies after all." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 60, | |
"id": "eleven-receptor", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"N, T, V, D = 2, 3, 7, 6\n", | |
"x = np.random.randint(V, size=(N, T))\n", | |
"W = np.arange(V*D).reshape(V, D)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 61, | |
"id": "activated-court", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0, 1, 2, 3, 4, 5],\n", | |
" [ 6, 7, 8, 9, 10, 11],\n", | |
" [12, 13, 14, 15, 16, 17],\n", | |
" [18, 19, 20, 21, 22, 23],\n", | |
" [24, 25, 26, 27, 28, 29],\n", | |
" [30, 31, 32, 33, 34, 35],\n", | |
" [36, 37, 38, 39, 40, 41]])" | |
] | |
}, | |
"execution_count": 61, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"W" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 62, | |
"id": "contrary-vampire", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[5, 0, 3],\n", | |
" [0, 4, 2]])" | |
] | |
}, | |
"execution_count": 62, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 63, | |
"id": "specialized-router", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[[30, 31, 32, 33, 34, 35],\n", | |
" [ 0, 1, 2, 3, 4, 5],\n", | |
" [18, 19, 20, 21, 22, 23]],\n", | |
"\n", | |
" [[ 0, 1, 2, 3, 4, 5],\n", | |
" [24, 25, 26, 27, 28, 29],\n", | |
" [12, 13, 14, 15, 16, 17]]])" | |
] | |
}, | |
"execution_count": 63, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"W[x]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 65, | |
"id": "intended-special", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((2, 3, 6), (2, 3))" | |
] | |
}, | |
"execution_count": 65, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"W[x].shape, x.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "divided-smart", | |
"metadata": {}, | |
"source": [ | |
"### 01-2 gradient" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "prerequisite-elephant", | |
"metadata": {}, | |
"source": [ | |
"So how can we compute the gradient for this operation? Well we have to compute:\n", | |
"\n", | |
"$$\n", | |
"\\frac{\\partial W_{ijk}}{\\partial x_{st}} = \\delta_{is} \\delta_{jt}\n", | |
"$$\n", | |
"\n", | |
"In other words it's 1 if $i=s$ and $j=t$ for **all** $k$ and 0 otherwise. This means gradient is 1 for **all** elements in a row with index $x_{ij} \\in \\{0, ..., V-1\\}$. So it's similar to the gradient of a sum except here we add 1 for all elements of a row. Finally if we have a repetetive index in $x$ we have to add 1 to the row yet again." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "critical-imaging", | |
"metadata": {}, | |
"source": [ | |
"In case we have an upstream gradient (its shape is equal to $W[x]$) we have to propagate it to those rows. Let's consider an example. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 120, | |
"id": "future-nudist", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.random.seed(1)\n", | |
"N, T, V, D = 2, 3, 7, 6\n", | |
"x = np.random.randint(V, size=(N, T))\n", | |
"dW_man = np.zeros((V, D))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 121, | |
"id": "harmful-curtis", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((2, 3, 6), (2, 3))" | |
] | |
}, | |
"execution_count": 121, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dW_man[x].shape, x.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 122, | |
"id": "processed-housing", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[5, 3, 4],\n", | |
" [0, 1, 3]])" | |
] | |
}, | |
"execution_count": 122, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 123, | |
"id": "fifteen-wages", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dout = np.arange(2*3*6).reshape(dW_man[x].shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 124, | |
"id": "centered-roulette", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[[ 0, 1, 2, 3, 4, 5],\n", | |
" [ 6, 7, 8, 9, 10, 11],\n", | |
" [12, 13, 14, 15, 16, 17]],\n", | |
"\n", | |
" [[18, 19, 20, 21, 22, 23],\n", | |
" [24, 25, 26, 27, 28, 29],\n", | |
" [30, 31, 32, 33, 34, 35]]])" | |
] | |
}, | |
"execution_count": 124, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dout" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "norwegian-incidence", | |
"metadata": {}, | |
"source": [ | |
"What should be the row 0 of the $dW[x]$? It should be $[18, 19, ...]$. Where should be added $[0, 1, ...]$? To the row 5. Let's compute $dW[x]$." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 125, | |
"id": "technological-yukon", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dW_man[5] = dout[0, 0]\n", | |
"dW_man[3] = dout[0, 1]\n", | |
"dW_man[4] = dout[0, 2]\n", | |
"\n", | |
"dW_man[0] = dout[1, 0]\n", | |
"dW_man[1] = dout[1, 1]\n", | |
"dW_man[3] = dout[1, 2]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 126, | |
"id": "invalid-eagle", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[18., 19., 20., 21., 22., 23.],\n", | |
" [24., 25., 26., 27., 28., 29.],\n", | |
" [ 0., 0., 0., 0., 0., 0.],\n", | |
" [30., 31., 32., 33., 34., 35.],\n", | |
" [12., 13., 14., 15., 16., 17.],\n", | |
" [ 0., 1., 2., 3., 4., 5.],\n", | |
" [ 0., 0., 0., 0., 0., 0.]])" | |
] | |
}, | |
"execution_count": 126, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dW_man" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "sunrise-monitoring", | |
"metadata": {}, | |
"source": [ | |
"## 02 - `np.add.at` " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "fifth-given", | |
"metadata": {}, | |
"source": [ | |
"### 02-1 dimension" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "billion-edmonton", | |
"metadata": {}, | |
"source": [ | |
"Note: Actually 2 dimensional case may be more relevant so you may skip this part." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "urban-surrey", | |
"metadata": {}, | |
"source": [ | |
"Let's start with the simple case when `x` is just 1-dimensional and we have no repetition. In this case we added to the rows of `W` specified in `x`. We added `1` that is broadcasted to the size of a row." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "parallel-ordering", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"N, T, V, D = 1, 3, 5, 6\n", | |
"x = np.random.randint(V, size=(N, T))\n", | |
"W = np.zeros((V, D))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "freelance-willow", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[1, 4, 2]])" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "cardiovascular-deployment", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.]])" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"W" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "romantic-armstrong", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.add.at(W, x, 1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "human-boston", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[0., 0., 0., 0., 0., 0.],\n", | |
" [1., 1., 1., 1., 1., 1.],\n", | |
" [1., 1., 1., 1., 1., 1.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [1., 1., 1., 1., 1., 1.]])" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"W" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "split-melissa", | |
"metadata": {}, | |
"source": [ | |
"Now let's add repetitions to our `x`. As we may see `1` was added to the same row twice." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "proprietary-aspect", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x = np.random.randint(V, size=(N, T))\n", | |
"W = np.zeros((V, D))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "thrown-commitment", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[2, 2, 1]])" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "posted-olive", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.add.at(W, x, 1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "vanilla-settlement", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[0., 0., 0., 0., 0., 0.],\n", | |
" [1., 1., 1., 1., 1., 1.],\n", | |
" [2., 2., 2., 2., 2., 2.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.]])" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"W" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "pretty-moment", | |
"metadata": {}, | |
"source": [ | |
"### 02-2 dimensions" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "retained-patrol", | |
"metadata": {}, | |
"source": [ | |
"Let's now try the same with 2-dimensional `x`. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 127, | |
"id": "challenging-blink", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.random.seed(1)\n", | |
"N, T, V, D = 2, 3, 7, 6\n", | |
"x = np.random.randint(V, size=(N, T))\n", | |
"dW = np.zeros((V, D))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 128, | |
"id": "logical-vertical", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.]])" | |
] | |
}, | |
"execution_count": 128, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dW" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 129, | |
"id": "fleet-sport", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[5, 3, 4],\n", | |
" [0, 1, 3]])" | |
] | |
}, | |
"execution_count": 129, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 130, | |
"id": "inside-aspect", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dout = np.arange(2*3*6).reshape(dW[x].shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 131, | |
"id": "handy-routine", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[[ 0, 1, 2, 3, 4, 5],\n", | |
" [ 6, 7, 8, 9, 10, 11],\n", | |
" [12, 13, 14, 15, 16, 17]],\n", | |
"\n", | |
" [[18, 19, 20, 21, 22, 23],\n", | |
" [24, 25, 26, 27, 28, 29],\n", | |
" [30, 31, 32, 33, 34, 35]]])" | |
] | |
}, | |
"execution_count": 131, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dout" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 132, | |
"id": "assumed-vision", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.add.at(dW, x, dout)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 133, | |
"id": "provincial-photographer", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[18., 19., 20., 21., 22., 23.],\n", | |
" [24., 25., 26., 27., 28., 29.],\n", | |
" [ 0., 0., 0., 0., 0., 0.],\n", | |
" [36., 38., 40., 42., 44., 46.],\n", | |
" [12., 13., 14., 15., 16., 17.],\n", | |
" [ 0., 1., 2., 3., 4., 5.],\n", | |
" [ 0., 0., 0., 0., 0., 0.]])" | |
] | |
}, | |
"execution_count": 133, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dW" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 134, | |
"id": "senior-technology", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((2, 3, 6), (7, 6))" | |
] | |
}, | |
"execution_count": 134, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dout.shape, dW.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 135, | |
"id": "other-legislation", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[18., 19., 20., 21., 22., 23.],\n", | |
" [24., 25., 26., 27., 28., 29.],\n", | |
" [ 0., 0., 0., 0., 0., 0.],\n", | |
" [30., 31., 32., 33., 34., 35.],\n", | |
" [12., 13., 14., 15., 16., 17.],\n", | |
" [ 0., 1., 2., 3., 4., 5.],\n", | |
" [ 0., 0., 0., 0., 0., 0.]])" | |
] | |
}, | |
"execution_count": 135, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dW_man" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "olympic-attribute", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment