Skip to content

Instantly share code, notes, and snippets.

@PWhiddy
Created October 5, 2021 02:01
Show Gist options
  • Save PWhiddy/689155a8fb292d62ddc66e8cf53bcf56 to your computer and use it in GitHub Desktop.
Save PWhiddy/689155a8fb292d62ddc66e8cf53bcf56 to your computer and use it in GitHub Desktop.
Improved implementation of bbox_to_mask
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "49dcc530",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import random"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"@torch.jit.ignore\n",
"def validate_bbox(boxes: torch.Tensor) -> bool:\n",
" \"\"\"Validate if a 2D bounding box usable or not. This function checks if the boxes are rectangular or not.\n",
" Args:\n",
" boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape\n",
" of Bx4x2, where each box is defined in the following ``clockwise`` order: top-left, top-right, bottom-right,\n",
" bottom-left. The coordinates must be in the x, y order.\n",
" \"\"\"\n",
" if not (len(boxes.shape) == 3 and boxes.shape[1:] == torch.Size([4, 2])):\n",
" raise AssertionError(f\"Box shape must be (B, 4, 2). Got {boxes.shape}.\")\n",
"\n",
" if not torch.allclose((boxes[:, 1, 0] - boxes[:, 0, 0] + 1), (boxes[:, 2, 0] - boxes[:, 3, 0] + 1)):\n",
" raise ValueError(\n",
" \"Boxes must have be rectangular, while get widths %s and %s\"\n",
" % (str(boxes[:, 1, 0] - boxes[:, 0, 0] + 1), str(boxes[:, 2, 0] - boxes[:, 3, 0] + 1))\n",
" )\n",
"\n",
" if not torch.allclose((boxes[:, 2, 1] - boxes[:, 0, 1] + 1), (boxes[:, 3, 1] - boxes[:, 1, 1] + 1)):\n",
" raise ValueError(\n",
" \"Boxes must have be rectangular, while get heights %s and %s\"\n",
" % (str(boxes[:, 2, 1] - boxes[:, 0, 1] + 1), str(boxes[:, 3, 1] - boxes[:, 1, 1] + 1))\n",
" )\n",
"\n",
" return True\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "69c587df",
"metadata": {},
"outputs": [],
"source": [
"def bbox_to_mask_old(boxes: torch.Tensor, width: int, height: int) -> torch.Tensor:\n",
" \"\"\"Convert 2D bounding boxes to masks. Covered area is 1. and the remaining is 0.\n",
" Args:\n",
" boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape\n",
" of Bx4x2, where each box is defined in the following ``clockwise`` order: top-left, top-right, bottom-right\n",
" and bottom-left. The coordinates must be in the x, y order.\n",
" width: width of the masked image.\n",
" height: height of the masked image.\n",
" Returns:\n",
" the output mask tensor.\n",
" Note:\n",
" It is currently non-differentiable.\n",
" Examples:\n",
" >>> boxes = torch.tensor([[\n",
" ... [1., 1.],\n",
" ... [3., 1.],\n",
" ... [3., 2.],\n",
" ... [1., 2.],\n",
" ... ]]) # 1x4x2\n",
" >>> bbox_to_mask(boxes, 5, 5)\n",
" tensor([[[0., 0., 0., 0., 0.],\n",
" [0., 1., 1., 1., 0.],\n",
" [0., 1., 1., 1., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.]]])\n",
" \"\"\"\n",
" validate_bbox(boxes)\n",
" # zero padding the surroudings\n",
" mask = torch.zeros((len(boxes), height + 2, width + 2))\n",
" # push all points one pixel off\n",
" # in order to zero-out the fully filled rows or columns\n",
" boxes_shifted = boxes + 1\n",
"\n",
" mask_out = []\n",
" # TODO: Looking for a vectorized way\n",
" for m, box in zip(mask, boxes_shifted):\n",
" m = m.index_fill(1, torch.arange(box[0, 0].item(), box[1, 0].item() + 1, dtype=torch.long), torch.tensor(1))\n",
" m = m.index_fill(0, torch.arange(box[1, 1].item(), box[2, 1].item() + 1, dtype=torch.long), torch.tensor(1))\n",
" m = m.unsqueeze(dim=0)\n",
" m_out = (m == 1).all(dim=1) * (m == 1).all(dim=2).T\n",
" m_out = m_out[1:-1, 1:-1]\n",
" mask_out.append(m_out)\n",
"\n",
" return torch.stack(mask_out, dim=0).float()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def bbox_to_mask_new(boxes: torch.Tensor, width: int, height: int) -> torch.Tensor:\n",
" \"\"\"Convert 2D bounding boxes to masks. Covered area is 1. and the remaining is 0.\n",
" Args:\n",
" boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape\n",
" of Bx4x2, where each box is defined in the following ``clockwise`` order: top-left, top-right, bottom-right\n",
" and bottom-left. The coordinates must be in the x, y order.\n",
" width: width of the masked image.\n",
" height: height of the masked image.\n",
" Returns:\n",
" the output mask tensor.\n",
" Note:\n",
" It is currently non-differentiable.\n",
" Examples:\n",
" >>> boxes = torch.tensor([[\n",
" ... [1., 1.],\n",
" ... [3., 1.],\n",
" ... [3., 2.],\n",
" ... [1., 2.],\n",
" ... ]]) # 1x4x2\n",
" >>> bbox_to_mask(boxes, 5, 5)\n",
" tensor([[[0., 0., 0., 0., 0.],\n",
" [0., 1., 1., 1., 0.],\n",
" [0., 1., 1., 1., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.]]])\n",
" \"\"\"\n",
" validate_bbox(boxes)\n",
" # zero padding the surroudings\n",
" mask = torch.zeros((len(boxes), height + 2, width + 2), device=boxes.device)\n",
" # push all points one pixel off\n",
" # in order to zero-out the fully filled rows or columns\n",
" box_i = (boxes + 1).long()\n",
" # set all pixels within box to 1\n",
" mask[:, box_i[:, 0, 1]:box_i[:, 2, 1] + 1, \n",
" box_i[:, 0, 0]:box_i[:, 1, 0] + 1] = 1.0 \n",
" return mask[:, 1:-1, 1:-1]\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6b955970",
"metadata": {},
"outputs": [],
"source": [
"test_dev = 'cpu'"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "42f7a15d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[0., 0., 0., 0., 0.],\n",
" [0., 1., 1., 1., 0.],\n",
" [0., 1., 1., 1., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.]]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"boxes = torch.tensor([[\n",
" [1., 1.],\n",
" [3., 1.],\n",
" [3., 2.],\n",
" [1., 2.],\n",
" ]], device=test_dev) # 1x4x2\n",
"bbox_to_mask_old(boxes, 5, 5)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "30f58ded",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[0., 0., 0., 0., 0.],\n",
" [0., 1., 1., 1., 0.],\n",
" [0., 1., 1., 1., 0.],\n",
" [0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0.]]])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"boxes = torch.tensor([[\n",
" [1., 1.],\n",
" [3., 1.],\n",
" [3., 2.],\n",
" [1., 2.],\n",
" ]], device=test_dev) # 1x4x2\n",
"bbox_to_mask_new(boxes, 5, 5)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "30fccef8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.allclose(bbox_to_mask_old(boxes, 5, 5), bbox_to_mask_new(boxes, 5, 5))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def test_random_boxes(func):\n",
" random.seed(0)\n",
" w = random.randint(130, 180)\n",
" h = random.randint(70, 120)\n",
" bbox = torch.tensor([[\n",
" [10., 10.],\n",
" [w, 10.],\n",
" [w, h],\n",
" [10., h],\n",
" ]], device=test_dev)\n",
" return func(bbox, random.randint(200, 300), random.randint(200, 300))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.allclose(test_random_boxes(bbox_to_mask_old), test_random_boxes(bbox_to_mask_new))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "bc60c111",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"624 µs ± 12.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
}
],
"source": [
"%%timeit -n 1000\n",
"w = random.randint(130, 180)\n",
"h = random.randint(70, 120)\n",
"bbox = torch.tensor([[\n",
" [10., 10.],\n",
" [w, 10.],\n",
" [w, h],\n",
" [10., h],\n",
" ]], device=test_dev)\n",
"bbox_to_mask_old(bbox, random.randint(200, 300), random.randint(200, 300))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "c7f704f0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"257 µs ± 3.26 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
}
],
"source": [
"%%timeit -n 1000\n",
"w = random.randint(130, 180)\n",
"h = random.randint(70, 120)\n",
"bbox = torch.tensor([[\n",
" [10., 10.],\n",
" [w, 10.],\n",
" [w, h],\n",
" [10., h],\n",
" ]], device=test_dev)\n",
"bbox_to_mask_new(bbox, random.randint(200, 300), random.randint(200, 300))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "548a4217",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment