Skip to content

Instantly share code, notes, and snippets.

@wdhorton
Created September 18, 2018 02:39
Show Gist options
  • Save wdhorton/9fa988750bab88046168a81a813e468f to your computer and use it in GitHub Desktop.
Save wdhorton/9fa988750bab88046168a81a813e468f to your computer and use it in GitHub Desktop.
Testing ConstantPadNd backward
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.autograd import Variable"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"class Context:\n",
" pad = (1, 1, -1, -1)\n",
" value = 0\n",
" input_size = torch.Size([4, 5])\n",
" l_inp = 2\n",
" pad_tup = ((-1, -1), (1, 1))\n",
" l_pad = 2\n",
" l_diff = 0"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.5822, -0.7529, 0.4968, -0.2158, -1.3026, 1.1133, 0.6431],\n",
" [-0.5976, -0.2949, 0.4010, -0.9164, 0.4081, -0.8147, 0.0278]])"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ctx = Context()\n",
"grad_output = torch.randn(2, 7)\n",
"grad_output"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"def backward(ctx, grad_output):\n",
" grad_input = Variable(grad_output.data.new(ctx.input_size).zero_())\n",
"\n",
" cg_input = grad_input\n",
" for i, p in zip(range(ctx.l_inp)[-ctx.l_pad:], ctx.pad_tup):\n",
" if p[0] < 0:\n",
" cg_input = cg_input.narrow(i, -p[0], cg_input.size(i) + p[0])\n",
" if p[1] < 0:\n",
" cg_input = cg_input.narrow(i, 0, cg_input.size(i) + p[1])\n",
"\n",
" cg_output = grad_output\n",
" for i, p in zip(range(ctx.l_inp)[-ctx.l_pad:], ctx.pad_tup):\n",
" if p[0] > 0:\n",
" cg_output = cg_output.narrow(i, p[0], cg_output.size(i) - p[0])\n",
" if p[1] > 0:\n",
" cg_output = cg_output.narrow(i, 0, cg_output.size(i) - p[1])\n",
" cg_input.copy_(cg_output)\n",
" return grad_input"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [-0.7529, 0.4968, -0.2158, -1.3026, 1.1133],\n",
" [-0.2949, 0.4010, -0.9164, 0.4081, -0.8147],\n",
" [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"backward(ctx, grad_output)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"def backward_return_output(ctx, grad_output):\n",
" grad_input = Variable(grad_output.data.new(ctx.input_size).zero_())\n",
"\n",
" cg_input = grad_input\n",
" for i, p in zip(range(ctx.l_inp)[-ctx.l_pad:], ctx.pad_tup):\n",
" if p[0] < 0:\n",
" cg_input = cg_input.narrow(i, -p[0], cg_input.size(i) + p[0])\n",
" if p[1] < 0:\n",
" cg_input = cg_input.narrow(i, 0, cg_input.size(i) + p[1])\n",
"\n",
" cg_output = grad_output\n",
" for i, p in zip(range(ctx.l_inp)[-ctx.l_pad:], ctx.pad_tup):\n",
" if p[0] > 0:\n",
" cg_output = cg_output.narrow(i, p[0], cg_output.size(i) - p[0])\n",
" if p[1] > 0:\n",
" cg_output = cg_output.narrow(i, 0, cg_output.size(i) - p[1])\n",
" return cg_output"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.7529, 0.4968, -0.2158, -1.3026, 1.1133],\n",
" [-0.2949, 0.4010, -0.9164, 0.4081, -0.8147]])"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"backward_return_output(ctx, grad_output)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment