Last active
October 6, 2019 20:50
-
-
Save mkolod/956069e8ee345ba11b0679cf864695c7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import math\n", | |
"import numpy as np\n", | |
"import torch" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Manual conv2d implementation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def im2col(data, filter_h, filter_w):\n", | |
" filter_size = filter_h * filter_w\n", | |
" h, w = data.shape\n", | |
" filter_down = h - filter_h + 1\n", | |
" filter_right = w - filter_w + 1\n", | |
" output = np.zeros((filter_size, filter_down * filter_right)).astype(data.dtype)\n", | |
" ctr = 0\n", | |
" for row in range(0, filter_down):\n", | |
" for col in range(0, filter_right):\n", | |
" patch = data[row:(row+filter_h), col:(col+filter_w)]\n", | |
" patch = patch.reshape(1, filter_size)\n", | |
" output[:, ctr] = patch\n", | |
" ctr += 1\n", | |
" return output" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def conv2d(data, filt):\n", | |
" data_h, data_w = data.shape\n", | |
" filt_h, filt_w = filt.shape\n", | |
" unrolled = im2col(data, filt_h, filt_w)\n", | |
" mul = np.dot(filt.reshape(1, filt.size), unrolled)\n", | |
" new_dim_h = data_h - filt_h//2\n", | |
" new_dim_w = data_w - filt_w//2\n", | |
" mul = mul.reshape(new_dim_h, new_dim_w)\n", | |
" return mul" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.random.seed(42)\n", | |
"\n", | |
"data = np.random.randn(4, 4)\n", | |
"filt = np.random.randn(2, 2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[-3.24734738e-03, -1.67415333e+00, -2.69520520e+00],\n", | |
" [-1.76384602e-01, 8.95233552e-01, -2.79765292e-01],\n", | |
" [ 3.12842152e+00, 3.47826289e+00, 2.68339721e+00]])" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"manual_conv2d = conv2d(data, filt)\n", | |
"manual_conv2d" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### PyTorch Implementation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data_t = torch.from_numpy(data)\n", | |
"filt_t = torch.from_numpy(filt)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[[-3.2473e-03, -1.6742e+00, -2.6952e+00],\n", | |
" [-1.7638e-01, 8.9523e-01, -2.7977e-01],\n", | |
" [ 3.1284e+00, 3.4783e+00, 2.6834e+00]]]], dtype=torch.float64)" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pytorch_conv2d = torch.nn.functional.conv2d(data_t[None, None, ...], filt_t[None, None, ...])\n", | |
"pytorch_conv2d" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Does the manual implementation agree with PyTorch?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"assert np.allclose(manual_conv2d, pytorch_conv2d.numpy()), \"NumPy and PyTorch results don't match\"" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment