Skip to content

Instantly share code, notes, and snippets.

@AmirOfir
Created May 16, 2021 10:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AmirOfir/82882cc236745be2d4ea61900ded97c4 to your computer and use it in GitHub Desktop.
Save AmirOfir/82882cc236745be2d4ea61900ded97c4 to your computer and use it in GitHub Desktop.
An implementation of MobileNet v1 with libtorch c++
#pragma once
#include <torch/torch.h>
// Special thanks for wangvaton: https://github.com/wangvation/torch-mobilenet/blob/master/module/mobilenet.py
//# BSD 2 - Clause License
//
//# Copyright(c) 2019 wangvation.All rights reserved.
//
//# Redistribution and use in source and binary forms, with or without
//# modification, are permitted provided that the following conditions are met :
//
//# 1. Redistributions of source code must retain the above copyright notice,
//# this list of conditionsand the following disclaimer.
//
//# 2. Redistributions in binary form must reproduce the above copyright notice,
//# this list of conditionsand the following disclaimer in the documentation
//#and /or other materials provided with the distribution.
//
//# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
//# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
//# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
//# ARE DISCLAIMED.IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
//# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
//# OR CONSEQUENTIAL DAMAGES(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// # CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING NEGLIGENCE OR OTHERWISE)
// # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// # POSSIBILITY OF SUCH DAMAGE.
//# ============================================================================
float makeDivisable(int v, int divisor, int minValue = -1)
{
// This function is taken from the original tf repo.
// It ensures that all layers have a channel number that is divisible by 8
// It can be seen here : https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
// : param v :
// : param divisor :
// : param min_value :
// : return :
if (minValue == -1)
minValue = divisor;
int new_v = std::max(minValue, int((int(divisor / 2) + v) / divisor * divisor));
// make sure that round down does not go down by more than 10 % .
if (new_v < 0.9 * v)
new_v += divisor;
return new_v;
}
class DepthSepConvImpl : public torch::nn::Cloneable<DepthSepConvImpl>
{
int _in_channels;
int _out_channels;
torch::nn::Conv2d _depthwise_conv;
torch::nn::BatchNorm2d _bn1;
torch::nn::Conv2d _pointwise_conv;
torch::nn::BatchNorm2d _bn2;
public:
DepthSepConvImpl(int in_channels, int out_channels, int kernel_dim, int stride, int padding, int multiplier=1)
: _in_channels(makeDivisable(in_channels * multiplier, 8)),
_out_channels(makeDivisable(out_channels* multiplier, 8)),
_depthwise_conv(register_module("depthwise_conv",
torch::nn::Conv2d(
torch::nn::Conv2dOptions(_in_channels, _in_channels, kernel_dim).stride(stride).padding(padding).groups(in_channels).bias(false)))),
_bn1(register_module("bn1",
torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(_in_channels)))),
_pointwise_conv(register_module("pointwise_conv",
torch::nn::Conv2d(torch::nn::Conv2dOptions(_in_channels, _out_channels, 1).stride(1).groups(1).bias(false)))),
_bn2(register_module("bn2",
torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(_out_channels))))
{
}
torch::Tensor forward(const torch::Tensor& input)
{
auto x = _depthwise_conv(input);
x = _bn1(x);
x = relu(x);
x = _pointwise_conv(x);
x = _bn2(x);
x = relu(x);
return x;
}
void reset() { }
virtual void train(bool on = true) override
{
}
};
TORCH_MODULE(DepthSepConv);
class MobileNetV1Impl : public torch::nn::Cloneable<MobileNetV1Impl>
{
/*
docstring for MobileNetV1
MobileNetV1 Body Architecture
| Type / Stride | Filter Shape | Input Size | Output Size |
| : ------------ | : ------------------ | : ------------ - | : ------------ - |
| Conv / s2 | 3 × 3 × 3 × 32 | 224 x 224 x 3 | 112 x 112 x 32 |
| Conv dw / s1 | 3 × 3 × 32 dw | 112 x 112 x 32 | 112 x 112 x 32 |
| Conv / s1 | 1 × 1 × 32 x 64 | 112 x 112 x 32 | 112 x 112 x 64 |
| Conv dw / s2 | 3 × 3 × 64 dw | 112 x 112 x 64 | 56 x 56 x 64 |
| Conv / s1 | 1 × 1 × 64 × 128 | 56 x 56 x 64 | 56 x 56 x 128 |
| Conv dw / s1 | 3 × 3 × 128 dw | 56 x 56 x 128 | 56 x 56 x 128 |
| Conv / s1 | 1 × 1 × 128 × 128 | 56 x 56 x 128 | 56 x 56 x 128 |
| Conv dw / s2 | 3 × 3 × 128 dw | 56 x 56 x 128 | 28 x 28 x 128 |
| Conv / s1 | 1 × 1 × 128 × 256 | 28 x 28 x 128 | 28 x 28 x 256 |
| Conv dw / s1 | 3 × 3 × 256 dw | 28 x 28 x 256 | 28 x 28 x 256 |
| Conv / s1 | 1 × 1 × 256 × 256 | 28 x 28 x 256 | 28 x 28 x 256 |
| Conv dw / s2 | 3 × 3 × 256 dw | 28 x 28 x 256 | 14 x 14 x 256 |
| Conv / s1 | 1 × 1 × 256 × 512 | 14 x 14 x 256 | 14 x 14 x 512 |
| Conv dw / s1 | 3 × 3 × 512 dw | 14 x 14 x 512 | 14 x 14 x 512 |
| Conv / s1 | 1 × 1 × 512 × 512 | 14 x 14 x 512 | 14 x 14 x 512 |
| Conv dw / s1 | 3 × 3 × 512 dw | 14 x 14 x 512 | 14 x 14 x 512 |
| Conv / s1 | 1 × 1 × 512 × 512 | 14 x 14 x 512 | 14 x 14 x 512 |
| Conv dw / s1 | 3 × 3 × 512 dw | 14 x 14 x 512 | 14 x 14 x 512 |
| Conv / s1 | 1 × 1 × 512 × 512 | 14 x 14 x 512 | 14 x 14 x 512 |
| Conv dw / s1 | 3 × 3 × 512 dw | 14 x 14 x 512 | 14 x 14 x 512 |
| Conv / s1 | 1 × 1 × 512 × 512 | 14 x 14 x 512 | 14 x 14 x 512 |
| Conv dw / s1 | 3 × 3 × 512 dw | 14 x 14 x 512 | 14 x 14 x 512 |
| Conv / s1 | 1 × 1 × 512 × 512 | 14 x 14 x 512 | 14 x 14 x 512 |
| Conv dw / s2 | 3 × 3 × 512 dw | 14 x 14 x 512 | 7 x 7 x 512 |
| Conv / s1 | 1 × 1 × 512 × 1024 | 7 x 7 x 512 | 7 x 7 x 1024 |
| Conv dw / s1 | 3 × 3 × 1024 dw | 7 x 7 x 1024 | 7 x 7 x 1024 |
| Conv / s1 | 1 × 1 × 1024 × 1024 | 7 x 7 x 1024 | 7 x 7 x 1024 |
| AvgPool / s1 | Pool 7 × 7 | 7 x 7 x 1024 | 1 x 1 x 1024 |
| FC / s1 | 1024 x 1000 | 1 x 1 x 1024 | 1 x 1 x 1000 |
| Softmax / s1 | Classifier | 1 x 1 x 1000 | 1 x 1 x 1000 |
*/
int first_in_channel;
int last_out_channel;
int num_classes;
torch::nn::Sequential features;
torch::nn::Sequential classifier;
public:
MobileNetV1Impl(int resolution = 224, int num_classes = 1000, int multiplier = 1)
:
first_in_channel(makeDivisable(32 * multiplier, 8)),
last_out_channel(makeDivisable(1024 * multiplier, 8)),
num_classes(num_classes),
features(register_module("features", torch::nn::Sequential(
torch::nn::Conv2d(torch::nn::Conv2dOptions(3, first_in_channel, 3).stride(2).padding(1)),
DepthSepConv(32, 64, 3, 1,1, multiplier),
DepthSepConv(64, 128, 3, 2,1, multiplier),
DepthSepConv(128, 128, 3, 1,1, multiplier),
DepthSepConv(128, 256, 3, 2,1, multiplier),
DepthSepConv(256, 256, 3, 1,1, multiplier),
DepthSepConv(256, 512, 3, 2,1, multiplier),
DepthSepConv(512, 512, 3, 1,1, multiplier),
DepthSepConv(512, 512, 3, 1,1, multiplier),
DepthSepConv(512, 512, 3, 1,1, multiplier),
DepthSepConv(512, 512, 3, 1,1, multiplier),
DepthSepConv(512, 512, 3, 1,1, multiplier),
DepthSepConv(512, 1024, 3, 2,1, multiplier),
DepthSepConv(1024, 1024,3, 1,1, multiplier)
))),
classifier(register_module("classifier", torch::nn::Sequential(
torch::nn::AvgPool2d(torch::nn::AvgPool2dOptions((int)(resolution / 32))), // 7 x 7 x 1024
torch::nn::Conv2d(torch::nn::Conv2dOptions(last_out_channel, num_classes, 1)), // 1 x 1 x 1024
torch::nn::Softmax2d()
)))
{
assert(resolution % 32 == 0);
}
torch::Tensor forward(const torch::Tensor& input)
{
torch::Tensor x = features->forward(input);
x = classifier->forward(x);
x = x.view({ -1, num_classes });
return x;
}
void reset() { }
virtual void train(bool on = true) override
{
}
};
TORCH_MODULE(MobileNetV1);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment