Skip to content

Instantly share code, notes, and snippets.

@hollance
Created October 7, 2017 20:43
Show Gist options
  • Star 21 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save hollance/8d30bf5c1622036d16c4f27bd0ec88bf to your computer and use it in GitHub Desktop.
Save hollance/8d30bf5c1622036d16c4f27bd0ec88bf to your computer and use it in GitHub Desktop.
SE-ResNet-50 in Keras
# Convert SE-ResNet-50 from Caffe to Keras
# Using the model from https://github.com/shicai/SENet-Caffe
import os
import numpy as np
# The caffe module needs to be on the Python path; we'll add it here explicitly.
import sys
caffe_root = "/path/to/caffe"
sys.path.insert(0, caffe_root + "python")
import caffe
model_root = "/path/to/SE-ResNet-50/"
model_def = model_root + 'se_resnet_50_v1_deploy.prototxt'
model_weights = model_root + 'se_resnet_50_v1.caffemodel'
if not os.path.isfile(model_weights):
print("Model not found")
caffe.set_mode_cpu()
net = caffe.Net(model_def, model_weights, caffe.TEST)
fc_layers = [
"fc2_1/sqz", "fc2_1/exc",
"fc2_2/sqz", "fc2_2/exc",
"fc2_3/sqz", "fc2_3/exc",
"fc3_1/sqz", "fc3_1/exc",
"fc3_2/sqz", "fc3_2/exc",
"fc3_3/sqz", "fc3_3/exc",
"fc3_4/sqz", "fc3_4/exc",
"fc4_1/sqz", "fc4_1/exc",
"fc4_2/sqz", "fc4_2/exc",
"fc4_3/sqz", "fc4_3/exc",
"fc4_4/sqz", "fc4_4/exc",
"fc4_5/sqz", "fc4_5/exc",
"fc4_6/sqz", "fc4_6/exc",
"fc5_1/sqz", "fc5_1/exc",
"fc5_2/sqz", "fc5_2/exc",
"fc5_3/sqz", "fc5_3/exc",
"fc6",
]
real_name = None
mean = None
variance = None
bias = None
params = {}
for layer_name, param in net.params.items():
shapes = map(lambda x: x.data.shape, param)
print(layer_name.ljust(25) + str(list(shapes)))
# Dense layer with bias.
if layer_name in fc_layers:
# Caffe stores the weights as (outputChannels, inputChannels).
c_o = param[0].data.shape[0]
c_i = param[0].data.shape[1]
# Keras on TensorFlow uses: (inputChannels, outputChannels).
weights = np.array(param[0].data.data, dtype=np.float32).reshape(c_o, c_i)
weights = weights.transpose(1, 0)
bias = param[1].data
params[layer_name] = [weights, bias]
# These are the batch norm parameters.
# Each BatchNorm layer has three blobs:
# 0: mean
# 1: variance
# 2: scale factor
elif "/bn" in layer_name:
factor = param[2].data[0]
mean = np.array(param[0].data, dtype=np.float32) / factor
variance = np.array(param[1].data, dtype=np.float32) / factor
real_name = layer_name
# This is a scale layer. It always follows BatchNorm.
# A scale layer has two blobs:
# 0: scale (gamma)
# 1: bias (beta)
elif "/scale" in layer_name:
gamma = np.array(param[0].data, dtype=np.float32)
beta = np.array(param[1].data, dtype=np.float32)
if real_name is None: print("*** ERROR! ***")
if mean is None: print("*** ERROR! ***")
if variance is None: print("*** ERROR! ***")
params[real_name] = [gamma, beta, mean, variance]
real_name = None
mean = None
variance = None
bias = None
# Conv layer with batchnorm, no bias
else:
# The Caffe model stores the weights for each layer in this shape:
# (outputChannels, inputChannels, kernelHeight, kernelWidth)
c_o = param[0].data.shape[0]
c_i = param[0].data.shape[1]
h = param[0].data.shape[2]
w = param[0].data.shape[3]
# Keras on TensorFlow expects weights in the following shape:
# (kernelHeight, kernelWidth, inputChannels, outputChannels)
weights = np.array(param[0].data.data, dtype=np.float32).reshape(c_o, c_i, h, w)
weights = weights.transpose(2, 3, 1, 0)
params[layer_name] = [weights]
np.save("SENet_params.npy", params)
# -*- coding: utf-8 -*-
"""SE-ResNet-50 model for Keras.
Based on https://github.com/fchollet/keras/blob/master/keras/applications/resnet50.py
"""
from __future__ import print_function
from __future__ import absolute_import
import warnings
from keras.layers import Input
from keras import layers
from keras.layers import Dense
from keras.layers import Activation
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import MaxPooling2D
from keras.layers import AveragePooling2D
from keras.layers import GlobalAveragePooling2D
from keras.layers import GlobalMaxPooling2D
from keras.layers import BatchNormalization
from keras.layers import Reshape
from keras.layers import Multiply
from keras.models import Model
from keras import backend as K
from keras.engine.topology import get_source_inputs
from keras.utils import layer_utils
from keras.utils.data_utils import get_file
from keras.applications.imagenet_utils import _obtain_input_shape
def preprocess_input(x):
# 'RGB'->'BGR'
x = x[..., ::-1]
# Zero-center by mean pixel
x[..., 0] -= 103.939
x[..., 1] -= 116.779
x[..., 2] -= 123.68
# Scale
x *= 0.017
return x
def identity_block(input_tensor, kernel_size, filters, stage, block):
filters1, filters2, filters3 = filters
if K.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
bn_eps = 0.0001
block_name = str(stage) + "_" + str(block)
conv_name_base = "conv" + block_name
relu_name_base = "relu" + block_name
x = Conv2D(filters1, (1, 1), use_bias=False, name=conv_name_base + '_x1')(input_tensor)
x = BatchNormalization(axis=bn_axis, epsilon=bn_eps, name=conv_name_base + '_x1_bn')(x)
x = Activation('relu', name=relu_name_base + '_x1')(x)
x = Conv2D(filters2, kernel_size, padding='same', use_bias=False, name=conv_name_base + '_x2')(x)
x = BatchNormalization(axis=bn_axis, epsilon=bn_eps, name=conv_name_base + '_x2_bn')(x)
x = Activation('relu', name=relu_name_base + '_x2')(x)
x = Conv2D(filters3, (1, 1), use_bias=False, name=conv_name_base + '_x3')(x)
x = BatchNormalization(axis=bn_axis, epsilon=bn_eps, name=conv_name_base + '_x3_bn')(x)
se = GlobalAveragePooling2D(name='pool' + block_name + '_gap')(x)
se = Dense(filters3 // 16, activation='relu', name = 'fc' + block_name + '_sqz')(se)
se = Dense(filters3, activation='sigmoid', name = 'fc' + block_name + '_exc')(se)
se = Reshape([1, 1, filters3])(se)
x = Multiply(name='scale' + block_name)([x, se])
x = layers.add([x, input_tensor], name='block_' + block_name)
x = Activation('relu', name=relu_name_base)(x)
return x
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
filters1, filters2, filters3 = filters
if K.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
bn_eps = 0.0001
block_name = str(stage) + "_" + str(block)
conv_name_base = "conv" + block_name
relu_name_base = "relu" + block_name
x = Conv2D(filters1, (1, 1), use_bias=False, name=conv_name_base + '_x1')(input_tensor)
x = BatchNormalization(axis=bn_axis, epsilon=bn_eps, name=conv_name_base + '_x1_bn')(x)
x = Activation('relu', name=relu_name_base + '_x1')(x)
x = Conv2D(filters2, kernel_size, strides=strides, padding='same', use_bias=False, name=conv_name_base + '_x2')(x)
x = BatchNormalization(axis=bn_axis, epsilon=bn_eps, name=conv_name_base + '_x2_bn')(x)
x = Activation('relu', name=relu_name_base + '_x2')(x)
x = Conv2D(filters3, (1, 1), use_bias=False, name=conv_name_base + '_x3')(x)
x = BatchNormalization(axis=bn_axis, epsilon=bn_eps, name=conv_name_base + '_x3_bn')(x)
se = GlobalAveragePooling2D(name='pool' + block_name + '_gap')(x)
se = Dense(filters3 // 16, activation='relu', name = 'fc' + block_name + '_sqz')(se)
se = Dense(filters3, activation='sigmoid', name = 'fc' + block_name + '_exc')(se)
se = Reshape([1, 1, filters3])(se)
x = Multiply(name='scale' + block_name)([x, se])
shortcut = Conv2D(filters3, (1, 1), strides=strides, use_bias=False, name=conv_name_base + '_prj')(input_tensor)
shortcut = BatchNormalization(axis=bn_axis, epsilon=bn_eps, name=conv_name_base + '_prj_bn')(shortcut)
x = layers.add([x, shortcut], name='block_' + block_name)
x = Activation('relu', name=relu_name_base)(x)
return x
def SEResNet50(include_top=True, weights='imagenet',
input_tensor=None, input_shape=None,
pooling=None,
classes=1000):
# Determine proper input shape
input_shape = _obtain_input_shape(input_shape,
default_size=225,
min_size=160,
data_format=K.image_data_format(),
require_flatten=include_top,
weights=weights)
if input_tensor is None:
img_input = Input(shape=input_shape)
else:
if not K.is_keras_tensor(input_tensor):
img_input = Input(tensor=input_tensor, shape=input_shape)
else:
img_input = input_tensor
if K.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
bn_eps = 0.0001
x = Conv2D(64, (7, 7), strides=(2, 2), padding='same', use_bias=False, name='conv1')(img_input)
x = BatchNormalization(axis=bn_axis, epsilon=bn_eps, name='conv1_bn')(x)
x = Activation('relu', name='relu1')(x)
x = MaxPooling2D((3, 3), strides=(2, 2), name='pool1')(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block=1, strides=(1, 1))
x = identity_block(x, 3, [64, 64, 256], stage=2, block=2)
x = identity_block(x, 3, [64, 64, 256], stage=2, block=3)
x = conv_block(x, 3, [128, 128, 512], stage=3, block=1)
x = identity_block(x, 3, [128, 128, 512], stage=3, block=2)
x = identity_block(x, 3, [128, 128, 512], stage=3, block=3)
x = identity_block(x, 3, [128, 128, 512], stage=3, block=4)
x = conv_block(x, 3, [256, 256, 1024], stage=4, block=1)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block=2)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block=3)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block=4)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block=5)
x = identity_block(x, 3, [256, 256, 1024], stage=4, block=6)
x = conv_block(x, 3, [512, 512, 2048], stage=5, block=1)
x = identity_block(x, 3, [512, 512, 2048], stage=5, block=2)
x = identity_block(x, 3, [512, 512, 2048], stage=5, block=3)
x = Flatten()(x)
x = Dense(classes, activation='softmax', name='fc6')(x)
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = get_source_inputs(input_tensor)
else:
inputs = img_input
# Create model.
model = Model(inputs, x, name='se-resnet50')
return model
model = SEResNet50(weights=None, input_shape=(160, 160, 3), classes=1000)
model.summary()
params = np.load("SENet_params.npy")
for key in params[()].keys():
layer_name = key.replace("/", "_")
print(key, "-->", layer_name)
layer = model.get_layer(layer_name)
layer.set_weights(params[()][key])
@desertnaut
Copy link

Hi and many thanks for this! Could you possibly coordinate with @titu1994, so that we finally have pretrained weights in a ready to use Keras module?

https://github.com/titu1994/keras-squeeze-excite-network

Thanks in advance

@cipher009
Copy link

cipher009 commented Apr 5, 2018

@hollance I get the following error:

I0405 12:31:17.480696 17251 net.cpp:816] Ignoring source layer top5/acc
conv1                    [(64, 3, 7, 7)]
Traceback (most recent call last):
  File "convert_weights.py", line 111, in <module>
    weights = np.array(param[0].data.data, dtype=np.float32).reshape(c_o, c_i, h, w)
ValueError: cannot reshape array of size 37632 into shape (64,3,7,7)

I've used the model and weights mentioned here:
https://github.com/shicai/SENet-Caffe

Any ideas?

@limadm
Copy link

limadm commented Oct 7, 2019

@cipher009 could you try with np.frombuffer instead of np.array in lines 64 and 111?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment