An implementation of Efficientnet in PyTorch
Pytorch
B0:
- Animals 138:
84%
B1:
- Animals 138:
85%
An implementation of Efficientnet in PyTorch
Pytorch
B0:
84%
B1:
85%
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "bd75bfa0-8ff3-4a56-a222-0a73068bb4ff", | |
"metadata": {}, | |
"source": [ | |
"# EfficientNet in PyTorch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "0e55ee58-7dcc-485b-91ec-5329b248bf06", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from torch import nn, optim\n", | |
"import math\n", | |
"import os\n", | |
"from torchinfo import summary" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "09d41e73-3720-4f39-916b-7602e40fedaf", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def conv_block(in_channels, out_channels, kernel_size=3, \n", | |
" stride=1, padding=0, groups=1,\n", | |
" bias=False, bn=True, act = True):\n", | |
" layers = [\n", | |
" nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, \n", | |
" padding=padding, groups=groups, bias=bias),\n", | |
" nn.BatchNorm2d(out_channels) if bn else nn.Identity(),\n", | |
" nn.SiLU() if act else nn.Identity()\n", | |
" ]\n", | |
" return nn.Sequential(*layers)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "bf59c00f-f320-45be-86f8-6023c2d1f749", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class SEBlock(nn.Module):\n", | |
" def __init__(self, c, r=24):\n", | |
" super(SEBlock, self).__init__()\n", | |
" self.squeeze = nn.AdaptiveMaxPool2d(1)\n", | |
" self.excitation = nn.Sequential(\n", | |
" nn.Conv2d(c, c // r, kernel_size=1),\n", | |
" nn.SiLU(),\n", | |
" nn.Conv2d(c // r, c, kernel_size=1),\n", | |
" nn.Sigmoid()\n", | |
" )\n", | |
" def forward(self, x):\n", | |
" s = self.squeeze(x)\n", | |
" e = self.excitation(s)\n", | |
" return x * e" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "df21a513-e747-4d9f-a2ae-4378ba6e3c91", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class MBConv(nn.Module):\n", | |
" def __init__(self, n_in, n_out, expansion, kernel_size=3, stride=1, r=24, dropout=0.1):\n", | |
" super(MBConv, self).__init__()\n", | |
" self.skip_connection = (n_in == n_out) and (stride == 1)\n", | |
" padding = (kernel_size-1)//2\n", | |
" expanded = expansion*n_in\n", | |
" \n", | |
" self.expand_pw = nn.Identity() if expansion == 1 else conv_block(n_in, expanded, kernel_size=1)\n", | |
" self.depthwise = conv_block(expanded, expanded, kernel_size=kernel_size, \n", | |
" stride=stride, padding=padding, groups=expanded)\n", | |
" self.se = SEBlock(expanded, r=r)\n", | |
" self.reduce_pw = conv_block(expanded, n_out, kernel_size=1, act=False)\n", | |
" self.dropout = nn.Dropout(dropout)\n", | |
" \n", | |
" def forward(self, x):\n", | |
" residual = x\n", | |
" x = self.expand_pw(x)\n", | |
" x = self.depthwise(x)\n", | |
" x = self.se(x)\n", | |
" x = self.reduce_pw(x)\n", | |
" if self.skip_connection:\n", | |
" x = self.dropout(x)\n", | |
" x = x + residual\n", | |
" return x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "777c11c8-e042-4b52-ba2a-a57146e225d1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def mbconv1(n_in, n_out, kernel_size=3, stride=1, r=24, dropout=0.1):\n", | |
" return MBConv(n_in, n_out, 1, kernel_size=kernel_size, stride=stride, r=r, dropout=dropout)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "115dc875-517d-4cf5-9441-0344f144b69d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def mbconv6(n_in, n_out, kernel_size=3, stride=1, r=24, dropout=0.1):\n", | |
" return MBConv(n_in, n_out, 6, kernel_size=kernel_size, stride=stride, r=r, dropout=dropout)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "90de185f-a60f-4bd3-a22f-4bb090d22538", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def create_stage(n_in, n_out, num_layers, layer=mbconv6, \n", | |
" kernel_size=3, stride=1, r=24, ps=0):\n", | |
" layers = [layer(n_in, n_out, kernel_size=kernel_size,\n", | |
" stride=stride, r=r, dropout=ps)]\n", | |
" layers += [layer(n_out, n_out, kernel_size=kernel_size,\n", | |
" r=r, dropout=ps) for _ in range(num_layers-1)]\n", | |
" return nn.Sequential(*layers)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "08539591-434f-4457-824e-8840a97a6ccf", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def scale_width(w, w_factor):\n", | |
" w *= w_factor\n", | |
" new_w = (int(w+4) // 8) * 8\n", | |
" new_w = max(8, new_w)\n", | |
" if new_w < 0.9*w:\n", | |
" new_w += 8\n", | |
" return int(new_w)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "d9ad5a73-ed3d-4cfa-849a-d292782cd6c4", | |
"metadata": {}, | |
"source": [ | |
"EfficientNet Base structure\n", | |
"\n", | |
"| Stage (i) | Layer | Resolution | Channels | Layers |\n", | |
"|-----------|-----------|------------|----------|--------|\n", | |
"| 1 | `mbconv1` | 224 x 224 | 32 | 1 |\n", | |
"| 2 | `mbconv6` | 112 x 112 | 16 | 1 |\n", | |
"| 3 | `mbconv6` | 112 x 112 | 24 | 2 |\n", | |
"| 4 | `mbconv6` | 56 x 56 | 40 | 2 |\n", | |
"| 5 | `mbconv6` | 28 x 28 | 80 | 3 |\n", | |
"| 6 | `mbconv6` | 14 x 14 | 112 | 3 |\n", | |
"| 7 | `mbconv6` | 14 x 14 | 192 | 4 |\n", | |
"| 8 | `mbconv6` | 7 x 7 | 320 | 1 |\n", | |
"| 9 | `mbconv6` | 7 x 7 | 1080 | 1 |" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "83b2baa2-4923-4a05-8f33-9429214baa3f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"### Obtained from Paper ###\n", | |
"base_widths = [(32, 16), (16, 24), (24, 40),\n", | |
" (40, 80), (80, 112), (112, 192),\n", | |
" (192, 320), (320, 1280)]\n", | |
"base_depths = [1, 2, 2, 3, 3, 4, 1]\n", | |
"kernel_sizes = [3, 3, 5, 3, 5, 5, 3]\n", | |
"strides = [1, 2, 2, 2, 1, 2, 1]\n", | |
"ps = [0, 0.029, 0.057, 0.086, 0.114, 0.143, 0.171]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "22099a0a-2579-4268-ac29-52db221b5366", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def efficientnet_gen(w_factor=1, d_factor=1):\n", | |
" scaled_widths = [(scale_width(w[0], w_factor), scale_width(w[1], w_factor)) \n", | |
" for w in base_widths]\n", | |
" scaled_depths = [math.ceil(d_factor*d) for d in base_depths]\n", | |
" return scaled_widths, scaled_depths" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"id": "08cdce48-2c2e-4bc7-bbd0-c8ab6bc02113", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class EfficientNet(nn.Module):\n", | |
" def __init__(self, w_factor=1, d_factor=1, n_classes=1000):\n", | |
" super(EfficientNet, self).__init__()\n", | |
" scaled_widths, scaled_depths = efficientnet_gen(w_factor=w_factor, d_factor=d_factor)\n", | |
" \n", | |
" self.conv1 = conv_block(3, scaled_widths[0][0], stride=2, padding=1)\n", | |
" stages = [\n", | |
" create_stage(*scaled_widths[i], scaled_depths[i], layer= mbconv1 if i==0 else mbconv6, \n", | |
" kernel_size=kernel_sizes[i], stride=strides[i], r= 4 if i==0 else 24, ps=ps[i]) for i in range(7)\n", | |
" ]\n", | |
" self.stages = nn.Sequential(*stages)\n", | |
" self.pre = conv_block(*scaled_widths[-1], kernel_size=1)\n", | |
" self.pool_flatten = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten())\n", | |
" self.head = nn.Sequential(\n", | |
" nn.Linear(scaled_widths[-1][1], n_classes)\n", | |
" )\n", | |
" \n", | |
" def forward(self, x):\n", | |
" x = self.conv1(x)\n", | |
" x = self.stages(x)\n", | |
" x = self.pre(x)\n", | |
" x = self.pool_flatten(x)\n", | |
" x = self.head(x)\n", | |
" return x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"id": "9fe3126b-d63a-4fdb-b081-698f07092cc1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def efficientnet_b0(n_classes=1000):\n", | |
" return EfficientNet(n_classes=n_classes)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 40, | |
"id": "335c91f7-60a7-4397-b2ea-ba9755e0a259", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def efficientnet_b1(n_classes=1000):\n", | |
" return EfficientNet(1, 1.1, n_classes=n_classes)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 41, | |
"id": "a9ae8295-66e0-4788-861e-af81cf33e099", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def efficientnet_b2(n_classes=1000):\n", | |
" return EfficientNet(1.1, 1.2, n_classes=n_classes)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 42, | |
"id": "227832fc-48b7-4c75-ae40-86213f008c91", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def efficientnet_b3(n_classes=1000):\n", | |
" return EfficientNet(1.2, 1.4, n_classes=n_classes)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 43, | |
"id": "cdfaff84-9f4a-4ba6-94f5-aa07ab81d22f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def efficientnet_b4(n_classes=1000):\n", | |
" return EfficientNet(1.4, 1.8, n_classes=n_classes)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 44, | |
"id": "05634592-a1c0-4692-88da-fba8ac2ea017", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def efficientnet_b5(n_classes=1000):\n", | |
" return EfficientNet(1.6, 2.2, n_classes=n_classes)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 45, | |
"id": "d6a66f21-9dea-4310-afc9-4c28b118ef52", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def efficientnet_b6(n_classes=1000):\n", | |
" return EfficientNet(1.8, 2.6, n_classes=n_classes)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 46, | |
"id": "24462c7e-994e-494e-baba-d619055f816c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def efficientnet_b7(n_classes=1000):\n", | |
" return EfficientNet(2, 3.1, n_classes=n_classes)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 60, | |
"id": "ce69f255-c00b-4dc3-a044-37f1cc1be178", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"b0 = efficientnet_b0()\n", | |
"b1 = efficientnet_b1()\n", | |
"b2 = efficientnet_b2()\n", | |
"b3 = efficientnet_b3()\n", | |
"b4 = efficientnet_b4()\n", | |
"b5 = efficientnet_b5()\n", | |
"b6 = efficientnet_b6()\n", | |
"b7 = efficientnet_b7()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 61, | |
"id": "15858249-5a11-4e41-afe5-bd071b6c1417", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([1, 1000])" | |
] | |
}, | |
"execution_count": 61, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"inp = torch.randn(1, 3, 224, 224)\n", | |
"b0(inp).shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 62, | |
"id": "43caafea-ef4e-4dd2-bff4-e0268b84a68d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def print_size_of_model(model):\n", | |
" torch.save(model.state_dict(), \"temp.p\")\n", | |
" print('Size (MB):', os.path.getsize(\"temp.p\")/1e6)\n", | |
" os.remove('temp.p')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 63, | |
"id": "a79c046d-882f-439b-9b60-9c0997baebdf", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Size (MB): 21.446577\n", | |
"Size (MB): 31.600841\n", | |
"Size (MB): 36.885449\n", | |
"Size (MB): 49.479621\n", | |
"Size (MB): 78.111933\n", | |
"Size (MB): 122.546261\n", | |
"Size (MB): 173.400525\n", | |
"Size (MB): 267.054441\n" | |
] | |
} | |
], | |
"source": [ | |
"print_size_of_model(b0)\n", | |
"print_size_of_model(b1)\n", | |
"print_size_of_model(b2)\n", | |
"print_size_of_model(b3)\n", | |
"print_size_of_model(b4)\n", | |
"print_size_of_model(b5)\n", | |
"print_size_of_model(b6)\n", | |
"print_size_of_model(b7)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 51, | |
"id": "c3259af2-9648-4788-849e-d9878b6d90a5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def fmat(n):\n", | |
" return \"{:.2f}M\".format(n / 1_000_000)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 52, | |
"id": "32a21dc3-3bd4-48d6-902f-b0fd22fbf261", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def params(model, f=True):\n", | |
" s = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", | |
" return fmat(s) if f else s" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 53, | |
"id": "994369d9-3be0-4896-838a-b5ff7b94c0a2", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"('5.29M', '7.79M', '9.11M', '12.23M', '19.34M', '30.39M', '43.04M', '66.35M')" | |
] | |
}, | |
"execution_count": 53, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"params(b0),params(b1), params(b2), params(b3), params(b4), params(b5), params(b6), params(b7)\n", | |
"# roughly equivalent to the params mentioned in paper \n", | |
"# (5.3M, 7.8M, 9.2M, 12M, 19M, 30M, 43M, 66M) <- param sizes in the paper" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 64, | |
"id": "d8f7b707-8cc1-43f7-be6a-30bd865bc6eb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"====================================================================================================\n", | |
"Layer (type:depth-idx) Output Shape Param #\n", | |
"====================================================================================================\n", | |
"EfficientNet -- --\n", | |
"├─Sequential: 1-1 [1, 32, 112, 112] --\n", | |
"│ └─Conv2d: 2-1 [1, 32, 112, 112] 864\n", | |
"│ └─BatchNorm2d: 2-2 [1, 32, 112, 112] 64\n", | |
"│ └─SiLU: 2-3 [1, 32, 112, 112] --\n", | |
"├─Sequential: 1-2 [1, 320, 7, 7] --\n", | |
"│ └─Sequential: 2-4 [1, 16, 112, 112] --\n", | |
"│ │ └─MBConv: 3-1 [1, 16, 112, 112] 1,448\n", | |
"│ └─Sequential: 2-5 [1, 24, 56, 56] --\n", | |
"│ │ └─MBConv: 3-2 [1, 24, 56, 56] 6,004\n", | |
"│ │ └─MBConv: 3-3 [1, 24, 56, 56] 10,710\n", | |
"│ └─Sequential: 2-6 [1, 40, 28, 28] --\n", | |
"│ │ └─MBConv: 3-4 [1, 40, 28, 28] 15,350\n", | |
"│ │ └─MBConv: 3-5 [1, 40, 28, 28] 31,290\n", | |
"│ └─Sequential: 2-7 [1, 80, 14, 14] --\n", | |
"│ │ └─MBConv: 3-6 [1, 80, 14, 14] 37,130\n", | |
"│ │ └─MBConv: 3-7 [1, 80, 14, 14] 102,900\n", | |
"│ │ └─MBConv: 3-8 [1, 80, 14, 14] 102,900\n", | |
"│ └─Sequential: 2-8 [1, 112, 14, 14] --\n", | |
"│ │ └─MBConv: 3-9 [1, 112, 14, 14] 126,004\n", | |
"│ │ └─MBConv: 3-10 [1, 112, 14, 14] 208,572\n", | |
"│ │ └─MBConv: 3-11 [1, 112, 14, 14] 208,572\n", | |
"│ └─Sequential: 2-9 [1, 192, 7, 7] --\n", | |
"│ │ └─MBConv: 3-12 [1, 192, 7, 7] 262,492\n", | |
"│ │ └─MBConv: 3-13 [1, 192, 7, 7] 587,952\n", | |
"│ │ └─MBConv: 3-14 [1, 192, 7, 7] 587,952\n", | |
"│ │ └─MBConv: 3-15 [1, 192, 7, 7] 587,952\n", | |
"│ └─Sequential: 2-10 [1, 320, 7, 7] --\n", | |
"│ │ └─MBConv: 3-16 [1, 320, 7, 7] 717,232\n", | |
"├─Sequential: 1-3 [1, 1280, 7, 7] --\n", | |
"│ └─Conv2d: 2-11 [1, 1280, 7, 7] 409,600\n", | |
"│ └─BatchNorm2d: 2-12 [1, 1280, 7, 7] 2,560\n", | |
"│ └─SiLU: 2-13 [1, 1280, 7, 7] --\n", | |
"├─Sequential: 1-4 [1, 1280] --\n", | |
"│ └─AdaptiveAvgPool2d: 2-14 [1, 1280, 1, 1] --\n", | |
"│ └─Flatten: 2-15 [1, 1280] --\n", | |
"├─Sequential: 1-5 [1, 1000] --\n", | |
"│ └─Linear: 2-16 [1, 1000] 1,281,000\n", | |
"====================================================================================================\n", | |
"Total params: 5,288,548\n", | |
"Trainable params: 5,288,548\n", | |
"Non-trainable params: 0\n", | |
"Total mult-adds (M): 385.87\n", | |
"====================================================================================================\n", | |
"Input size (MB): 0.60\n", | |
"Forward/backward pass size (MB): 107.89\n", | |
"Params size (MB): 21.15\n", | |
"Estimated Total Size (MB): 129.64\n", | |
"====================================================================================================" | |
] | |
}, | |
"execution_count": 64, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"summary(b0, (1, 3, 224, 224)) # pick a model." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e6fc8d9c-a52b-4b89-9f19-8344b70c418b", | |
"metadata": {}, | |
"source": [ | |
"End of notebook" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |