Created
November 8, 2018 15:29
-
-
Save piyush2896/01c4156d7b370ec4d1fa5441a4f9d2e2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras import layers | |
from keras.models import Model | |
def _after_conv(in_tensor): | |
norm = layers.BatchNormalization()(in_tensor) | |
return layers.Activation('relu')(norm) | |
def conv1(in_tensor, filters): | |
conv = layers.Conv2D(filters, kernel_size=1, strides=1)(in_tensor) | |
return _after_conv(conv) | |
def conv1_downsample(in_tensor, filters): | |
conv = layers.Conv2D(filters, kernel_size=1, strides=2)(in_tensor) | |
return _after_conv(conv) | |
def conv3(in_tensor, filters): | |
conv = layers.Conv2D(filters, kernel_size=3, strides=1, padding='same')(in_tensor) | |
return _after_conv(conv) | |
def conv3_downsample(in_tensor, filters): | |
conv = layers.Conv2D(filters, kernel_size=3, strides=2, padding='same')(in_tensor) | |
return _after_conv(conv) | |
def resnet_block_wo_bottlneck(in_tensor, filters, downsample=False): | |
if downsample: | |
conv1_rb = conv3_downsample(in_tensor, filters) | |
else: | |
conv1_rb = conv3(in_tensor, filters) | |
conv2_rb = conv3(conv1_rb, filters) | |
if downsample: | |
in_tensor = conv1_downsample(in_tensor, filters) | |
result = layers.Add()([conv2_rb, in_tensor]) | |
return layers.Activation('relu')(result) | |
def resnet_block_w_bottlneck(in_tensor, | |
filters, | |
downsample=False, | |
change_channels=False): | |
if downsample: | |
conv1_rb = conv1_downsample(in_tensor, int(filters/4)) | |
else: | |
conv1_rb = conv1(in_tensor, int(filters/4)) | |
conv2_rb = conv3(conv1_rb, int(filters/4)) | |
conv3_rb = conv1(conv2_rb, filters) | |
if downsample: | |
in_tensor = conv1_downsample(in_tensor, filters) | |
elif change_channels: | |
in_tensor = conv1(in_tensor, filters) | |
result = layers.Add()([conv3_rb, in_tensor]) | |
return result | |
def _pre_res_blocks(in_tensor): | |
conv = layers.Conv2D(64, 7, strides=2, padding='same')(in_tensor) | |
conv = _after_conv(conv) | |
pool = layers.MaxPool2D(3, 2, padding='same')(conv) | |
return pool | |
def _post_res_blocks(in_tensor, n_classes): | |
pool = layers.GlobalAvgPool2D()(in_tensor) | |
preds = layers.Dense(n_classes, activation='softmax')(pool) | |
return preds | |
def convx_wo_bottleneck(in_tensor, filters, n_times, downsample_1=False): | |
res = in_tensor | |
for i in range(n_times): | |
if i == 0: | |
res = resnet_block_wo_bottlneck(res, filters, downsample_1) | |
else: | |
res = resnet_block_wo_bottlneck(res, filters) | |
return res | |
def convx_w_bottleneck(in_tensor, filters, n_times, downsample_1=False): | |
res = in_tensor | |
for i in range(n_times): | |
if i == 0: | |
res = resnet_block_w_bottlneck(res, filters, downsample_1, not downsample_1) | |
else: | |
res = resnet_block_w_bottlneck(res, filters) | |
return res | |
def _resnet(in_shape=(224,224,3), | |
n_classes=1000, | |
opt='sgd', | |
convx=[64, 128, 256, 512], | |
n_convx=[2, 2, 2, 2], | |
convx_fn=convx_wo_bottleneck): | |
in_layer = layers.Input(in_shape) | |
downsampled = _pre_res_blocks(in_layer) | |
conv2x = convx_fn(downsampled, convx[0], n_convx[0]) | |
conv3x = convx_fn(conv2x, convx[1], n_convx[1], True) | |
conv4x = convx_fn(conv3x, convx[2], n_convx[2], True) | |
conv5x = convx_fn(conv4x, convx[3], n_convx[3], True) | |
preds = _post_res_blocks(conv5x, n_classes) | |
model = Model(in_layer, preds) | |
model.compile(loss="categorical_crossentropy", optimizer=opt, | |
metrics=["accuracy"]) | |
return model | |
def resnet18(in_shape=(224,224,3), n_classes=1000, opt='sgd'): | |
return _resnet(in_shape, n_classes, opt) | |
def resnet34(in_shape=(224,224,3), n_classes=1000, opt='sgd'): | |
return _resnet(in_shape, | |
n_classes, | |
opt, | |
n_convx=[3, 4, 6, 3]) | |
def resnet50(in_shape=(224,224,3), n_classes=1000, opt='sgd'): | |
return _resnet(in_shape, | |
n_classes, | |
opt, | |
[256, 512, 1024, 2048], | |
[3, 4, 6, 3], | |
convx_w_bottleneck) | |
def resnet101(in_shape=(224,224,3), n_classes=1000, opt='sgd'): | |
return _resnet(in_shape, | |
n_classes, | |
opt, | |
[256, 512, 1024, 2048], | |
[3, 4, 23, 3], | |
convx_w_bottleneck) | |
def resnet152(in_shape=(224,224,3), n_classes=1000, opt='sgd'): | |
return _resnet(in_shape, | |
n_classes, | |
opt, | |
[256, 512, 1024, 2048], | |
[3, 8, 36, 3], | |
convx_w_bottleneck) | |
if __name__ == '__main__': | |
model = resnet50() | |
print(model.summary()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment