Skip to content

Instantly share code, notes, and snippets.

@cedrickchee
Last active May 14, 2019 15:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cedrickchee/eedd27ff231b5f66053eb409d711cf5e to your computer and use it in GitHub Desktop.
Save cedrickchee/eedd27ff231b5f66053eb409d711cf5e to your computer and use it in GitHub Desktop.
PyTorch implementation of "Searching for MobileNetV3" paper: https://arxiv.org/abs/1905.02244
import torch
import torch.nn as nn
import torch.nn.functional as F
class h_sigmoid(nn.Module):
def __init__(self, inplace=True):
super(h_sigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
return F.relu6(x + 3, inplace=True) / 6
class h_swish(nn.Module):
def __init__(self, inplace=True):
super(h_swish, self).__init__()
self.inplace = inplace
def forward(self, x):
return x * F.relu6(x + 3, inplace=self.inplace) / 6
class SqueezeBlock(nn.Module):
def __init__(self, in_size, reduction=4):
super(SqueezeBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.squeeze_block = nn.Sequential(
nn.Conv2d(
in_size,
in_size // reduction,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(in_size // reduction),
nn.ReLU(inplace=True),
nn.Conv2d(
in_size // reduction,
in_size,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(in_size),
h_sigmoid(),
)
def forward(self, x):
return x * self.squeeze_block(x)
class MobileBlock(nn.Module):
def __init__(
self, kernel_size, in_size, expand_size, out_size, nolinear, se_block, stride
):
super(MobileBlock, self).__init__()
self.stride = stride
self.squeeze_block = se_block
self.conv1 = nn.Conv2d(
in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False
)
self.bn1 = nn.BatchNorm2d(expand_size)
self.nolinear1 = nolinear
self.conv2 = nn.Conv2d(
expand_size,
expand_size,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2,
groups=expand_size,
bias=False
)
self.bn2 = nn.BatchNorm2d(expand_size)
self.nolinear2 = nolinear
self.conv3 = nn.Conv2d(
expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False
)
self.bn3 = nn.BatchNorm2d(out_size)
self.shortcut = nn.Sequential()
if stride == 1 and in_size != out_size:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False
),
nn.BatchNorm2d(out_size),
)
def forward(self, x):
out = self.nolinear1(self.bn1(self.conv1(x)))
out = self.nolinear2(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
if self.squeeze_block != None:
out = self.squeeze_block(out)
out = out + self.shortcut(x) if self.stride == 1 else out
return out
class MobileNetV3(nn.Module):
def __init__(self, variant="large", num_classes=1000):
super(MobileNetV3, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.hs1 = h_swish()
if variant == "large":
self.bneck = nn.Sequential(
MobileBlock(3, 16, 16, 16, nn.ReLU(inplace=True), None, 1),
MobileBlock(3, 16, 64, 24, nn.ReLU(inplace=True), None, 2),
MobileBlock(3, 24, 72, 24, nn.ReLU(inplace=True), None, 1),
MobileBlock(5, 24, 72, 40, nn.ReLU(inplace=True), SqueezeBlock(40), 2),
MobileBlock(5, 40, 120, 40, nn.ReLU(inplace=True), SqueezeBlock(40), 1),
MobileBlock(5, 40, 120, 40, nn.ReLU(inplace=True), SqueezeBlock(40), 1),
MobileBlock(3, 40, 240, 80, h_swish(), None, 2),
MobileBlock(3, 80, 200, 80, h_swish(), None, 1),
MobileBlock(3, 80, 184, 80, h_swish(), None, 1),
MobileBlock(3, 80, 184, 80, h_swish(), None, 1),
MobileBlock(3, 80, 480, 112, h_swish(), SqueezeBlock(112), 1),
MobileBlock(3, 112, 672, 112, h_swish(), SqueezeBlock(112), 1),
MobileBlock(5, 112, 672, 160, h_swish(), SqueezeBlock(160), 1),
MobileBlock(5, 160, 672, 160, h_swish(), SqueezeBlock(160), 2),
MobileBlock(5, 160, 960, 160, h_swish(), SqueezeBlock(160), 1),
)
self.conv2 = nn.Conv2d(
160, 960, kernel_size=1, stride=1, padding=0, bias=False
)
self.bn2 = nn.BatchNorm2d(960)
self.linear3 = nn.Linear(960, 1280)
elif variant == "small":
self.bneck = nn.Sequential(
MobileBlock(3, 16, 16, 16, nn.ReLU(inplace=True), SqueezeBlock(16), 2),
MobileBlock(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),
MobileBlock(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),
MobileBlock(5, 24, 96, 40, h_swish(), SqueezeBlock(40), 2),
MobileBlock(5, 40, 240, 40, h_swish(), SqueezeBlock(40), 1),
MobileBlock(5, 40, 240, 40, h_swish(), SqueezeBlock(40), 1),
MobileBlock(5, 40, 120, 48, h_swish(), SqueezeBlock(48), 1),
MobileBlock(5, 48, 144, 48, h_swish(), SqueezeBlock(48), 1),
MobileBlock(5, 48, 288, 96, h_swish(), SqueezeBlock(96), 2),
MobileBlock(5, 96, 576, 96, h_swish(), SqueezeBlock(96), 1),
MobileBlock(5, 96, 576, 96, h_swish(), SqueezeBlock(96), 1),
)
self.conv2 = nn.Conv2d(
96, 576, kernel_size=1, stride=1, padding=0, bias=False
)
self.bn2 = nn.BatchNorm2d(576)
self.linear3 = nn.Linear(576, 1280)
self.hs2 = h_swish()
self.bn3 = nn.BatchNorm1d(1280)
self.hs3 = h_swish()
self.linear4 = nn.Linear(1280, num_classes)
self.init_params()
def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
torch.nn.init.constant_(m.weight, 1)
torch.nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight, std=0.001)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
def forward(self, x):
out = self.hs1(self.bn1(self.conv1(x)))
out = self.bneck(out)
out = self.hs2(self.bn2(self.conv2(out)))
out = F.avg_pool2d(out, 7)
out = out.view(out.size(0), -1)
out = self.hs3(self.bn3(self.linear3(out)))
out = self.linear4(out)
return out
def main():
net = MobileNetV3(variant="large")
inp = torch.randn(2, 3, 224, 224)
out = net(inp)
print(out.size())
if __name__ == "__main__":
main()
@cedrickchee
Copy link
Author

cedrickchee commented May 14, 2019

Software requirements:

  • Python 3.6+
  • PyTorch 1.0.1

Usage

python model.py

Results

Based on the paper:

MADDS Parameters Top1-acc
Large 219 M 5.4 M 75.2%
Small 66 M 2.9 M 67.4%
Ours Large 263 M 3.7 M 75.454%
Ours Small 65 M 2.4 M 69.069%

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment