Skip to content

Instantly share code, notes, and snippets.

@piyush2896
Created January 30, 2019 06:25
Show Gist options
  • Save piyush2896/28597f933cbe3db4c08d4210fc5d9e03 to your computer and use it in GitHub Desktop.
Save piyush2896/28597f933cbe3db4c08d4210fc5d9e03 to your computer and use it in GitHub Desktop.
import keras.layers as ls #import Conv2D, DepthwiseConv2D, BatchNormalization, Dense
from keras.models import Model
from functools import partial
def depthwise_separable_conv(in_tensor, filters_pw, strides):
l1_1 = ls.DepthwiseConv2D(3, strides=strides,
depth_multiplier=1, padding='same')(in_tensor)
l1_2 = ls.BatchNormalization()(l1_1)
l1_3 = ls.Activation('relu')(l1_2)
l2_1 = ls.Conv2D(filters_pw, 1, strides=1, padding='same')(l1_3)
l2_2 = ls.BatchNormalization()(l2_1)
l2_3 = ls.Activation('relu')(l2_2)
return l2_3
def conv(in_tensor, filters, strides=2):
l1 = ls.Conv2D(filters, 3, strides=strides, padding='same')(in_tensor)
l2 = ls.BatchNormalization()(l1)
l3 = ls.Activation('relu')(l2)
return l3
dws_conv_s1 = partial(depthwise_separable_conv, strides=1)
dws_conv_s2 = partial(depthwise_separable_conv, strides=2)
def mobilenet(in_shape=(224, 224, 3), include_top=True):
in_ = ls.Input(in_shape)
conv1 = conv(in_, 32)
dws_conv1 = dws_conv_s1(conv1, 64)
dws_conv2 = dws_conv_s2(dws_conv1, 128)
dws_conv3 = dws_conv_s1(dws_conv2, 128)
dws_conv4 = dws_conv_s2(dws_conv3, 256)
dws_conv5 = dws_conv_s1(dws_conv4, 256)
dws_conv6 = dws_conv_s2(dws_conv5, 512)
dws_conv11 = dws_conv6
for _ in range(5):
dws_conv11 = dws_conv_s1(dws_conv11, 512)
dws_conv12 = dws_conv_s2(dws_conv11, 1024)
dws_conv13 = dws_conv_s1(dws_conv12, 1024)
if include_top:
pool = ls.GlobalMaxPooling2D()(dws_conv13)
reshape = ls.Reshape((1, 1, 1024))(pool)
res = ls.Dense(1000, activation='softmax')(reshape)
else:
res = dws_conv13
model = Model(in_, res)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment