Skip to content

Instantly share code, notes, and snippets.

@noahtren
Created July 29, 2019 15:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save noahtren/3e61b7c278b5057ad2b3b319d54c4188 to your computer and use it in GitHub Desktop.
Save noahtren/3e61b7c278b5057ad2b3b319d54c4188 to your computer and use it in GitHub Desktop.
Small convolutional neural network with residual connections implemented with Keras
"""Small convnet with residual connections.
inspired by https://gist.github.com/mjdietzx/0cb95922aac14d446a6530f87b3a04ce,
which builds a full ResNet-50 or ResNeXt-50 model
"""
NUM_CLASSES = 2
from keras.layers import BatchNormalization, Conv2D, LeakyReLU, Input, MaxPool2D, Dense, Flatten, Dropout
from keras.models import Model
def add_common_layers(y):
y = BatchNormalization()(y)
y = LeakyReLU()(y)
return y
def residual_block(y, nb_filters):
shortcut = y
y = Conv2D(nb_filters, kernel_size=(1, 1), strides=(1, 1), padding='same')(y)
y = add_common_layers(y)
y = Conv2D(nb_filters, kernel_size=(3, 3), strides=(1, 1), padding='same')(y)
y = add_common_layers(y)
y = Conv2D(nb_filters, kernel_size=(1, 1), strides=(1, 1), padding='same')(y)
y = add_common_layers(y)
y = add([shortcut, y])
y = LeakyReLU()(y)
return y
img_input = Input(shape=(100, 100, 1))
x = Conv2D(8, (3, 3))(img_input)
x = add_common_layers(x)
x = MaxPool2D()(x)
x = residual_block(x, 8)
x = MaxPool2D()(x)
x = residual_block(x, 8)
x = MaxPool2D()(x)
x = residual_block(x, 8)
x = Flatten()(x)
x = Dense(16)(x)
x = Dropout(0.5)(x)
prediction = Dense(NUM_CLASSES, activation='softmax')(x)
model = Model(inputs=[img_input], outputs=[prediction])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment