Last active
June 14, 2019 09:35
-
-
Save farconada/47a7cd131a24ca0d3d0385eea63af910 to your computer and use it in GitHub Desktop.
DetNet backbone for retinanet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
""" | |
import keras | |
from keras.applications import densenet | |
from keras.utils import get_file | |
from . import retinanet | |
from . import Backbone | |
from ..utils.image import preprocess_image | |
from keras.models import Model | |
from keras.layers import Input, Dense, Flatten | |
from keras.layers import Conv2D, Add, ZeroPadding2D, MaxPooling2D, AveragePooling2D | |
from keras.layers.normalization import BatchNormalization | |
from keras.layers.advanced_activations import ReLU | |
from keras.utils import plot_model | |
class DetnetNetBackbone(Backbone): | |
""" Describes backbone information and provides utility functions. | |
""" | |
def retinanet(self, *args, **kwargs): | |
""" Returns a retinanet model using the correct backbone. | |
""" | |
return detnet_retinanet(*args, backbone=self.backbone, **kwargs) | |
def download_imagenet(self): | |
""" Download pre-trained weights for the specified backbone name. | |
This name is in the format {backbone}_weights_tf_dim_ordering_tf_kernels_notop | |
where backbone is the densenet + number of layers (e.g. densenet121). | |
For more info check the explanation from the keras densenet script itself: | |
https://github.com/keras-team/keras/blob/master/keras/applications/densenet.py | |
""" | |
origin = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/' | |
file_name = '{}_weights_tf_dim_ordering_tf_kernels_notop.h5' | |
# load weights | |
if keras.backend.image_data_format() == 'channels_first': | |
raise ValueError('Weights for "channels_first" format are not available.') | |
weights_url = origin + file_name.format(self.backbone) | |
return get_file(file_name.format(self.backbone), weights_url, cache_subdir='models') | |
def validate(self): | |
""" Checks whether the backbone string is correct. | |
""" | |
allowed_backbones = ['detnet59'] | |
backbone = self.backbone.split('_')[0] | |
if backbone not in allowed_backbones: | |
raise ValueError('Backbone (\'{}\') not in allowed backbones ({}).'.format(backbone, allowed_backbones)) | |
def preprocess_image(self, inputs): | |
""" Takes as input an image and prepares it for being passed through the network. | |
""" | |
return preprocess_image(inputs, mode='tf') | |
def detnet_retinanet(num_classes, backbone='detnet59', inputs=None, modifier=None, **kwargs): | |
""" Constructs a retinanet model using a densenet backbone. | |
Args | |
num_classes: Number of classes to predict. | |
backbone: Which backbone to use (one of ('densenet121', 'densenet169', 'densenet201')). | |
inputs: The inputs to the network (defaults to a Tensor of shape (None, None, 3)). | |
modifier: A function handler which can modify the backbone before using it in retinanet (this can be used to freeze backbone layers for example). | |
Returns | |
RetinaNet model with a DenseNet backbone. | |
""" | |
# choose default input | |
if inputs is None: | |
if keras.backend.image_data_format() == 'channels_first': | |
inputs = keras.layers.Input(shape=(3, None, None)) | |
else: | |
inputs = keras.layers.Input(shape=(None, None, 3)) | |
filters_list = [[64], | |
[64, 64, 256], | |
[128, 128, 512], | |
[256, 256, 1024], | |
[256, 256, 256], | |
[256, 256, 256]] | |
blocks_list = [1, 3, 4, 6, 3, 3] | |
detnet = detnet_59(inputs=inputs, filters_list=filters_list, blocks_list=blocks_list, num_classes=num_classes, include_top=False, freeze_bn=True) | |
# invoke modifier if given | |
if modifier: | |
model = modifier(detnet) | |
# create the full model | |
model = retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=detnet.outputs, **kwargs) | |
return model | |
def res_block(x, filters_list, strides=1, use_bias=True, name=None): | |
''' | |
y = f3(f2(f1(x))) + x | |
# Conv2D default arguments: | |
strides=1 | |
padding='valid' | |
data_format='channels_last' | |
dilation_rate=1 | |
activation=None | |
use_bias=True | |
''' | |
out = Conv2D(filters=filters_list[0], kernel_size=1, strides=1, use_bias=False, name='%s_1'%(name))(x) | |
out = BatchNormalization(name='%s_1_bn'%(name))(out) | |
out = ReLU(name='%s_1_relu'%(name))(out) | |
out = Conv2D(filters=filters_list[1], kernel_size=3, strides=1, padding='same', use_bias=False, name='%s_2'%(name))(out) | |
out = BatchNormalization(name='%s_2_bn'%(name))(out) | |
out = ReLU(name='%s_2_relu'%(name))(out) | |
out = Conv2D(filters=filters_list[2], kernel_size=1, strides=1, use_bias=False, name='%s_3'%(name))(out) | |
out = BatchNormalization(name='%s_3_bn'%(name))(out) | |
out = Add(name='%s_add'%(name))([x, out]) | |
out = ReLU(name='%s_relu'%(name))(out) | |
return out | |
def res_block_proj(x, filters_list, strides=2, use_bias=True, name=None): | |
''' | |
y = f3(f2(f1(x))) + proj(x) | |
''' | |
out = Conv2D(filters=filters_list[0], kernel_size=1, strides=strides, use_bias=False, name='%s_1'%(name))(x) | |
out = BatchNormalization(name='%s_1_bn'%(name))(out) | |
out = ReLU(name='%s_1_relu'%(name))(out) | |
out = Conv2D(filters=filters_list[1], kernel_size=3, strides=1, padding='same', use_bias=False, name='%s_2'%(name))(out) | |
out = BatchNormalization(name='%s_2_bn'%(name))(out) | |
out = ReLU(name='%s_2_relu'%(name))(out) | |
out = Conv2D(filters=filters_list[2], kernel_size=1, strides=1, use_bias=False, name='%s_3'%(name))(out) | |
out = BatchNormalization(name='%s_3_bn'%(name))(out) | |
x = Conv2D(filters=filters_list[2], kernel_size=1, strides=strides, use_bias=False, name='%s_proj'%(name))(x) | |
x = BatchNormalization(name='%s_proj_bn'%(name))(x) | |
out = Add(name='%s_add'%(name))([x, out]) | |
out = ReLU(name='%s_relu'%(name))(out) | |
return out | |
def dilated_res_block(x, filters_list, strides=1, use_bias=True, name=None): | |
''' | |
y = f3(f2(f1(x))) + x | |
''' | |
out = Conv2D(filters=filters_list[0], kernel_size=1, strides=1, use_bias=False, name='%s_1'%(name))(x) | |
out = BatchNormalization(name='%s_1_bn'%(name))(out) | |
out = ReLU(name='%s_1_relu'%(name))(out) | |
out = Conv2D(filters=filters_list[1], kernel_size=3, strides=1, padding='same', dilation_rate=2, use_bias=False, name='%s_2'%(name))(out) | |
out = BatchNormalization(name='%s_2_bn'%(name))(out) | |
out = ReLU(name='%s_2_relu'%(name))(out) | |
out = Conv2D(filters=filters_list[2], kernel_size=1, strides=1, use_bias=False, name='%s_3'%(name))(out) | |
out = BatchNormalization(name='%s_3_bn'%(name))(out) | |
out = Add(name='%s_add'%(name))([x, out]) | |
out = ReLU(name='%s_relu'%(name))(out) | |
return out | |
def dilated_res_block_proj(x, filters_list, strides=1, use_bias=True, name=None): | |
''' | |
y = f3(f2(f1(x))) + proj(x) | |
''' | |
out = Conv2D(filters=filters_list[0], kernel_size=1, strides=1, use_bias=False, name='%s_1'%(name))(x) | |
out = BatchNormalization(name='%s_1_bn'%(name))(out) | |
out = ReLU(name='%s_1_relu'%(name))(out) | |
out = Conv2D(filters=filters_list[1], kernel_size=3, strides=1, padding='same', dilation_rate=2, use_bias=False, name='%s_2'%(name))(out) | |
out = BatchNormalization(name='%s_2_bn'%(name))(out) | |
out = ReLU(name='%s_2_relu'%(name))(out) | |
out = Conv2D(filters=filters_list[2], kernel_size=1, strides=1, use_bias=False, name='%s_3'%(name))(out) | |
out = BatchNormalization(name='%s_3_bn'%(name))(out) | |
x = Conv2D(filters=filters_list[2], kernel_size=1, strides=1, use_bias=False, name='%s_proj'%(name))(x) | |
x = BatchNormalization(name='%s_proj_bn'%(name))(x) | |
out = Add(name='%s_add'%(name))([x, out]) | |
out = ReLU(name='%s_relu'%(name))(out) | |
return out | |
def resnet_body(x, filters_list, num_blocks, strides=2, name=None): | |
out = res_block_proj(x=x, filters_list=filters_list, strides=strides, name='%s_1'%(name)) | |
for i in range(1, num_blocks): | |
out = res_block(x=out, filters_list=filters_list, name='%s_%s'%(name, str(i+1))) | |
return out | |
def detnet_body(x, filters_list, num_blocks, strides=1, name=None): | |
out = dilated_res_block_proj(x=x, filters_list=filters_list, name='%s_1'%(name)) | |
for i in range(1, num_blocks): | |
out = dilated_res_block(x=out, filters_list=filters_list, name='%s_%s'%(name, str(i+1))) | |
return out | |
def detnet_59(inputs, filters_list, blocks_list, num_classes, include_top=True, freeze_bn=False): | |
# stage 1 | |
inputs_pad = ZeroPadding2D(padding=3, name='inputs_pad')(inputs) | |
conv1 = Conv2D(filters=filters_list[0][0], kernel_size=7, strides=2, use_bias=False, name='conv1')(inputs_pad) | |
conv1 = BatchNormalization(name='conv1_bn')(conv1) | |
conv1 = ReLU(name='conv1_relu')(conv1) | |
# stage 2 | |
conv1_pad = ZeroPadding2D(padding=1, name='conv1_pad')(conv1) | |
conv1_pool = MaxPooling2D(pool_size=3, strides=2, name='conv1_maxpool')(conv1_pad) | |
conv2_x = resnet_body(x=conv1_pool, filters_list=filters_list[1], num_blocks=blocks_list[1], strides=1, name='res2') | |
# stage 3 | |
conv3_x = resnet_body(x=conv2_x, filters_list=filters_list[2], num_blocks=blocks_list[2], strides=2, name='res3') | |
# stage 4 | |
conv4_x = resnet_body(x=conv3_x, filters_list=filters_list[3], num_blocks=blocks_list[3], strides=2, name='res4') | |
# stage 5 | |
conv5_x = detnet_body(x=conv4_x, filters_list=filters_list[4], num_blocks=blocks_list[4], strides=1, name='dires5') | |
# stage 6 | |
conv6_x = detnet_body(x=conv5_x, filters_list=filters_list[5], num_blocks=blocks_list[5], strides=1, name='dires6') | |
out = AveragePooling2D(pool_size=14, strides=1, name='final_avepool')(conv6_x) | |
out = Flatten(name='flatten')(out) | |
if (include_top): | |
out = Dense(units=num_classes, activation="softmax", kernel_initializer='he_normal', name='dense')(out) | |
model = Model(inputs=inputs, outputs=out) | |
if (freeze_bn): | |
for layer in model.layers: | |
layer.trainable = False | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment