Created
September 29, 2019 14:41
-
-
Save ilyarudyak/10d6dcdb1e439f1d09932062bf7dfda5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## theory" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's get the result in this easy case." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Suppose $z = xW = [W_{s1} ... W_{sD}]$ where $x$ is one-hot encoded and $x_s = 1$. Then we have $\\frac{\\partial J}{\\partial W_{ij}} = \\sum_k{\\frac{\\partial J}{\\partial z_k} \\frac{\\partial z_k}{\\partial W_{ij}}}$." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"This is equal to $0$ if $s \\neq i$. Also $\\frac{\\partial z_k}{\\partial W_{sj}} = 0$ if $k \\neq j$ and $1$ otherwise. So $\\frac{\\partial J}{\\partial W_{sj}} = \\frac{\\partial J}{\\partial z_j}$." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"In other words: $\\frac{\\partial J}{\\partial W} = \\begin{bmatrix}0 & ... & 0\\\\ \\frac{\\partial J}{\\partial z_1} & ... & \\frac{\\partial J}{\\partial z_D} \\\\ 0 & ... & 0 \\end{bmatrix}$." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"So we just need to fill in $s^{th}$ row of our gradient with the upstream gradient. " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## implementation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.random.seed(42)\n", | |
"V, D = 10, 2\n", | |
"x = [3]\n", | |
"W = np.random.randn(V, D)\n", | |
"dW = np.zeros_like(W)\n", | |
"dout = np.random.randn(1, D)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 1.46564877, -0.2257763 ]])" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dout" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.add.at(dW, x, dout)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0. , 0. ],\n", | |
" [ 0. , 0. ],\n", | |
" [ 0. , 0. ],\n", | |
" [ 1.46564877, -0.2257763 ],\n", | |
" [ 0. , 0. ],\n", | |
" [ 0. , 0. ],\n", | |
" [ 0. , 0. ],\n", | |
" [ 0. , 0. ],\n", | |
" [ 0. , 0. ],\n", | |
" [ 0. , 0. ]])" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dW" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"This concludes our short analysis in this simple case." | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment