Skip to content

Instantly share code, notes, and snippets.

@TeraBytesMemory
Created December 27, 2020 03:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save TeraBytesMemory/fc680b08ff57e18dae6561fe67063038 to your computer and use it in GitHub Desktop.
Save TeraBytesMemory/fc680b08ff57e18dae6561fe67063038 to your computer and use it in GitHub Desktop.
My Extend Xception (in tf 1.x)
# ref: https://github.com/keras-team/keras-applications/blob/master/keras_applications/xception.py
# ref: https://arxiv.org/pdf/2010.02178.pdf
def ExtendXception(include_top=True,
weights='imagenet',
input_tensor=None,
pooling=None,
classes=1000,
**kwargs):
img_input = input_tensor
channel_axis = 1 if tf.keras.backend.image_data_format() == 'channels_first' else -1
x = tf.keras.layers.Conv2D(32, (3, 3),
strides=(2, 2),
use_bias=False,
name='block1_conv1')(img_input)
x = tf.keras.layers.BatchNormalization(axis=channel_axis, name='block1_conv1_bn')(x)
x = tf.keras.layers.Activation('relu', name='block1_conv1_act')(x)
x = tf.keras.layers.Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x)
x = tf.keras.layers.BatchNormalization(axis=channel_axis, name='block1_conv2_bn')(x)
x = tf.keras.layers.Activation('relu', name='block1_conv2_act')(x)
residual = tf.keras.layers.Conv2D(128, (1, 1),
strides=(2, 2),
padding='valid',
use_bias=False)(x)
residual = tf.keras.layers.BatchNormalization(axis=channel_axis)(residual)
x = tf.keras.layers.SeparableConv2D(128, (3, 3),
padding='same',
use_bias=False,
name='block2_sepconv1')(x)
x = tf.keras.layers.BatchNormalization(axis=channel_axis, name='block2_sepconv1_bn')(x)
x = tf.keras.layers.Activation('relu', name='block2_sepconv2_act')(x)
x = tf.keras.layers.SeparableConv2D(128, (3, 3),
padding='same',
use_bias=False,
name='block2_sepconv2')(x)
x = tf.keras.layers.BatchNormalization(axis=channel_axis, name='block2_sepconv2_bn')(x)
x = se_block(x, 128, 8)
x = tf.pad(x, tf.constant([[0,0], [1, 1], [1, 1], [0, 0]], dtype=tf.int32), "REFLECT")
x = tf.keras.layers.MaxPooling2D((3, 3),
strides=(2, 2),
padding='valid',
name='block2_pool')(x)
x = tf.keras.layers.add([x, residual])
residual = tf.keras.layers.Conv2D(256, (1, 1), strides=(2, 2),
padding='valid', use_bias=False)(x)
residual = tf.keras.layers.BatchNormalization(axis=channel_axis)(residual)
x = tf.keras.layers.Activation('relu', name='block3_sepconv1_act')(x)
x = tf.keras.layers.SeparableConv2D(256, (3, 3),
padding='same',
use_bias=False,
name='block3_sepconv1')(x)
x = tf.keras.layers.BatchNormalization(axis=channel_axis, name='block3_sepconv1_bn')(x)
x = tf.keras.layers.Activation('relu', name='block3_sepconv2_act')(x)
x = tf.keras.layers.SeparableConv2D(256, (3, 3),
padding='same',
use_bias=False,
name='block3_sepconv2')(x)
x = tf.keras.layers.BatchNormalization(axis=channel_axis, name='block3_sepconv2_bn')(x)
x = se_block(x, 256, 8)
x = tf.pad(x, tf.constant([[0,0], [1, 1], [1, 1], [0, 0]], dtype=tf.int32), "REFLECT")
x = tf.keras.layers.MaxPooling2D((3, 3), strides=(2, 2),
padding='valid',
name='block3_pool')(x)
x = tf.keras.layers.add([x, residual])
residual = tf.keras.layers.Conv2D(728, (1, 1),
strides=(2, 2),
padding='valid',
use_bias=False)(x)
residual = tf.keras.layers.BatchNormalization(axis=channel_axis)(residual)
x = tf.keras.layers.Activation('relu', name='block4_sepconv1_act')(x)
x = tf.keras.layers.SeparableConv2D(728, (3, 3),
padding='same',
use_bias=False,
name='block4_sepconv1')(x)
x = tf.keras.layers.BatchNormalization(axis=channel_axis, name='block4_sepconv1_bn')(x)
x = tf.keras.layers.Activation('relu', name='block4_sepconv2_act')(x)
x = tf.keras.layers.SeparableConv2D(728, (3, 3),
padding='same',
use_bias=False,
name='block4_sepconv2')(x)
x = tf.keras.layers.BatchNormalization(axis=channel_axis, name='block4_sepconv2_bn')(x)
x = tf.pad(x, tf.constant([[0,0], [1, 1], [1, 1], [0, 0]], dtype=tf.int32), "REFLECT")
x = tf.keras.layers.MaxPooling2D((3, 3), strides=(2, 2),
padding='valid',
name='block4_pool')(x)
x = tf.keras.layers.add([x, residual])
for i in range(8):
residual = x
prefix = 'block' + str(i + 5)
x = tf.keras.layers.Activation('relu', name=prefix + '_sepconv1_act')(x)
x = tf.keras.layers.SeparableConv2D(728, (3, 3),
padding='same',
use_bias=False,
name=prefix + '_sepconv1')(x)
x = tf.keras.layers.BatchNormalization(axis=channel_axis,
name=prefix + '_sepconv1_bn')(x)
x = tf.keras.layers.Activation('relu', name=prefix + '_sepconv2_act')(x)
x = tf.keras.layers.SeparableConv2D(728, (3, 3),
padding='same',
use_bias=False,
name=prefix + '_sepconv2')(x)
x = tf.keras.layers.BatchNormalization(axis=channel_axis,
name=prefix + '_sepconv2_bn')(x)
x = tf.keras.layers.Activation('relu', name=prefix + '_sepconv3_act')(x)
x = tf.keras.layers.SeparableConv2D(728, (3, 3),
padding='same',
use_bias=False,
name=prefix + '_sepconv3')(x)
x = tf.keras.layers.BatchNormalization(axis=channel_axis,
name=prefix + '_sepconv3_bn')(x)
x = tf.keras.layers.add([x, residual])
residual = tf.keras.layers.Conv2D(1024, (1, 1), strides=(2, 2),
padding='valid', use_bias=False)(x)
residual = tf.keras.layers.BatchNormalization(axis=channel_axis)(residual)
x = tf.keras.layers.Activation('relu', name='block13_sepconv1_act')(x)
x = tf.keras.layers.SeparableConv2D(728, (3, 3),
padding='same',
use_bias=False,
name='block13_sepconv1')(x)
x = tf.keras.layers.BatchNormalization(axis=channel_axis, name='block13_sepconv1_bn')(x)
x = tf.keras.layers.Activation('relu', name='block13_sepconv2_act')(x)
x = tf.keras.layers.SeparableConv2D(1024, (3, 3),
padding='same',
use_bias=False,
name='block13_sepconv2')(x)
x = tf.keras.layers.BatchNormalization(axis=channel_axis, name='block13_sepconv2_bn')(x)
x = tf.pad(x, tf.constant([[0,0], [1, 1], [1, 1], [0, 0]], dtype=tf.int32), "REFLECT")
x = tf.keras.layers.MaxPooling2D((3, 3),
strides=(2, 2),
padding='valid',
name='block13_pool')(x)
x = tf.keras.layers.add([x, residual])
x = tf.keras.layers.SeparableConv2D(1536, (3, 3),
padding='same',
use_bias=False,
name='block14_sepconv1')(x)
x = tf.keras.layers.BatchNormalization(axis=channel_axis, name='block14_sepconv1_bn')(x)
x = tf.keras.layers.Activation('relu', name='block14_sepconv1_act')(x)
x = tf.keras.layers.SeparableConv2D(2048, (3, 3),
padding='same',
use_bias=False,
name='block14_sepconv2')(x)
x = tf.keras.layers.BatchNormalization(axis=channel_axis, name='block14_sepconv2_bn')(x)
x = tf.keras.layers.Activation('relu', name='block14_sepconv2_act')(x)
if include_top:
x = tf.keras.layers.GlobalAveragePooling2D(name='avg_pool')(x)
x = tf.keras.layers.Dense(classes, activation='softmax', name='predictions')(x)
else:
if pooling == 'avg':
x = tf.keras.layers.GlobalAveragePooling2D()(x)
elif pooling == 'max':
x = tf.keras.layers.GlobalMaxPooling2D()(x)
model = tf.keras.models.Model(img_input, x, name='xception')
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment