Skip to content

Instantly share code, notes, and snippets.

@mkolod
Last active October 6, 2019 20:50
Show Gist options
  • Save mkolod/956069e8ee345ba11b0679cf864695c7 to your computer and use it in GitHub Desktop.
Save mkolod/956069e8ee345ba11b0679cf864695c7 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import numpy as np\n",
"import torch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Manual conv2d implementation"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def im2col(data, filter_h, filter_w):\n",
" filter_size = filter_h * filter_w\n",
" h, w = data.shape\n",
" filter_down = h - filter_h + 1\n",
" filter_right = w - filter_w + 1\n",
" output = np.zeros((filter_size, filter_down * filter_right)).astype(data.dtype)\n",
" ctr = 0\n",
" for row in range(0, filter_down):\n",
" for col in range(0, filter_right):\n",
" patch = data[row:(row+filter_h), col:(col+filter_w)]\n",
" patch = patch.reshape(1, filter_size)\n",
" output[:, ctr] = patch\n",
" ctr += 1\n",
" return output"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def conv2d(data, filt):\n",
" data_h, data_w = data.shape\n",
" filt_h, filt_w = filt.shape\n",
" unrolled = im2col(data, filt_h, filt_w)\n",
" mul = np.dot(filt.reshape(1, filt.size), unrolled)\n",
" new_dim_h = data_h - filt_h//2\n",
" new_dim_w = data_w - filt_w//2\n",
" mul = mul.reshape(new_dim_h, new_dim_w)\n",
" return mul"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(42)\n",
"\n",
"data = np.random.randn(4, 4)\n",
"filt = np.random.randn(2, 2)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[-3.24734738e-03, -1.67415333e+00, -2.69520520e+00],\n",
" [-1.76384602e-01, 8.95233552e-01, -2.79765292e-01],\n",
" [ 3.12842152e+00, 3.47826289e+00, 2.68339721e+00]])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"manual_conv2d = conv2d(data, filt)\n",
"manual_conv2d"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### PyTorch Implementation"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"data_t = torch.from_numpy(data)\n",
"filt_t = torch.from_numpy(filt)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[-3.2473e-03, -1.6742e+00, -2.6952e+00],\n",
" [-1.7638e-01, 8.9523e-01, -2.7977e-01],\n",
" [ 3.1284e+00, 3.4783e+00, 2.6834e+00]]]], dtype=torch.float64)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pytorch_conv2d = torch.nn.functional.conv2d(data_t[None, None, ...], filt_t[None, None, ...])\n",
"pytorch_conv2d"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Does the manual implementation agree with PyTorch?"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"assert np.allclose(manual_conv2d, pytorch_conv2d.numpy()), \"NumPy and PyTorch results don't match\""
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment