Skip to content

Instantly share code, notes, and snippets.

@SharanSMenon
Last active November 10, 2023 04:00
Show Gist options
  • Save SharanSMenon/19fdcfadea8fe9ed6a059f940fbbdbb6 to your computer and use it in GitHub Desktop.
Save SharanSMenon/19fdcfadea8fe9ed6a059f940fbbdbb6 to your computer and use it in GitHub Desktop.
EfficientNet Implementation PyTorch

EfficientNet Implementation

An implementation of Efficientnet in PyTorch

Results

Pytorch

B0:

  • Animals 138: 84%

B1:

  • Animals 138: 85%

Sources

Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "bd75bfa0-8ff3-4a56-a222-0a73068bb4ff",
"metadata": {},
"source": [
"# EfficientNet in PyTorch"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0e55ee58-7dcc-485b-91ec-5329b248bf06",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn, optim\n",
"import math\n",
"import os\n",
"from torchinfo import summary"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "09d41e73-3720-4f39-916b-7602e40fedaf",
"metadata": {},
"outputs": [],
"source": [
"def conv_block(in_channels, out_channels, kernel_size=3, \n",
" stride=1, padding=0, groups=1,\n",
" bias=False, bn=True, act = True):\n",
" layers = [\n",
" nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, \n",
" padding=padding, groups=groups, bias=bias),\n",
" nn.BatchNorm2d(out_channels) if bn else nn.Identity(),\n",
" nn.SiLU() if act else nn.Identity()\n",
" ]\n",
" return nn.Sequential(*layers)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "bf59c00f-f320-45be-86f8-6023c2d1f749",
"metadata": {},
"outputs": [],
"source": [
"class SEBlock(nn.Module):\n",
" def __init__(self, c, r=24):\n",
" super(SEBlock, self).__init__()\n",
" self.squeeze = nn.AdaptiveMaxPool2d(1)\n",
" self.excitation = nn.Sequential(\n",
" nn.Conv2d(c, c // r, kernel_size=1),\n",
" nn.SiLU(),\n",
" nn.Conv2d(c // r, c, kernel_size=1),\n",
" nn.Sigmoid()\n",
" )\n",
" def forward(self, x):\n",
" s = self.squeeze(x)\n",
" e = self.excitation(s)\n",
" return x * e"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "df21a513-e747-4d9f-a2ae-4378ba6e3c91",
"metadata": {},
"outputs": [],
"source": [
"class MBConv(nn.Module):\n",
" def __init__(self, n_in, n_out, expansion, kernel_size=3, stride=1, r=24, dropout=0.1):\n",
" super(MBConv, self).__init__()\n",
" self.skip_connection = (n_in == n_out) and (stride == 1)\n",
" padding = (kernel_size-1)//2\n",
" expanded = expansion*n_in\n",
" \n",
" self.expand_pw = nn.Identity() if expansion == 1 else conv_block(n_in, expanded, kernel_size=1)\n",
" self.depthwise = conv_block(expanded, expanded, kernel_size=kernel_size, \n",
" stride=stride, padding=padding, groups=expanded)\n",
" self.se = SEBlock(expanded, r=r)\n",
" self.reduce_pw = conv_block(expanded, n_out, kernel_size=1, act=False)\n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" def forward(self, x):\n",
" residual = x\n",
" x = self.expand_pw(x)\n",
" x = self.depthwise(x)\n",
" x = self.se(x)\n",
" x = self.reduce_pw(x)\n",
" if self.skip_connection:\n",
" x = self.dropout(x)\n",
" x = x + residual\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "777c11c8-e042-4b52-ba2a-a57146e225d1",
"metadata": {},
"outputs": [],
"source": [
"def mbconv1(n_in, n_out, kernel_size=3, stride=1, r=24, dropout=0.1):\n",
" return MBConv(n_in, n_out, 1, kernel_size=kernel_size, stride=stride, r=r, dropout=dropout)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "115dc875-517d-4cf5-9441-0344f144b69d",
"metadata": {},
"outputs": [],
"source": [
"def mbconv6(n_in, n_out, kernel_size=3, stride=1, r=24, dropout=0.1):\n",
" return MBConv(n_in, n_out, 6, kernel_size=kernel_size, stride=stride, r=r, dropout=dropout)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "90de185f-a60f-4bd3-a22f-4bb090d22538",
"metadata": {},
"outputs": [],
"source": [
"def create_stage(n_in, n_out, num_layers, layer=mbconv6, \n",
" kernel_size=3, stride=1, r=24, ps=0):\n",
" layers = [layer(n_in, n_out, kernel_size=kernel_size,\n",
" stride=stride, r=r, dropout=ps)]\n",
" layers += [layer(n_out, n_out, kernel_size=kernel_size,\n",
" r=r, dropout=ps) for _ in range(num_layers-1)]\n",
" return nn.Sequential(*layers)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "08539591-434f-4457-824e-8840a97a6ccf",
"metadata": {},
"outputs": [],
"source": [
"def scale_width(w, w_factor):\n",
" w *= w_factor\n",
" new_w = (int(w+4) // 8) * 8\n",
" new_w = max(8, new_w)\n",
" if new_w < 0.9*w:\n",
" new_w += 8\n",
" return int(new_w)"
]
},
{
"cell_type": "markdown",
"id": "d9ad5a73-ed3d-4cfa-849a-d292782cd6c4",
"metadata": {},
"source": [
"EfficientNet Base structure\n",
"\n",
"| Stage (i) | Layer | Resolution | Channels | Layers |\n",
"|-----------|-----------|------------|----------|--------|\n",
"| 1 | `mbconv1` | 224 x 224 | 32 | 1 |\n",
"| 2 | `mbconv6` | 112 x 112 | 16 | 1 |\n",
"| 3 | `mbconv6` | 112 x 112 | 24 | 2 |\n",
"| 4 | `mbconv6` | 56 x 56 | 40 | 2 |\n",
"| 5 | `mbconv6` | 28 x 28 | 80 | 3 |\n",
"| 6 | `mbconv6` | 14 x 14 | 112 | 3 |\n",
"| 7 | `mbconv6` | 14 x 14 | 192 | 4 |\n",
"| 8 | `mbconv6` | 7 x 7 | 320 | 1 |\n",
"| 9 | `mbconv6` | 7 x 7 | 1080 | 1 |"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "83b2baa2-4923-4a05-8f33-9429214baa3f",
"metadata": {},
"outputs": [],
"source": [
"### Obtained from Paper ###\n",
"base_widths = [(32, 16), (16, 24), (24, 40),\n",
" (40, 80), (80, 112), (112, 192),\n",
" (192, 320), (320, 1280)]\n",
"base_depths = [1, 2, 2, 3, 3, 4, 1]\n",
"kernel_sizes = [3, 3, 5, 3, 5, 5, 3]\n",
"strides = [1, 2, 2, 2, 1, 2, 1]\n",
"ps = [0, 0.029, 0.057, 0.086, 0.114, 0.143, 0.171]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "22099a0a-2579-4268-ac29-52db221b5366",
"metadata": {},
"outputs": [],
"source": [
"def efficientnet_gen(w_factor=1, d_factor=1):\n",
" scaled_widths = [(scale_width(w[0], w_factor), scale_width(w[1], w_factor)) \n",
" for w in base_widths]\n",
" scaled_depths = [math.ceil(d_factor*d) for d in base_depths]\n",
" return scaled_widths, scaled_depths"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "08cdce48-2c2e-4bc7-bbd0-c8ab6bc02113",
"metadata": {},
"outputs": [],
"source": [
"class EfficientNet(nn.Module):\n",
" def __init__(self, w_factor=1, d_factor=1, n_classes=1000):\n",
" super(EfficientNet, self).__init__()\n",
" scaled_widths, scaled_depths = efficientnet_gen(w_factor=w_factor, d_factor=d_factor)\n",
" \n",
" self.conv1 = conv_block(3, scaled_widths[0][0], stride=2, padding=1)\n",
" stages = [\n",
" create_stage(*scaled_widths[i], scaled_depths[i], layer= mbconv1 if i==0 else mbconv6, \n",
" kernel_size=kernel_sizes[i], stride=strides[i], r= 4 if i==0 else 24, ps=ps[i]) for i in range(7)\n",
" ]\n",
" self.stages = nn.Sequential(*stages)\n",
" self.pre = conv_block(*scaled_widths[-1], kernel_size=1)\n",
" self.pool_flatten = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten())\n",
" self.head = nn.Sequential(\n",
" nn.Linear(scaled_widths[-1][1], n_classes)\n",
" )\n",
" \n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = self.stages(x)\n",
" x = self.pre(x)\n",
" x = self.pool_flatten(x)\n",
" x = self.head(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "9fe3126b-d63a-4fdb-b081-698f07092cc1",
"metadata": {},
"outputs": [],
"source": [
"def efficientnet_b0(n_classes=1000):\n",
" return EfficientNet(n_classes=n_classes)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "335c91f7-60a7-4397-b2ea-ba9755e0a259",
"metadata": {},
"outputs": [],
"source": [
"def efficientnet_b1(n_classes=1000):\n",
" return EfficientNet(1, 1.1, n_classes=n_classes)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "a9ae8295-66e0-4788-861e-af81cf33e099",
"metadata": {},
"outputs": [],
"source": [
"def efficientnet_b2(n_classes=1000):\n",
" return EfficientNet(1.1, 1.2, n_classes=n_classes)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "227832fc-48b7-4c75-ae40-86213f008c91",
"metadata": {},
"outputs": [],
"source": [
"def efficientnet_b3(n_classes=1000):\n",
" return EfficientNet(1.2, 1.4, n_classes=n_classes)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "cdfaff84-9f4a-4ba6-94f5-aa07ab81d22f",
"metadata": {},
"outputs": [],
"source": [
"def efficientnet_b4(n_classes=1000):\n",
" return EfficientNet(1.4, 1.8, n_classes=n_classes)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "05634592-a1c0-4692-88da-fba8ac2ea017",
"metadata": {},
"outputs": [],
"source": [
"def efficientnet_b5(n_classes=1000):\n",
" return EfficientNet(1.6, 2.2, n_classes=n_classes)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "d6a66f21-9dea-4310-afc9-4c28b118ef52",
"metadata": {},
"outputs": [],
"source": [
"def efficientnet_b6(n_classes=1000):\n",
" return EfficientNet(1.8, 2.6, n_classes=n_classes)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "24462c7e-994e-494e-baba-d619055f816c",
"metadata": {},
"outputs": [],
"source": [
"def efficientnet_b7(n_classes=1000):\n",
" return EfficientNet(2, 3.1, n_classes=n_classes)"
]
},
{
"cell_type": "code",
"execution_count": 60,
"id": "ce69f255-c00b-4dc3-a044-37f1cc1be178",
"metadata": {},
"outputs": [],
"source": [
"b0 = efficientnet_b0()\n",
"b1 = efficientnet_b1()\n",
"b2 = efficientnet_b2()\n",
"b3 = efficientnet_b3()\n",
"b4 = efficientnet_b4()\n",
"b5 = efficientnet_b5()\n",
"b6 = efficientnet_b6()\n",
"b7 = efficientnet_b7()"
]
},
{
"cell_type": "code",
"execution_count": 61,
"id": "15858249-5a11-4e41-afe5-bd071b6c1417",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 1000])"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inp = torch.randn(1, 3, 224, 224)\n",
"b0(inp).shape"
]
},
{
"cell_type": "code",
"execution_count": 62,
"id": "43caafea-ef4e-4dd2-bff4-e0268b84a68d",
"metadata": {},
"outputs": [],
"source": [
"def print_size_of_model(model):\n",
" torch.save(model.state_dict(), \"temp.p\")\n",
" print('Size (MB):', os.path.getsize(\"temp.p\")/1e6)\n",
" os.remove('temp.p')"
]
},
{
"cell_type": "code",
"execution_count": 63,
"id": "a79c046d-882f-439b-9b60-9c0997baebdf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Size (MB): 21.446577\n",
"Size (MB): 31.600841\n",
"Size (MB): 36.885449\n",
"Size (MB): 49.479621\n",
"Size (MB): 78.111933\n",
"Size (MB): 122.546261\n",
"Size (MB): 173.400525\n",
"Size (MB): 267.054441\n"
]
}
],
"source": [
"print_size_of_model(b0)\n",
"print_size_of_model(b1)\n",
"print_size_of_model(b2)\n",
"print_size_of_model(b3)\n",
"print_size_of_model(b4)\n",
"print_size_of_model(b5)\n",
"print_size_of_model(b6)\n",
"print_size_of_model(b7)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "c3259af2-9648-4788-849e-d9878b6d90a5",
"metadata": {},
"outputs": [],
"source": [
"def fmat(n):\n",
" return \"{:.2f}M\".format(n / 1_000_000)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "32a21dc3-3bd4-48d6-902f-b0fd22fbf261",
"metadata": {},
"outputs": [],
"source": [
"def params(model, f=True):\n",
" s = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
" return fmat(s) if f else s"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "994369d9-3be0-4896-838a-b5ff7b94c0a2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('5.29M', '7.79M', '9.11M', '12.23M', '19.34M', '30.39M', '43.04M', '66.35M')"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"params(b0),params(b1), params(b2), params(b3), params(b4), params(b5), params(b6), params(b7)\n",
"# roughly equivalent to the params mentioned in paper \n",
"# (5.3M, 7.8M, 9.2M, 12M, 19M, 30M, 43M, 66M) <- param sizes in the paper"
]
},
{
"cell_type": "code",
"execution_count": 64,
"id": "d8f7b707-8cc1-43f7-be6a-30bd865bc6eb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"====================================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"====================================================================================================\n",
"EfficientNet -- --\n",
"├─Sequential: 1-1 [1, 32, 112, 112] --\n",
"│ └─Conv2d: 2-1 [1, 32, 112, 112] 864\n",
"│ └─BatchNorm2d: 2-2 [1, 32, 112, 112] 64\n",
"│ └─SiLU: 2-3 [1, 32, 112, 112] --\n",
"├─Sequential: 1-2 [1, 320, 7, 7] --\n",
"│ └─Sequential: 2-4 [1, 16, 112, 112] --\n",
"│ │ └─MBConv: 3-1 [1, 16, 112, 112] 1,448\n",
"│ └─Sequential: 2-5 [1, 24, 56, 56] --\n",
"│ │ └─MBConv: 3-2 [1, 24, 56, 56] 6,004\n",
"│ │ └─MBConv: 3-3 [1, 24, 56, 56] 10,710\n",
"│ └─Sequential: 2-6 [1, 40, 28, 28] --\n",
"│ │ └─MBConv: 3-4 [1, 40, 28, 28] 15,350\n",
"│ │ └─MBConv: 3-5 [1, 40, 28, 28] 31,290\n",
"│ └─Sequential: 2-7 [1, 80, 14, 14] --\n",
"│ │ └─MBConv: 3-6 [1, 80, 14, 14] 37,130\n",
"│ │ └─MBConv: 3-7 [1, 80, 14, 14] 102,900\n",
"│ │ └─MBConv: 3-8 [1, 80, 14, 14] 102,900\n",
"│ └─Sequential: 2-8 [1, 112, 14, 14] --\n",
"│ │ └─MBConv: 3-9 [1, 112, 14, 14] 126,004\n",
"│ │ └─MBConv: 3-10 [1, 112, 14, 14] 208,572\n",
"│ │ └─MBConv: 3-11 [1, 112, 14, 14] 208,572\n",
"│ └─Sequential: 2-9 [1, 192, 7, 7] --\n",
"│ │ └─MBConv: 3-12 [1, 192, 7, 7] 262,492\n",
"│ │ └─MBConv: 3-13 [1, 192, 7, 7] 587,952\n",
"│ │ └─MBConv: 3-14 [1, 192, 7, 7] 587,952\n",
"│ │ └─MBConv: 3-15 [1, 192, 7, 7] 587,952\n",
"│ └─Sequential: 2-10 [1, 320, 7, 7] --\n",
"│ │ └─MBConv: 3-16 [1, 320, 7, 7] 717,232\n",
"├─Sequential: 1-3 [1, 1280, 7, 7] --\n",
"│ └─Conv2d: 2-11 [1, 1280, 7, 7] 409,600\n",
"│ └─BatchNorm2d: 2-12 [1, 1280, 7, 7] 2,560\n",
"│ └─SiLU: 2-13 [1, 1280, 7, 7] --\n",
"├─Sequential: 1-4 [1, 1280] --\n",
"│ └─AdaptiveAvgPool2d: 2-14 [1, 1280, 1, 1] --\n",
"│ └─Flatten: 2-15 [1, 1280] --\n",
"├─Sequential: 1-5 [1, 1000] --\n",
"│ └─Linear: 2-16 [1, 1000] 1,281,000\n",
"====================================================================================================\n",
"Total params: 5,288,548\n",
"Trainable params: 5,288,548\n",
"Non-trainable params: 0\n",
"Total mult-adds (M): 385.87\n",
"====================================================================================================\n",
"Input size (MB): 0.60\n",
"Forward/backward pass size (MB): 107.89\n",
"Params size (MB): 21.15\n",
"Estimated Total Size (MB): 129.64\n",
"===================================================================================================="
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"summary(b0, (1, 3, 224, 224)) # pick a model."
]
},
{
"cell_type": "markdown",
"id": "e6fc8d9c-a52b-4b89-9f19-8344b70c418b",
"metadata": {},
"source": [
"End of notebook"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment