Skip to content

Instantly share code, notes, and snippets.

@piyush2896
Created November 8, 2018 15:29
Show Gist options
  • Save piyush2896/01c4156d7b370ec4d1fa5441a4f9d2e2 to your computer and use it in GitHub Desktop.
Save piyush2896/01c4156d7b370ec4d1fa5441a4f9d2e2 to your computer and use it in GitHub Desktop.
from keras import layers
from keras.models import Model
def _after_conv(in_tensor):
norm = layers.BatchNormalization()(in_tensor)
return layers.Activation('relu')(norm)
def conv1(in_tensor, filters):
conv = layers.Conv2D(filters, kernel_size=1, strides=1)(in_tensor)
return _after_conv(conv)
def conv1_downsample(in_tensor, filters):
conv = layers.Conv2D(filters, kernel_size=1, strides=2)(in_tensor)
return _after_conv(conv)
def conv3(in_tensor, filters):
conv = layers.Conv2D(filters, kernel_size=3, strides=1, padding='same')(in_tensor)
return _after_conv(conv)
def conv3_downsample(in_tensor, filters):
conv = layers.Conv2D(filters, kernel_size=3, strides=2, padding='same')(in_tensor)
return _after_conv(conv)
def resnet_block_wo_bottlneck(in_tensor, filters, downsample=False):
if downsample:
conv1_rb = conv3_downsample(in_tensor, filters)
else:
conv1_rb = conv3(in_tensor, filters)
conv2_rb = conv3(conv1_rb, filters)
if downsample:
in_tensor = conv1_downsample(in_tensor, filters)
result = layers.Add()([conv2_rb, in_tensor])
return layers.Activation('relu')(result)
def resnet_block_w_bottlneck(in_tensor,
filters,
downsample=False,
change_channels=False):
if downsample:
conv1_rb = conv1_downsample(in_tensor, int(filters/4))
else:
conv1_rb = conv1(in_tensor, int(filters/4))
conv2_rb = conv3(conv1_rb, int(filters/4))
conv3_rb = conv1(conv2_rb, filters)
if downsample:
in_tensor = conv1_downsample(in_tensor, filters)
elif change_channels:
in_tensor = conv1(in_tensor, filters)
result = layers.Add()([conv3_rb, in_tensor])
return result
def _pre_res_blocks(in_tensor):
conv = layers.Conv2D(64, 7, strides=2, padding='same')(in_tensor)
conv = _after_conv(conv)
pool = layers.MaxPool2D(3, 2, padding='same')(conv)
return pool
def _post_res_blocks(in_tensor, n_classes):
pool = layers.GlobalAvgPool2D()(in_tensor)
preds = layers.Dense(n_classes, activation='softmax')(pool)
return preds
def convx_wo_bottleneck(in_tensor, filters, n_times, downsample_1=False):
res = in_tensor
for i in range(n_times):
if i == 0:
res = resnet_block_wo_bottlneck(res, filters, downsample_1)
else:
res = resnet_block_wo_bottlneck(res, filters)
return res
def convx_w_bottleneck(in_tensor, filters, n_times, downsample_1=False):
res = in_tensor
for i in range(n_times):
if i == 0:
res = resnet_block_w_bottlneck(res, filters, downsample_1, not downsample_1)
else:
res = resnet_block_w_bottlneck(res, filters)
return res
def _resnet(in_shape=(224,224,3),
n_classes=1000,
opt='sgd',
convx=[64, 128, 256, 512],
n_convx=[2, 2, 2, 2],
convx_fn=convx_wo_bottleneck):
in_layer = layers.Input(in_shape)
downsampled = _pre_res_blocks(in_layer)
conv2x = convx_fn(downsampled, convx[0], n_convx[0])
conv3x = convx_fn(conv2x, convx[1], n_convx[1], True)
conv4x = convx_fn(conv3x, convx[2], n_convx[2], True)
conv5x = convx_fn(conv4x, convx[3], n_convx[3], True)
preds = _post_res_blocks(conv5x, n_classes)
model = Model(in_layer, preds)
model.compile(loss="categorical_crossentropy", optimizer=opt,
metrics=["accuracy"])
return model
def resnet18(in_shape=(224,224,3), n_classes=1000, opt='sgd'):
return _resnet(in_shape, n_classes, opt)
def resnet34(in_shape=(224,224,3), n_classes=1000, opt='sgd'):
return _resnet(in_shape,
n_classes,
opt,
n_convx=[3, 4, 6, 3])
def resnet50(in_shape=(224,224,3), n_classes=1000, opt='sgd'):
return _resnet(in_shape,
n_classes,
opt,
[256, 512, 1024, 2048],
[3, 4, 6, 3],
convx_w_bottleneck)
def resnet101(in_shape=(224,224,3), n_classes=1000, opt='sgd'):
return _resnet(in_shape,
n_classes,
opt,
[256, 512, 1024, 2048],
[3, 4, 23, 3],
convx_w_bottleneck)
def resnet152(in_shape=(224,224,3), n_classes=1000, opt='sgd'):
return _resnet(in_shape,
n_classes,
opt,
[256, 512, 1024, 2048],
[3, 8, 36, 3],
convx_w_bottleneck)
if __name__ == '__main__':
model = resnet50()
print(model.summary())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment