Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
"""Downsized version of Xception, without residual connections.
"""
from __future__ import print_function
from __future__ import absolute_import
from keras.models import Model
from keras.layers import Dense
from keras.layers import Input
from keras.layers import BatchNormalization
from keras.layers import Activation
from keras.layers import Conv2D
from keras.layers import SeparableConv2D
from keras.layers import MaxPooling2D
from keras.layers import GlobalAveragePooling2D
def SmallXception(input_shape=(200, 200, 3),
classes=100):
img_input = Input(shape=input_shape)
x = Conv2D(32, (3, 3), strides=(2, 2), use_bias=False)(img_input)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling2D((3, 3), strides=(2, 2))(x)
x = SeparableConv2D(128, (3, 3), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(128, (3, 3), use_bias=False)(x)
x = BatchNormalization()(x)
x = MaxPooling2D((3, 3), strides=(2, 2))(x)
x = SeparableConv2D(256, (3, 3), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(256, (3, 3), use_bias=False)(x)
x = BatchNormalization()(x)
x = MaxPooling2D((3, 3), strides=(2, 2))(x)
x = SeparableConv2D(512, (3, 3), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(512, (3, 3), use_bias=False)(x)
x = BatchNormalization()(x)
x = GlobalAveragePooling2D(name='avg_pool')(x)
x = Dense(classes, activation='softmax', name='predictions')(x)
# Create model.
model = Model(img_input, x)
return model
if __name__ == '__main__':
model = SmallXception()
model.summary()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.