Skip to content

Instantly share code, notes, and snippets.

@ypwhs
Created April 16, 2019 11:30
Show Gist options
  • Save ypwhs/46acf83d0e31c4431c473914f8be3387 to your computer and use it in GitHub Desktop.
Save ypwhs/46acf83d0e31c4431c473914f8be3387 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(96, 192, 32, 64)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ch_in, ch_out = 128, 256\n",
"alpha_in, alpha_out = 0.25, 0.25\n",
"\n",
"stride = 1\n",
"kernel_size = 3\n",
"padding = 1\n",
"stride = 1\n",
"\n",
"hf_ch_in = int(ch_in * (1 - alpha_in))\n",
"hf_ch_out = int(ch_out * (1 - alpha_out))\n",
"\n",
"lf_ch_in = ch_in - hf_ch_in\n",
"lf_ch_out = ch_out - hf_ch_out\n",
"\n",
"hf_ch_in, hf_ch_out, lf_ch_in, lf_ch_out"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"hf_data = torch.zeros((1, hf_ch_in, 32, 32))\n",
"lf_data = torch.zeros((1, lf_ch_in, 16, 16))\n",
"\n",
"if stride == 2:\n",
" hf_data = nn.AvgPool2d(2)(hf_data)\n",
"\n",
"conv_hh = nn.Conv2d(hf_ch_in, hf_ch_out, kernel_size, padding=padding)\n",
"conv_hl = nn.Conv2d(hf_ch_in, lf_ch_out, kernel_size, padding=padding)\n",
"conv_lh = nn.Conv2d(lf_ch_in, hf_ch_out, kernel_size, padding=padding)\n",
"conv_ll = nn.Conv2d(lf_ch_in, lf_ch_out, kernel_size, padding=padding)\n",
"\n",
"hf_conv = conv_hh(hf_data)\n",
"hf_pool_conv = conv_hl(nn.AvgPool2d(2)(hf_data))\n",
"lf_conv = conv_lh(lf_data)\n",
"\n",
"if stride == 2:\n",
" lf_upsample = lf_conv\n",
" lf_down = nn.AvgPool2d(2)(lf_data)\n",
"else:\n",
" lf_upsample = F.interpolate(lf_conv, scale_factor=2)\n",
" lf_down = lf_data\n",
"\n",
"lf_down_conv = conv_ll(lf_down)\n",
"\n",
"out_h = hf_conv + lf_upsample\n",
"out_l = hf_pool_conv + lf_down_conv"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1, 192, 32, 32]), torch.Size([1, 64, 16, 16]))"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out_h.shape, out_l.shape"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([192, 96, 3, 3]),\n",
" torch.Size([64, 96, 3, 3]),\n",
" torch.Size([192, 32, 3, 3]),\n",
" torch.Size([64, 32, 3, 3]))"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"conv_hh.weight.shape, conv_hl.weight.shape, conv_lh.weight.shape, conv_ll.weight.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment