Skip to content

Instantly share code, notes, and snippets.

@mvoelk
Last active March 11, 2022 08:06
Show Gist options
  • Star 21 You must be signed in to star a gist
  • Fork 8 You must be signed in to fork a gist
  • Save mvoelk/ef4fc7fb905be7191cc2beb1421da37c to your computer and use it in GitHub Desktop.
Save mvoelk/ef4fc7fb905be7191cc2beb1421da37c to your computer and use it in GitHub Desktop.
Resnet-152 pre-trained model in TF Keras 2.x
# -*- coding: utf-8 -*-
import cv2
import numpy as np
from tensorflow.keras.layers import Input, Dense, Conv2D, MaxPool2D, AvgPool2D, Activation
from tensorflow.keras.layers import Layer, BatchNormalization, ZeroPadding2D, Flatten, add
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.models import Model
from tensorflow.keras import initializers
from tensorflow.python.keras.layers import InputSpec
from tensorflow.keras import backend as K
import sys
sys.setrecursionlimit(3000)
class Scale(Layer):
'''Custom Layer for ResNet used for BatchNormalization.
Learns a set of weights and biases used for scaling the input data.
the output consists simply in an element-wise multiplication of the input
and a sum of a set of constants:
out = in * gamma + beta,
where 'gamma' and 'beta' are the weights and biases larned.
# Arguments
axis: integer, axis along which to normalize in mode 0. For instance,
if your input tensor has shape (samples, channels, rows, cols),
set axis to 1 to normalize per feature map (channels axis).
momentum: momentum in the computation of the
exponential average of the mean and standard deviation
of the data, for feature-wise normalization.
beta_init: name of initialization function for shift parameter
(see [initializers](../initializers.md)).
gamma_init: name of initialization function for scale parameter (see
[initializers](../initializers.md)).
'''
def __init__(self, axis=-1, momentum = 0.9, beta_init='zero', gamma_init='one', **kwargs):
self.momentum = momentum
self.axis = axis
self.beta_initializer = initializers.get(beta_init)
self.gamma_initializer = initializers.get(gamma_init)
super(Scale, self).__init__(**kwargs)
def build(self, input_shape):
self.input_spec = [InputSpec(shape=input_shape)]
shape = (int(input_shape[self.axis]),)
self.gamma = self.add_weight(
name='%s_gamma'%self.name,
shape=shape,
initializer=self.gamma_initializer,
trainable=True,
dtype=self.dtype)
self.beta = self.add_weight(
name='%s_beta'%self.name,
shape=shape,
initializer=self.beta_initializer,
trainable=True,
dtype=self.dtype)
self.built = True
def call(self, x, mask=None):
input_shape = self.input_spec[0].shape
broadcast_shape = [1] * len(input_shape)
broadcast_shape[self.axis] = input_shape[self.axis]
out = K.reshape(self.gamma, broadcast_shape) * x + K.reshape(self.beta, broadcast_shape)
return out
def get_config(self):
config = {"momentum": self.momentum, "axis": self.axis}
base_config = super(Scale, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def identity_block(input_tensor, kernel_size, filters, stage, block):
'''The identity_block is the block that has no conv layer at shortcut
# Arguments
input_tensor: input tensor
kernel_size: defualt 3, the kernel size of middle conv layer at main path
filters: list of integers, the nb_filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
'''
eps = 1.1e-5
nb_filter1, nb_filter2, nb_filter3 = filters
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
scale_name_base = 'scale' + str(stage) + block + '_branch'
x = Conv2D(nb_filter1, (1, 1), name=conv_name_base + '2a', use_bias=False)(input_tensor)
x = BatchNormalization(epsilon=eps, name=bn_name_base + '2a')(x)
x = Scale(name=scale_name_base + '2a')(x)
x = Activation('relu', name=conv_name_base + '2a_relu')(x)
x = ZeroPadding2D((1, 1), name=conv_name_base + '2b_zeropadding')(x)
x = Conv2D(nb_filter2, (kernel_size, kernel_size), name=conv_name_base + '2b', use_bias=False)(x)
x = BatchNormalization(epsilon=eps, name=bn_name_base + '2b')(x)
x = Scale(name=scale_name_base + '2b')(x)
x = Activation('relu', name=conv_name_base + '2b_relu')(x)
x = Conv2D(nb_filter3, (1, 1), name=conv_name_base + '2c', use_bias=False)(x)
x = BatchNormalization(epsilon=eps, name=bn_name_base + '2c')(x)
x = Scale(name=scale_name_base + '2c')(x)
x = add([x, input_tensor], name='res' + str(stage) + block)
x = Activation('relu', name='res' + str(stage) + block + '_relu')(x)
return x
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
'''conv_block is the block that has a conv layer at shortcut
# Arguments
input_tensor: input tensor
kernel_size: defualt 3, the kernel size of middle conv layer at main path
filters: list of integers, the nb_filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
Note that from stage 3, the first conv layer at main path is with subsample=(2,2)
And the shortcut should have subsample=(2,2) as well
'''
eps = 1.1e-5
nb_filter1, nb_filter2, nb_filter3 = filters
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
scale_name_base = 'scale' + str(stage) + block + '_branch'
x = Conv2D(nb_filter1, (1, 1), strides=strides, name=conv_name_base + '2a', use_bias=False)(input_tensor)
x = BatchNormalization(epsilon=eps, name=bn_name_base + '2a')(x)
x = Scale(name=scale_name_base + '2a')(x)
x = Activation('relu', name=conv_name_base + '2a_relu')(x)
x = ZeroPadding2D((1, 1), name=conv_name_base + '2b_zeropadding')(x)
x = Conv2D(nb_filter2, (kernel_size, kernel_size), name=conv_name_base + '2b', use_bias=False)(x)
x = BatchNormalization(epsilon=eps, name=bn_name_base + '2b')(x)
x = Scale(name=scale_name_base + '2b')(x)
x = Activation('relu', name=conv_name_base + '2b_relu')(x)
x = Conv2D(nb_filter3, (1, 1), name=conv_name_base + '2c', use_bias=False)(x)
x = BatchNormalization(epsilon=eps, name=bn_name_base + '2c')(x)
x = Scale(name=scale_name_base + '2c')(x)
shortcut = Conv2D(nb_filter3, (1, 1), strides=strides, name=conv_name_base + '1', use_bias=False)(input_tensor)
shortcut = BatchNormalization(epsilon=eps, name=bn_name_base + '1')(shortcut)
shortcut = Scale(name=scale_name_base + '1')(shortcut)
x = add([x, shortcut], name='res' + str(stage) + block)
x = Activation('relu', name='res' + str(stage) + block + '_relu')(x)
return x
def resnet152_model(input_shape=(224, 224, 3), weights_path=None):
'''Instantiate the ResNet152 architecture,
# Arguments
input_shape: shape of the model input
weights_path: path to pretrained weight file
# Returns
A Keras model instance.
'''
eps = 1.1e-5
img_input = Input(shape=input_shape, name='data')
x = ZeroPadding2D((3, 3), name='conv1_zeropadding')(img_input)
x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1', use_bias=False)(x)
x = BatchNormalization(epsilon=eps, name='bn_conv1')(x)
x = Scale(name='scale_conv1')(x)
x = Activation('relu', name='conv1_relu')(x)
x = MaxPool2D((3, 3), strides=(2, 2), name='pool1')(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
for i in range(1,8):
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b'+str(i))
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
for i in range(1,36):
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b'+str(i))
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
x_fc = AvgPool2D((7, 7), name='avg_pool')(x)
x_fc = Flatten()(x_fc)
x_fc = Dense(1000, activation='softmax', name='fc1000')(x_fc)
model = Model(img_input, x_fc)
# load weights
if weights_path:
model.load_weights(weights_path, by_name=True)
return model
if __name__ == '__main__':
input_shape = (224, 224, 3)
weights_path = 'resnet152_weights_tf.h5'
image_path = 'cat.jpg'
im = cv2.resize(cv2.imread(image_path), input_shape[0:2]).astype(np.float32)
# Remove train image mean
im -= [103.939, 116.779, 123.68]
# Insert a new dimension for the batch_size
im = np.expand_dims(im, axis=0)
# Test pretrained model
model = resnet152_model(input_shape, weights_path)
sgd = SGD(lr=1e-2, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy'])
preds = model.predict(im)
print(np.argmax(preds))
@mvoelk
Copy link
Author

mvoelk commented Oct 6, 2020

I updated the code to TF Keras 2.4.0 and removed the Theano support.

@mvoelk
Copy link
Author

mvoelk commented Oct 6, 2020

Code is bases on https://gist.github.com/flyyufelix/7e2eafb149f72f4d38dd661882c554a6. You can find the weights there...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment