Skip to content

Instantly share code, notes, and snippets.

@izmailovpavel
Created March 2, 2018 15:42
Show Gist options
  • Save izmailovpavel/65afa6212f21fe752e48056d8d723f9d to your computer and use it in GitHub Desktop.
Save izmailovpavel/65afa6212f21fe752e48056d8d723f9d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"from torch import nn as nn\n",
"from torch.autograd import Variable"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch.utils.hooks.RemovableHandle at 0x7f23101cdc18>"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"grads = {}\n",
"def save_grad(name):\n",
" def hook(grad):\n",
" grads[name] = grad\n",
" return hook\n",
"\n",
"def extract_grad(var):\n",
" print(var)\n",
" print(var.shape)\n",
" return var\n",
"\n",
"n_feat = 10\n",
"n_obj = 25\n",
"X = np.random.normal(size=(n_obj, n_feat))\n",
"y = np.random.randint(low=0, high=10, size=(n_obj))\n",
"X_ = Variable(torch.from_numpy(X), requires_grad=True)\n",
"y_ = Variable(torch.from_numpy(y))\n",
"lsm = nn.LogSoftmax(dim=1)(X_)\n",
"l = nn.NLLLoss()(lsm, y_)\n",
"\n",
"lsm.register_hook(save_grad('lsm'))"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"l.backward()"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Variable containing:\n",
"1.00000e-02 *\n",
" 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n",
" 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000\n",
" 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000\n",
" 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n",
" 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n",
" 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n",
" 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n",
" 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n",
" 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n",
" 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000\n",
" 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n",
" 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000\n",
" 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000\n",
" 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000\n",
" 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000\n",
" 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000\n",
" 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000\n",
" 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n",
" 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n",
" 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n",
" 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000\n",
" 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n",
" 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n",
" 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n",
" -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n",
"[torch.DoubleTensor of size 25x10]"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"grads['lsm']"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Variable containing:\n",
"1.00000e-02 *\n",
" 0.8155 0.2794 -3.8968 0.5208 0.1211 0.1943 0.5007 0.3851 0.1434 0.9365\n",
" 0.5790 0.1383 0.5933 0.6330 0.4394 0.6952 0.1665 0.0873 -3.5001 0.1682\n",
" 0.5062 0.0578 0.0737 1.0623 1.3683 0.1590 0.1494 0.5018 0.0833 -3.9618\n",
" 0.1304 1.4510 -3.7681 0.0102 0.1075 0.6720 0.5830 0.4104 0.2180 0.1856\n",
" 0.1847 0.4890 -3.7672 0.2188 0.9282 0.1364 0.0864 0.4260 0.4284 0.8693\n",
" 0.0223 -3.5900 1.6741 0.1133 0.0790 0.3652 0.1131 0.0791 0.7763 0.3677\n",
" 0.1598 0.1576 0.2000 0.1278 -3.8787 0.7640 0.4842 0.5352 1.2654 0.1847\n",
" 0.6389 0.1485 0.1226 -3.6593 0.1309 0.1685 0.3405 1.6761 0.1646 0.2687\n",
" 0.6277 -3.3878 0.2061 0.1202 1.0843 0.2062 0.4640 0.0806 0.4340 0.1647\n",
" 0.6289 0.2413 0.0680 0.2713 0.2365 0.8447 0.0867 -2.9684 0.2418 0.3492\n",
" 0.6839 0.0435 0.2797 0.1037 -3.7896 0.5422 0.6385 0.8123 0.0335 0.6521\n",
" 0.2857 0.2350 0.7899 0.2513 0.8590 0.0355 -3.4502 0.2588 0.2522 0.4828\n",
" 0.2759 0.1700 0.1678 0.2723 0.0284 0.0710 1.5839 0.7363 -3.4987 0.1932\n",
" 0.3499 0.9402 0.4786 0.2367 0.7633 -3.6291 0.3840 0.0831 0.2335 0.1597\n",
" 0.2092 0.1013 0.7504 0.2398 0.0702 -3.7920 0.2251 0.3815 1.5150 0.2994\n",
" 0.1143 0.0623 0.2665 0.3133 0.5581 -3.8609 0.9103 0.4654 0.4638 0.7068\n",
" 0.1304 0.2630 0.3659 1.7599 0.1529 -3.8142 0.3330 0.4531 0.2820 0.0741\n",
" 0.0638 0.5095 -3.7011 0.2375 0.0971 0.0885 0.1205 0.0939 2.4365 0.0538\n",
" 0.6268 -3.7864 0.1416 1.0899 0.7207 0.3281 0.0347 0.2979 0.2298 0.3171\n",
" 0.4224 0.7358 -3.3912 0.2569 0.2338 0.3163 0.3436 0.4604 0.3109 0.3111\n",
" 0.5334 0.0474 0.3382 0.8208 0.4691 -3.7323 0.4889 0.1613 0.4568 0.4165\n",
" 0.4489 -3.5636 0.4416 0.0672 0.0471 0.8899 0.7920 0.6428 0.1531 0.0810\n",
" 0.1111 0.1431 0.1120 -2.9565 0.6451 0.9736 0.0359 0.5117 0.2584 0.1655\n",
" 0.3391 0.9316 0.2202 -3.4725 0.0128 0.1691 0.2743 0.4158 0.8600 0.2497\n",
" -3.8255 0.6958 0.3361 0.3637 0.0943 0.3866 0.4016 0.6418 0.2407 0.6649\n",
"[torch.DoubleTensor of size 25x10]"
]
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_.grad"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"lsm.grad"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
"dlsm_dXs = []\n",
"for i in range(n_obj):\n",
" denom = np.sum(np.exp(X[i]))\n",
" dlsm_dXs.append(np.eye(n_feat) - np.exp(X[i][:, None]) / denom)\n",
"dlsm_dX = np.array(dlsm_dXs)"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
"dlsm_dX_ = Variable(torch.from_numpy(dlsm_dX))"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [],
"source": [
"ans = np.einsum('ijk, ik -> ij', dlsm_dX, grads['lsm'].data.numpy())"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2.5407775362056137e-17"
]
},
"execution_count": 84,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.linalg.norm(ans - X_.grad.data.numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment