Created
April 30, 2021 16:26
-
-
Save ProblemFactory/401c68b341ae6338fdb8798598c60814 to your computer and use it in GitHub Desktop.
OCR model for https://github.com/ProblemFactory/GenshinArtScanner
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""MobileNet v3 models for Keras. | |
# Reference | |
[Searching for MobileNetV3](https://arxiv.org/abs/1905.02244?context=cs) | |
""" | |
from keras.models import Model | |
from keras.layers import Conv2D, DepthwiseConv2D, Dense, GlobalAveragePooling2D | |
from keras.layers import Input, Activation, BatchNormalization, Add, Multiply, Reshape | |
from keras.utils.vis_utils import plot_model | |
from keras import backend as K | |
class MobileNetBase: | |
def __init__(self, shape, n_class, alpha=1.0): | |
"""Init | |
# Arguments | |
input_shape: An integer or tuple/list of 3 integers, shape | |
of input tensor. | |
n_class: Integer, number of classes. | |
alpha: Integer, width multiplier. | |
""" | |
self.shape = shape | |
self.n_class = n_class | |
self.alpha = alpha | |
# def _relu6(self, x): | |
# """Relu 6 | |
# """ | |
# return K.relu(x, max_value=6.0) | |
# def _hard_swish(self, x): | |
# """Hard swish | |
# """ | |
# return x * K.relu(x + 3.0, max_value=6.0) / 6.0 | |
def _return_activation(self, x, nl): | |
"""Convolution Block | |
This function defines a activation choice. | |
# Arguments | |
x: Tensor, input tensor of conv layer. | |
nl: String, nonlinearity activation type. | |
# Returns | |
Output tensor. | |
""" | |
if nl == 'HS': | |
x = x * K.relu(x + 3.0, max_value=6.0) / 6.0 | |
if nl == 'RE': | |
x = K.relu(x, max_value=6.0) | |
return x | |
def _conv_block(self, inputs, filters, kernel, strides, nl): | |
"""Convolution Block | |
This function defines a 2D convolution operation with BN and activation. | |
# Arguments | |
inputs: Tensor, input tensor of conv layer. | |
filters: Integer, the dimensionality of the output space. | |
kernel: An integer or tuple/list of 2 integers, specifying the | |
width and height of the 2D convolution window. | |
strides: An integer or tuple/list of 2 integers, | |
specifying the strides of the convolution along the width and height. | |
Can be a single integer to specify the same value for | |
all spatial dimensions. | |
nl: String, nonlinearity activation type. | |
# Returns | |
Output tensor. | |
""" | |
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 | |
x = Conv2D(filters, kernel, padding='same', strides=strides)(inputs) | |
x = BatchNormalization(axis=channel_axis)(x) | |
return self._return_activation(x, nl) | |
def _squeeze(self, inputs): | |
"""Squeeze and Excitation. | |
This function defines a squeeze structure. | |
# Arguments | |
inputs: Tensor, input tensor of conv layer. | |
""" | |
input_channels = int(inputs.shape[-1]) | |
x = GlobalAveragePooling2D()(inputs) | |
x = Dense(input_channels, activation='relu')(x) | |
x = Dense(input_channels, activation='hard_sigmoid')(x) | |
x = Reshape((1, 1, input_channels))(x) | |
x = Multiply()([inputs, x]) | |
return x | |
def _bottleneck(self, inputs, filters, kernel, e, s, squeeze, nl): | |
"""Bottleneck | |
This function defines a basic bottleneck structure. | |
# Arguments | |
inputs: Tensor, input tensor of conv layer. | |
filters: Integer, the dimensionality of the output space. | |
kernel: An integer or tuple/list of 2 integers, specifying the | |
width and height of the 2D convolution window. | |
e: Integer, expansion factor. | |
t is always applied to the input size. | |
s: An integer or tuple/list of 2 integers,specifying the strides | |
of the convolution along the width and height.Can be a single | |
integer to specify the same value for all spatial dimensions. | |
squeeze: Boolean, Whether to use the squeeze. | |
nl: String, nonlinearity activation type. | |
# Returns | |
Output tensor. | |
""" | |
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 | |
input_shape = K.int_shape(inputs) | |
tchannel = int(e) | |
cchannel = int(self.alpha * filters) | |
if type(s) == int: | |
r = s == 1 and input_shape[3] == filters | |
s = (s, s) | |
elif type(s) in [tuple, list]: | |
r = s[0]==s[1] == 1 and input_shape[3] == filters | |
else: | |
raise ValueError(f'Unknown Stride Type (Expected int/tuple/list) {repr(s)}') | |
x = self._conv_block(inputs, tchannel, (1, 1), (1, 1), nl) | |
x = DepthwiseConv2D(kernel, strides=s, depth_multiplier=1, padding='same')(x) | |
x = BatchNormalization(axis=channel_axis)(x) | |
x = self._return_activation(x, nl) | |
if squeeze: | |
x = self._squeeze(x) | |
x = Conv2D(cchannel, (1, 1), strides=(1, 1), padding='same')(x) | |
x = BatchNormalization(axis=channel_axis)(x) | |
if r: | |
x = Add()([x, inputs]) | |
return x | |
def build(self): | |
pass | |
class MobileNetV3_Small(MobileNetBase): | |
def __init__(self, shape, n_class, alpha=1.0, include_top=True): | |
"""Init. | |
# Arguments | |
input_shape: An integer or tuple/list of 3 integers, shape | |
of input tensor. | |
n_class: Integer, number of classes. | |
alpha: Integer, width multiplier. | |
include_top: if inculde classification layer. | |
# Returns | |
MobileNetv3 model. | |
""" | |
super(MobileNetV3_Small, self).__init__(shape, n_class, alpha) | |
self.include_top = include_top | |
def build(self, plot=False): | |
"""build MobileNetV3 Small. | |
# Arguments | |
plot: Boolean, weather to plot model. | |
# Returns | |
model: Model, model. | |
""" | |
inputs = Input(shape=self.shape) | |
x = self._conv_block(inputs, 16, (3, 3), strides=(2, 2), nl='HS') | |
x = self._bottleneck(x, 16, (3, 3), e=16, s=1, squeeze=True, nl='RE') | |
x = self._bottleneck(x, 24, (3, 3), e=72, s=2, squeeze=False, nl='RE') | |
x = self._bottleneck(x, 24, (3, 3), e=88, s=1, squeeze=False, nl='RE') | |
x = self._bottleneck(x, 40, (5, 5), e=96, s=1, squeeze=True, nl='HS') | |
x = self._bottleneck(x, 40, (5, 5), e=240, s=1, squeeze=True, nl='HS') | |
x = self._bottleneck(x, 40, (5, 5), e=240, s=1, squeeze=True, nl='HS') | |
x = self._bottleneck(x, 48, (5, 5), e=120, s=1, squeeze=True, nl='HS') | |
x = self._bottleneck(x, 48, (5, 5), e=144, s=1, squeeze=True, nl='HS') | |
x = self._bottleneck(x, 96, (5, 5), e=288, s=2, squeeze=True, nl='HS') | |
x = self._bottleneck(x, 96, (5, 5), e=576, s=1, squeeze=True, nl='HS') | |
x = self._bottleneck(x, 96, (5, 5), e=576, s=1, squeeze=True, nl='HS') | |
x = self._conv_block(x, 576, (1, 1), strides=(1, 1), nl='HS') | |
if self.include_top: | |
x = GlobalAveragePooling2D()(x) | |
x = Reshape((1, 1, 576))(x) | |
x = Conv2D(1280, (1, 1), padding='same')(x) | |
x = self._return_activation(x, 'HS') | |
x = Conv2D(self.n_class, (1, 1), padding='same', activation='softmax')(x) | |
x = Reshape((self.n_class,))(x) | |
model = Model(inputs, x) | |
if plot: | |
plot_model(model, to_file='images/MobileNetv3_small.png', show_shapes=True) | |
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from mobilenetv3 import MobileNetV3_Small | |
import tensorflow as tf | |
from tensorflow import keras | |
from tensorflow.keras import layers | |
@tf.autograph.experimental.do_not_convert | |
def ctc_loss(y_true, y_pred): | |
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64") | |
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64") | |
label_length = tf.math.count_nonzero(y_true, axis=-1, keepdims=True) | |
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64") | |
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64") | |
return keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length) | |
# A utility function to decode the output of the network | |
def decode_batch_predictions(pred): | |
input_len = np.ones(pred.shape[0]) * pred.shape[1] | |
# Use greedy search. For complex tasks, you can use beam search | |
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][ | |
:, :max_length | |
] | |
# Iterate over the results and get back the text | |
output_text = [] | |
for res in results: | |
res = num_to_char(res) | |
res = tf.strings.reduce_join(res) | |
res = res.numpy().decode("utf-8") | |
output_text.append(res) | |
return output_text | |
class CTCAccuracy(tf.keras.metrics.Metric): | |
def __init__(self, name='ctc_accuracy', **kwargs): | |
super(CTCAccuracy, self).__init__(name=name, **kwargs) | |
self.correct_count = 0 | |
self.all_count = 0 | |
def update_state(self, y_true, y_pred, sample_weight=None): | |
pred_text = decode_batch_predictions(y_pred) | |
self.all_count += len(pred_text) | |
true_text = [] | |
for res in y_true: | |
res = num_to_char(res) | |
res = tf.strings.reduce_join(res) | |
res = res.numpy().decode("utf-8") | |
true_text.append(res) | |
self.correct_count += sum([i==j for i,j in zip(pred_text, true_text)]) | |
def result(self): | |
return self.correct_count/self.all_count | |
def reset_states(self): | |
self.correct_count = 0 | |
self.all_count = 0 | |
def build_model(): | |
# Inputs to the model | |
input_img = layers.Input( | |
shape=(img_width, img_height, 1), name="image", dtype="float32" | |
) | |
mobilenet = MobileNetV3_Small( | |
(img_width, img_height, 1), 0, alpha=1.0, include_top=False | |
).build() | |
x = mobilenet(input_img) | |
new_shape = ((img_width // 8), (img_height // 8) * 576) | |
x = layers.Reshape(target_shape=new_shape, name="reshape")(x) | |
x = layers.Dense(64, activation="relu", name="dense1")(x) | |
x = layers.Dropout(0.2)(x) | |
# RNNs | |
x = layers.Bidirectional(layers.LSTM(128, return_sequences=True, dropout=0.25))(x) | |
x = layers.Bidirectional(layers.LSTM(64, return_sequences=True, dropout=0.25))(x) | |
# Output layer | |
output = layers.Dense(len(characters) + 2, activation="softmax", name="dense2")(x) | |
# Define the model | |
model = keras.models.Model(inputs=[input_img], outputs=output, name="ocr_model_v1") | |
# Optimizer | |
return model | |
# Get the model | |
model = build_model() | |
opt = keras.optimizers.Adam() | |
# Compile the model and return | |
model.compile(loss=ctc_loss, optimizer=opt, metrics=[CTCAccuracy('ctc_accu')]) | |
model.run_eagerly = True | |
model.summary() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment