Skip to content

Instantly share code, notes, and snippets.

@baldassarreFe
Last active December 7, 2020 00:02
Show Gist options
  • Save baldassarreFe/513d798c089efa2c4feb0562dd1cfaf5 to your computer and use it in GitHub Desktop.
Save baldassarreFe/513d798c089efa2c4feb0562dd1cfaf5 to your computer and use it in GitHub Desktop.
Feature pyramid network and Region of Interest Align
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Feature pyramid network and Region of Interest Align"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.7.0\n",
"0.8.0a0+45f960c\n"
]
}
],
"source": [
"from math import sqrt\n",
"\n",
"import torch\n",
"import torchvision as tv\n",
"\n",
"print(torch.__version__)\n",
"print(tv.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Backbone\n",
"\n",
"Backbone network, e.g. ResNet, at every layer the spatial resolution is halved and the number of channels is doubled."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input: (2, 3, 800, 1280)\n",
"\n",
"Backbone (B, C, H, W):\n",
"- block1: (2, 16, 400, 640)\n",
"- block2: (2, 32, 200, 320)\n",
"- block3: (2, 64, 100, 160)\n"
]
}
],
"source": [
"height = 800\n",
"width = 1280\n",
"batch_size = 2\n",
"print(f\"Input: {(batch_size, 3, height, width)}\\n\")\n",
"\n",
"level_sizes = {\n",
" \"block1\": 16,\n",
" \"block2\": 32,\n",
" \"block3\": 64,\n",
"}\n",
"backbone_features = {\n",
" level: torch.rand(batch_size, size, height // 2 ** (i + 1), width // 2 ** (i + 1))\n",
" for i, (level, size) in enumerate(level_sizes.items())\n",
"}\n",
"\n",
"print(\"Backbone (B, C, H, W):\")\n",
"for level, feats in backbone_features.items():\n",
" print(f\"- {level}: {tuple(feats.shape)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Feature pyramid network\n",
"\n",
"A feature pyramid network attaches a `1x1` conv to each layer of the backbone (lateral connection) and produces a pyramid of features by means of upsampling and summation.\n",
"Each pyramid layer is further processed though a `3x3` convolution with padding 1.\n",
"\n",
"When we run the unintialized FPN we get random features, but then we fill the output with predictable values:\n",
"- units correspond to the image index\n",
"- tens correspond to the pyramid level"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Feature pyramid (B, C, H, W):\n",
"- block1: (2, 16, 400, 640) [10, 11]\n",
"- block2: (2, 16, 200, 320) [20, 21]\n",
"- block3: (2, 16, 100, 160) [30, 31]\n"
]
}
],
"source": [
"fpn = tv.ops.FeaturePyramidNetwork(\n",
" in_channels_list=list(level_sizes.values()),\n",
" out_channels=16,\n",
")\n",
"fpn_features = fpn.forward(backbone_features)\n",
"\n",
"fpn_features = {\n",
" level: torch.stack(\n",
" [\n",
" torch.full_like(t, fill_value=10 * (lvl_num + 1) + img_idx)\n",
" for img_idx, t in enumerate(fpn_features[level].unbind(dim=0))\n",
" ],\n",
" dim=0,\n",
" )\n",
" for lvl_num, level in enumerate(level_sizes.keys())\n",
"}\n",
"\n",
"print(\"Feature pyramid (B, C, H, W):\")\n",
"for level, feats in fpn_features.items():\n",
" print(f\"- {level}: {tuple(feats.shape)} {[t[0,0,0].int().item() for t in feats]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"RoI align on top of pyramid features"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"boxes_list = [\n",
" # Image 0: boxes of areas from 1% to 60% of the image\n",
" (\n",
" torch.tensor([0, 0, height, width])[None, :]\n",
" * torch.tensor([0.01, 0.06, 0.08, 0.15, 0.24, 0.26, 0.50, 0.60]).sqrt()[:, None]\n",
" ),\n",
" # Image 1: just a couple of boxes\n",
" torch.tensor(\n",
" [\n",
" [0, 0, 10, 900], # very thin long horiz box\n",
" [0, 0, 450, 20], # very thin long vert box\n",
" ]\n",
" ).float(),\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Default `canonical_scale` and `canonical_level`"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Image 0\n",
"[ 0. 0. 80. 128.] 1% (16, 3, 3) -> pooled from block 2\n",
"[ 0. 0. 196. 314.] 6% (16, 3, 3) -> pooled from block 3\n",
"[ 0. 0. 226. 362.] 8% (16, 3, 3) -> pooled from block 3\n",
"[ 0. 0. 310. 496.] 15% (16, 3, 3) -> pooled from block 3\n",
"[ 0. 0. 392. 627.] 24% (16, 3, 3) -> pooled from block 3\n",
"[ 0. 0. 408. 653.] 26% (16, 3, 3) -> pooled from block 3\n",
"[ 0. 0. 566. 905.] 50% (16, 3, 3) -> pooled from block 3\n",
"[ 0. 0. 620. 991.] 60% (16, 3, 3) -> pooled from block 3\n",
"\n",
"Image 1\n",
"[ 0. 0. 10. 900.] 1% (16, 3, 3) -> pooled from block 2\n",
"[ 0. 0. 450. 20.] 1% (16, 3, 3) -> pooled from block 2\n",
"\n"
]
}
],
"source": [
"roi_align = tv.ops.MultiScaleRoIAlign(\n",
" featmap_names=list(level_sizes.keys()), output_size=(3, 3), sampling_ratio=2\n",
")\n",
"\n",
"box_features = roi_align(\n",
" fpn_features,\n",
" boxes_list,\n",
" image_shapes=[(height, width)] * batch_size,\n",
")\n",
"box_features = torch.split(\n",
" box_features, split_size_or_sections=[len(b) for b in boxes_list]\n",
")\n",
"\n",
"\n",
"for img_idx, (boxes, box_feats) in enumerate(zip(boxes_list, box_features)):\n",
" print(f\"Image {img_idx}\")\n",
" areas = tv.ops.box_area(boxes)\n",
" for b, a, f in zip(boxes, areas, box_feats):\n",
" ff = f[0, 0, 0].int().item()\n",
" assert img_idx == ff % 10\n",
" level_idx = ff // 10\n",
" print(\n",
" f\"{b.numpy().round(0)} {a/(height*width):>4.0%} {tuple(f.shape)} \"\n",
" f\"-> pooled from block {level_idx}\"\n",
" )\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Forcing the desired `canonical_scale` and `canonical_level`"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Image 0\n",
"[ 0. 0. 80. 128.] 1% (16, 3, 3) -> pooled from block 1\n",
"[ 0. 0. 196. 314.] 6% (16, 3, 3) -> pooled from block 1\n",
"[ 0. 0. 226. 362.] 8% (16, 3, 3) -> pooled from block 2\n",
"[ 0. 0. 310. 496.] 15% (16, 3, 3) -> pooled from block 2\n",
"[ 0. 0. 392. 627.] 24% (16, 3, 3) -> pooled from block 2\n",
"[ 0. 0. 408. 653.] 26% (16, 3, 3) -> pooled from block 3\n",
"[ 0. 0. 566. 905.] 50% (16, 3, 3) -> pooled from block 3\n",
"[ 0. 0. 620. 991.] 60% (16, 3, 3) -> pooled from block 3\n",
"\n",
"Image 1\n",
"[ 0. 0. 10. 900.] 1% (16, 3, 3) -> pooled from block 1\n",
"[ 0. 0. 450. 20.] 1% (16, 3, 3) -> pooled from block 1\n",
"\n"
]
}
],
"source": [
"roi_align = tv.ops.MultiScaleRoIAlign(\n",
" featmap_names=list(level_sizes.keys()), output_size=(3, 3), sampling_ratio=2\n",
")\n",
"roi_align.setup_scales(\n",
" list(fpn_features.values()), image_shapes=[(height, width)] * batch_size\n",
")\n",
"roi_align.map_levels = tv.ops.poolers.initLevelMapper(\n",
" k_min=roi_align.map_levels.k_min,\n",
" k_max=roi_align.map_levels.k_max,\n",
" canonical_scale=sqrt(width * height),\n",
" canonical_level=4,\n",
")\n",
"\n",
"box_features = roi_align(\n",
" fpn_features,\n",
" boxes_list,\n",
" image_shapes=[(height, width)] * batch_size,\n",
")\n",
"box_features = torch.split(\n",
" box_features, split_size_or_sections=[len(b) for b in boxes_list]\n",
")\n",
"\n",
"\n",
"for img_idx, (boxes, box_feats) in enumerate(zip(boxes_list, box_features)):\n",
" print(f\"Image {img_idx}\")\n",
" areas = tv.ops.box_area(boxes)\n",
" for b, a, f in zip(boxes, areas, box_feats):\n",
" ff = f[0, 0, 0].int().item()\n",
" assert img_idx == ff % 10\n",
" level_idx = ff // 10\n",
" print(\n",
" f\"{b.numpy().round(0)} {a/(height*width):>4.0%} {tuple(f.shape)} \"\n",
" f\"-> pooled from block {level_idx}\"\n",
" )\n",
" print()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment