Skip to content

Instantly share code, notes, and snippets.

@gngdb
Created June 23, 2021 16:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gngdb/c1d227ede6aea2ac62afae64bb390590 to your computer and use it in GitHub Desktop.
Save gngdb/c1d227ede6aea2ac62afae64bb390590 to your computer and use it in GitHub Desktop.
Pooling notebook for checking einops channel pooling correctness
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "af0d71c1",
"metadata": {},
"outputs": [],
"source": [
"from torch.nn import MaxPool1d\n",
"import torch.nn.functional as F\n",
"\n",
"\n",
"class ChannelPool(MaxPool1d):\n",
" def forward(self, input):\n",
" n, c, w, h = input.size()\n",
" input = input.view(n, c, w * h).permute(0, 2, 1)\n",
" pooled = F.max_pool1d(\n",
" input,\n",
" self.kernel_size,\n",
" self.stride,\n",
" self.padding,\n",
" self.dilation,\n",
" self.ceil_mode,\n",
" self.return_indices,\n",
" )\n",
" _, _, c = pooled.size()\n",
" pooled = pooled.permute(0, 2, 1)\n",
" return pooled.view(n, c, w, h)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1ca5fd19",
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ff163bfd",
"metadata": {},
"outputs": [],
"source": [
"pool = ChannelPool(2)\n",
"x = torch.randn(2, 4, 16, 16)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "7d2123e5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 2, 16, 16])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y = pool(x)\n",
"y.size()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "5bb44b47",
"metadata": {},
"outputs": [],
"source": [
"from torch.nn import MaxPool1d\n",
"import torch.nn.functional as F\n",
"from einops import rearrange\n",
"\n",
"\n",
"class ChannelPool(MaxPool1d):\n",
" def forward(self, input):\n",
" n, c, w, h = input.size()\n",
" pool = lambda x: F.max_pool1d(\n",
" x,\n",
" self.kernel_size,\n",
" self.stride,\n",
" self.padding,\n",
" self.dilation,\n",
" self.ceil_mode,\n",
" self.return_indices,\n",
" )\n",
" return rearrange(\n",
" pool(rearrange(input, \"n c w h -> n (w h) c\")),\n",
" \"n (w h) c -> n c w h\",\n",
" n=n,\n",
" w=w,\n",
" h=h,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b34dc47e",
"metadata": {},
"outputs": [],
"source": [
"pool = ChannelPool(2)\n",
"_y = pool(x)\n",
"assert torch.abs(y - _y).max() < 1e-5"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment