Created
January 30, 2019 11:04
-
-
Save piyush2896/c403e310c281e218cfee9324e6137476 to your computer and use it in GitHub Desktop.
Gist for squeezenet to use in https://predictiveprogrammer.com/famous-convolutional-neural-network-architectures-2/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras import layers as ls | |
from keras.models import Model | |
def conv_relu(in_tensor, kernel_size, filters): | |
# A simple Convolution, BatchNormalization and Relu Stack | |
conv = ls.Conv2D(filters, kernel_size, | |
strides=1, padding='same', | |
kernel_initializer='he_normal')(in_tensor) | |
return ls.Activation('relu')(conv) | |
def _pre_squeezenet(in_tensor): | |
conv = ls.Conv2D(96, 7, strides=2, padding='same')(in_tensor) | |
act = ls.Activation('relu')(conv) | |
pool = ls.MaxPool2D(3, strides=2)(act) | |
return pool | |
def _post_squeezenet(in_tensor, n_classes): | |
dropped = ls.Dropout(0.5)(in_tensor) | |
conv_act = conv_relu(dropped, 1, n_classes) | |
pool = ls.GlobalAvgPool2D()(conv_act) | |
out = ls.Activation('softmax')(pool) | |
return out | |
def fire_module(in_tensor, s_channels, e_channels): | |
s_conv_act = conv_relu(in_tensor, 1, s_channels) | |
e_l_conv_act = conv_relu(s_conv_act, 1, e_channels) | |
e_r_conv_act = conv_relu(s_conv_act, 3, e_channels) | |
concat = ls.Concatenate()([e_l_conv_act, e_r_conv_act]) | |
return concat | |
def squeezenet(in_shape, include_top=True, n_classes=1000): | |
in_ = ls.Input(in_shape) | |
conv_act1 = _pre_squeezenet(in_) | |
fire2 = fire_module(conv_act1, 16, 64) | |
fire3 = fire_module(fire2, 16, 64) | |
fire4 = fire_module(fire3, 32, 128) | |
pool2 = ls.MaxPool2D()(fire4) | |
fire5 = fire_module(pool2, 32, 128) | |
fire6 = fire_module(fire5, 48, 192) | |
fire7 = fire_module(fire6, 48, 192) | |
fire8 = fire_module(fire7, 64, 256) | |
pool3 = ls.MaxPool2D()(fire8) | |
fire9 = fire_module(pool3, 64, 256) | |
if include_top: | |
res = _post_squeezenet(fire9, n_classes) | |
else: | |
res = fire9 | |
model = Model(in_, res) | |
return model | |
def squeezenet_sbypass(in_shape, include_top=True, n_classes=1000): | |
in_ = ls.Input(in_shape) | |
conv_act1 = _pre_squeezenet(in_) | |
fire2 = fire_module(conv_act1, 16, 64) | |
fire3 = fire_module(fire2, 16, 64) | |
bypass1 = ls.Add()([fire2, fire3]) | |
fire4 = fire_module(bypass1, 32, 128) | |
pool2 = ls.MaxPool2D()(fire4) | |
fire5 = fire_module(pool2, 32, 128) | |
bypass2 = ls.Add()([pool2, fire5]) | |
fire6 = fire_module(bypass2, 48, 192) | |
fire7 = fire_module(fire6, 48, 192) | |
bypass3 = ls.Add()([fire6, fire7]) | |
fire8 = fire_module(bypass3, 64, 256) | |
pool3 = ls.MaxPool2D()(fire8) | |
fire9 = fire_module(pool3, 64, 256) | |
bypass4 = ls.Add()([pool3, fire9]) | |
if include_top: | |
res = _post_squeezenet(bypass4, n_classes) | |
else: | |
res = bypass4 | |
model = Model(in_, res) | |
return model | |
def squeezenet_cbypass(in_shape, include_top=True, n_classes=1000): | |
def add_complex_bypass(in_tensor1, in_tensor2, filters): | |
conv_act = conv_relu(in_tensor1, 1, filters) | |
return ls.Add()([in_tensor2, conv_act]) | |
in_ = ls.Input(in_shape) | |
conv_act1 = _pre_squeezenet(in_) | |
fire2 = fire_module(conv_act1, 16, 64) | |
cbypass1 = add_complex_bypass(conv_act1, fire2, 128) # 64 for 1x1 + 64 for 3x3 in fire module | |
fire3 = fire_module(cbypass1, 16, 64) | |
sbypass1 = ls.Add()([cbypass1, fire3]) | |
fire4 = fire_module(sbypass1, 32, 128) | |
cbypass2 = add_complex_bypass(sbypass1, fire4, 256) | |
pool2 = ls.MaxPool2D()(cbypass2) | |
fire5 = fire_module(pool2, 32, 128) | |
sbypass2 = ls.Add()([pool2, fire5]) | |
fire6 = fire_module(sbypass2, 48, 192) | |
cbypass3 = add_complex_bypass(sbypass2, fire6, 384) | |
fire7 = fire_module(cbypass3, 48, 192) | |
sbypass3 = ls.Add()([cbypass3, fire7]) | |
fire8 = fire_module(sbypass3, 64, 256) | |
cbypass4 = add_complex_bypass(sbypass3, fire8, 512) | |
pool3 = ls.MaxPool2D()(cbypass4) | |
fire9 = fire_module(pool3, 64, 256) | |
sbypass4 = ls.Add()([pool3, fire9]) | |
if include_top: | |
res = _post_squeezenet(sbypass4, n_classes) | |
else: | |
res = sbypass4 | |
model = Model(in_, res) | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment