Skip to content

Instantly share code, notes, and snippets.

@ProblemFactory
Created April 30, 2021 16:26
Show Gist options
  • Save ProblemFactory/401c68b341ae6338fdb8798598c60814 to your computer and use it in GitHub Desktop.
Save ProblemFactory/401c68b341ae6338fdb8798598c60814 to your computer and use it in GitHub Desktop.
"""MobileNet v3 models for Keras.
# Reference
[Searching for MobileNetV3](https://arxiv.org/abs/1905.02244?context=cs)
"""
from keras.models import Model
from keras.layers import Conv2D, DepthwiseConv2D, Dense, GlobalAveragePooling2D
from keras.layers import Input, Activation, BatchNormalization, Add, Multiply, Reshape
from keras.utils.vis_utils import plot_model
from keras import backend as K
class MobileNetBase:
def __init__(self, shape, n_class, alpha=1.0):
"""Init
# Arguments
input_shape: An integer or tuple/list of 3 integers, shape
of input tensor.
n_class: Integer, number of classes.
alpha: Integer, width multiplier.
"""
self.shape = shape
self.n_class = n_class
self.alpha = alpha
# def _relu6(self, x):
# """Relu 6
# """
# return K.relu(x, max_value=6.0)
# def _hard_swish(self, x):
# """Hard swish
# """
# return x * K.relu(x + 3.0, max_value=6.0) / 6.0
def _return_activation(self, x, nl):
"""Convolution Block
This function defines a activation choice.
# Arguments
x: Tensor, input tensor of conv layer.
nl: String, nonlinearity activation type.
# Returns
Output tensor.
"""
if nl == 'HS':
x = x * K.relu(x + 3.0, max_value=6.0) / 6.0
if nl == 'RE':
x = K.relu(x, max_value=6.0)
return x
def _conv_block(self, inputs, filters, kernel, strides, nl):
"""Convolution Block
This function defines a 2D convolution operation with BN and activation.
# Arguments
inputs: Tensor, input tensor of conv layer.
filters: Integer, the dimensionality of the output space.
kernel: An integer or tuple/list of 2 integers, specifying the
width and height of the 2D convolution window.
strides: An integer or tuple/list of 2 integers,
specifying the strides of the convolution along the width and height.
Can be a single integer to specify the same value for
all spatial dimensions.
nl: String, nonlinearity activation type.
# Returns
Output tensor.
"""
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
x = Conv2D(filters, kernel, padding='same', strides=strides)(inputs)
x = BatchNormalization(axis=channel_axis)(x)
return self._return_activation(x, nl)
def _squeeze(self, inputs):
"""Squeeze and Excitation.
This function defines a squeeze structure.
# Arguments
inputs: Tensor, input tensor of conv layer.
"""
input_channels = int(inputs.shape[-1])
x = GlobalAveragePooling2D()(inputs)
x = Dense(input_channels, activation='relu')(x)
x = Dense(input_channels, activation='hard_sigmoid')(x)
x = Reshape((1, 1, input_channels))(x)
x = Multiply()([inputs, x])
return x
def _bottleneck(self, inputs, filters, kernel, e, s, squeeze, nl):
"""Bottleneck
This function defines a basic bottleneck structure.
# Arguments
inputs: Tensor, input tensor of conv layer.
filters: Integer, the dimensionality of the output space.
kernel: An integer or tuple/list of 2 integers, specifying the
width and height of the 2D convolution window.
e: Integer, expansion factor.
t is always applied to the input size.
s: An integer or tuple/list of 2 integers,specifying the strides
of the convolution along the width and height.Can be a single
integer to specify the same value for all spatial dimensions.
squeeze: Boolean, Whether to use the squeeze.
nl: String, nonlinearity activation type.
# Returns
Output tensor.
"""
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
input_shape = K.int_shape(inputs)
tchannel = int(e)
cchannel = int(self.alpha * filters)
if type(s) == int:
r = s == 1 and input_shape[3] == filters
s = (s, s)
elif type(s) in [tuple, list]:
r = s[0]==s[1] == 1 and input_shape[3] == filters
else:
raise ValueError(f'Unknown Stride Type (Expected int/tuple/list) {repr(s)}')
x = self._conv_block(inputs, tchannel, (1, 1), (1, 1), nl)
x = DepthwiseConv2D(kernel, strides=s, depth_multiplier=1, padding='same')(x)
x = BatchNormalization(axis=channel_axis)(x)
x = self._return_activation(x, nl)
if squeeze:
x = self._squeeze(x)
x = Conv2D(cchannel, (1, 1), strides=(1, 1), padding='same')(x)
x = BatchNormalization(axis=channel_axis)(x)
if r:
x = Add()([x, inputs])
return x
def build(self):
pass
class MobileNetV3_Small(MobileNetBase):
def __init__(self, shape, n_class, alpha=1.0, include_top=True):
"""Init.
# Arguments
input_shape: An integer or tuple/list of 3 integers, shape
of input tensor.
n_class: Integer, number of classes.
alpha: Integer, width multiplier.
include_top: if inculde classification layer.
# Returns
MobileNetv3 model.
"""
super(MobileNetV3_Small, self).__init__(shape, n_class, alpha)
self.include_top = include_top
def build(self, plot=False):
"""build MobileNetV3 Small.
# Arguments
plot: Boolean, weather to plot model.
# Returns
model: Model, model.
"""
inputs = Input(shape=self.shape)
x = self._conv_block(inputs, 16, (3, 3), strides=(2, 2), nl='HS')
x = self._bottleneck(x, 16, (3, 3), e=16, s=1, squeeze=True, nl='RE')
x = self._bottleneck(x, 24, (3, 3), e=72, s=2, squeeze=False, nl='RE')
x = self._bottleneck(x, 24, (3, 3), e=88, s=1, squeeze=False, nl='RE')
x = self._bottleneck(x, 40, (5, 5), e=96, s=1, squeeze=True, nl='HS')
x = self._bottleneck(x, 40, (5, 5), e=240, s=1, squeeze=True, nl='HS')
x = self._bottleneck(x, 40, (5, 5), e=240, s=1, squeeze=True, nl='HS')
x = self._bottleneck(x, 48, (5, 5), e=120, s=1, squeeze=True, nl='HS')
x = self._bottleneck(x, 48, (5, 5), e=144, s=1, squeeze=True, nl='HS')
x = self._bottleneck(x, 96, (5, 5), e=288, s=2, squeeze=True, nl='HS')
x = self._bottleneck(x, 96, (5, 5), e=576, s=1, squeeze=True, nl='HS')
x = self._bottleneck(x, 96, (5, 5), e=576, s=1, squeeze=True, nl='HS')
x = self._conv_block(x, 576, (1, 1), strides=(1, 1), nl='HS')
if self.include_top:
x = GlobalAveragePooling2D()(x)
x = Reshape((1, 1, 576))(x)
x = Conv2D(1280, (1, 1), padding='same')(x)
x = self._return_activation(x, 'HS')
x = Conv2D(self.n_class, (1, 1), padding='same', activation='softmax')(x)
x = Reshape((self.n_class,))(x)
model = Model(inputs, x)
if plot:
plot_model(model, to_file='images/MobileNetv3_small.png', show_shapes=True)
return model
from mobilenetv3 import MobileNetV3_Small
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
@tf.autograph.experimental.do_not_convert
def ctc_loss(y_true, y_pred):
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
label_length = tf.math.count_nonzero(y_true, axis=-1, keepdims=True)
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
return keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)
# A utility function to decode the output of the network
def decode_batch_predictions(pred):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
# Use greedy search. For complex tasks, you can use beam search
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
:, :max_length
]
# Iterate over the results and get back the text
output_text = []
for res in results:
res = num_to_char(res)
res = tf.strings.reduce_join(res)
res = res.numpy().decode("utf-8")
output_text.append(res)
return output_text
class CTCAccuracy(tf.keras.metrics.Metric):
def __init__(self, name='ctc_accuracy', **kwargs):
super(CTCAccuracy, self).__init__(name=name, **kwargs)
self.correct_count = 0
self.all_count = 0
def update_state(self, y_true, y_pred, sample_weight=None):
pred_text = decode_batch_predictions(y_pred)
self.all_count += len(pred_text)
true_text = []
for res in y_true:
res = num_to_char(res)
res = tf.strings.reduce_join(res)
res = res.numpy().decode("utf-8")
true_text.append(res)
self.correct_count += sum([i==j for i,j in zip(pred_text, true_text)])
def result(self):
return self.correct_count/self.all_count
def reset_states(self):
self.correct_count = 0
self.all_count = 0
def build_model():
# Inputs to the model
input_img = layers.Input(
shape=(img_width, img_height, 1), name="image", dtype="float32"
)
mobilenet = MobileNetV3_Small(
(img_width, img_height, 1), 0, alpha=1.0, include_top=False
).build()
x = mobilenet(input_img)
new_shape = ((img_width // 8), (img_height // 8) * 576)
x = layers.Reshape(target_shape=new_shape, name="reshape")(x)
x = layers.Dense(64, activation="relu", name="dense1")(x)
x = layers.Dropout(0.2)(x)
# RNNs
x = layers.Bidirectional(layers.LSTM(128, return_sequences=True, dropout=0.25))(x)
x = layers.Bidirectional(layers.LSTM(64, return_sequences=True, dropout=0.25))(x)
# Output layer
output = layers.Dense(len(characters) + 2, activation="softmax", name="dense2")(x)
# Define the model
model = keras.models.Model(inputs=[input_img], outputs=output, name="ocr_model_v1")
# Optimizer
return model
# Get the model
model = build_model()
opt = keras.optimizers.Adam()
# Compile the model and return
model.compile(loss=ctc_loss, optimizer=opt, metrics=[CTCAccuracy('ctc_accu')])
model.run_eagerly = True
model.summary()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment