Last active
November 9, 2020 18:04
-
-
Save adimyth/3be8ca7c1bdd8cfc778aee52074d66d5 to your computer and use it in GitHub Desktop.
OCR Attention - TF 2.3 compatible files
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Functions to read, decode and pre-process input data for the Model. | |
""" | |
import collections | |
import functools | |
import tensorflow as tf | |
import tf_slim as slim | |
import inception_preprocessing | |
# Tuple to store input data endpoints for the Model. | |
# It has following fields (tensors): | |
# images: input images, | |
# shape [batch_size x H x W x 3]; | |
# labels: ground truth label ids, | |
# shape=[batch_size x seq_length]; | |
# labels_one_hot: labels in one-hot encoding, | |
# shape [batch_size x seq_length x num_char_classes]; | |
InputEndpoints = collections.namedtuple( | |
'InputEndpoints', ['images', 'images_orig', 'labels', 'labels_one_hot']) | |
# A namedtuple to define a configuration for shuffled batch fetching. | |
# num_batching_threads: A number of parallel threads to fetch data. | |
# queue_capacity: a max number of elements in the batch shuffling queue. | |
# min_after_dequeue: a min number elements in the queue after a dequeue, used | |
# to ensure a level of mixing of elements. | |
ShuffleBatchConfig = collections.namedtuple('ShuffleBatchConfig', [ | |
'num_batching_threads', 'queue_capacity', 'min_after_dequeue' | |
]) | |
DEFAULT_SHUFFLE_CONFIG = ShuffleBatchConfig( | |
num_batching_threads=8, queue_capacity=3000, min_after_dequeue=1000) | |
def augment_image(image): | |
"""Augmentation the image with a random modification. | |
Args: | |
image: input Tensor image of rank 3, with the last dimension | |
of size 3. | |
Returns: | |
Distorted Tensor image of the same shape. | |
""" | |
with tf.compat.v1.variable_scope('AugmentImage'): | |
height = image.get_shape().dims[0].value | |
width = image.get_shape().dims[1].value | |
# Random crop cut from the street sign image, resized to the same size. | |
# Assures that the crop is covers at least 0.8 area of the input image. | |
bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box( | |
image_size=tf.shape(input=image), | |
bounding_boxes=tf.zeros([0, 0, 4]), | |
min_object_covered=0.8, | |
aspect_ratio_range=[0.8, 1.2], | |
area_range=[0.8, 1.0], | |
use_image_if_no_bounding_boxes=True) | |
distorted_image = tf.slice(image, bbox_begin, bbox_size) | |
# Randomly chooses one of the 4 interpolation methods | |
distorted_image = inception_preprocessing.apply_with_random_selector( | |
distorted_image, | |
lambda x, method: tf.image.resize(x, [height, width], method), | |
num_cases=4) | |
distorted_image.set_shape([height, width, 3]) | |
# Color distortion | |
distorted_image = inception_preprocessing.apply_with_random_selector( | |
distorted_image, | |
functools.partial( | |
inception_preprocessing.distort_color, fast_mode=False), | |
num_cases=4) | |
distorted_image = tf.clip_by_value(distorted_image, -1.5, 1.5) | |
return distorted_image | |
def central_crop(image, crop_size): | |
"""Returns a central crop for the specified size of an image. | |
Args: | |
image: A tensor with shape [height, width, channels] | |
crop_size: A tuple (crop_width, crop_height) | |
Returns: | |
A tensor of shape [crop_height, crop_width, channels]. | |
""" | |
with tf.compat.v1.variable_scope('CentralCrop'): | |
target_width, target_height = crop_size | |
image_height, image_width = tf.shape( | |
input=image)[0], tf.shape(input=image)[1] | |
assert_op1 = tf.Assert( | |
tf.greater_equal(image_height, target_height), | |
['image_height < target_height', image_height, target_height]) | |
assert_op2 = tf.Assert( | |
tf.greater_equal(image_width, target_width), | |
['image_width < target_width', image_width, target_width]) | |
with tf.control_dependencies([assert_op1, assert_op2]): | |
offset_width = tf.cast((image_width - target_width) / 2, tf.int32) | |
offset_height = tf.cast((image_height - target_height) / 2, tf.int32) | |
return tf.image.crop_to_bounding_box(image, offset_height, offset_width, | |
target_height, target_width) | |
def preprocess_image(image, augment=False, central_crop_size=None, | |
num_towers=4): | |
"""Normalizes image to have values in a narrow range around zero. | |
Args: | |
image: a [H x W x 3] uint8 tensor. | |
augment: optional, if True do random image distortion. | |
central_crop_size: A tuple (crop_width, crop_height). | |
num_towers: optional, number of shots of the same image in the input image. | |
Returns: | |
A float32 tensor of shape [H x W x 3] with RGB values in the required | |
range. | |
""" | |
with tf.compat.v1.variable_scope('PreprocessImage'): | |
image = tf.image.convert_image_dtype(image, dtype=tf.float32) | |
if augment or central_crop_size: | |
if num_towers == 1: | |
images = [image] | |
else: | |
images = tf.split(value=image, num_or_size_splits=num_towers, axis=1) | |
if central_crop_size: | |
view_crop_size = (int(central_crop_size[0] / num_towers), | |
central_crop_size[1]) | |
images = [central_crop(img, view_crop_size) for img in images] | |
if augment: | |
images = [augment_image(img) for img in images] | |
image = tf.concat(images, 1) | |
return image | |
def get_data(dataset, | |
batch_size, | |
augment=False, | |
central_crop_size=None, | |
shuffle_config=None, | |
shuffle=True): | |
"""Wraps calls to DatasetDataProviders and shuffle_batch. | |
For more details about supported Dataset objects refer to datasets/fsns.py. | |
Args: | |
dataset: a slim.data.dataset.Dataset object. | |
batch_size: number of samples per batch. | |
augment: optional, if True does random image distortion. | |
central_crop_size: A CharLogittuple (crop_width, crop_height). | |
shuffle_config: A namedtuple ShuffleBatchConfig. | |
shuffle: if True use data shuffling. | |
Returns: | |
""" | |
if not shuffle_config: | |
shuffle_config = DEFAULT_SHUFFLE_CONFIG | |
provider = slim.dataset_data_provider.DatasetDataProvider( | |
dataset, | |
shuffle=shuffle, | |
common_queue_capacity=2 * batch_size, | |
common_queue_min=batch_size) | |
image_orig, label = provider.get(['image', 'label']) | |
image = preprocess_image( | |
image_orig, augment, central_crop_size, num_towers=dataset.num_of_views) | |
label_one_hot = slim.one_hot_encoding(label, dataset.num_char_classes) | |
images, images_orig, labels, labels_one_hot = (tf.compat.v1.train.shuffle_batch( | |
[image, image_orig, label, label_one_hot], | |
batch_size=batch_size, | |
num_threads=shuffle_config.num_batching_threads, | |
capacity=shuffle_config.queue_capacity, | |
min_after_dequeue=shuffle_config.min_after_dequeue)) | |
return InputEndpoints( | |
images=images, | |
images_orig=images_orig, | |
labels=labels, | |
labels_one_hot=labels_one_hot) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Script to train the Attention OCR model. | |
A simple usage example: | |
python train.py | |
""" | |
import collections | |
import logging | |
import tensorflow as tf | |
import tf_slim as slim | |
from tensorflow.compat.v1 import app | |
from tensorflow.python.platform import flags | |
import data_provider | |
import common_flags | |
FLAGS = flags.FLAGS | |
common_flags.define() | |
# yapf: disable | |
flags.DEFINE_integer('task', 0, | |
'The Task ID. This value is used when training with ' | |
'multiple workers to identify each worker.') | |
flags.DEFINE_integer('ps_tasks', 0, | |
'The number of parameter servers. If the value is 0, then' | |
' the parameters are handled locally by the worker.') | |
flags.DEFINE_integer('save_summaries_secs', 60, | |
'The frequency with which summaries are saved, in ' | |
'seconds.') | |
flags.DEFINE_integer('save_interval_secs', 600, | |
'Frequency in seconds of saving the model.') | |
flags.DEFINE_integer('max_number_of_steps', int(1e10), | |
'The maximum number of gradient steps.') | |
flags.DEFINE_string('checkpoint_inception', '', | |
'Checkpoint to recover inception weights from.') | |
flags.DEFINE_float('clip_gradient_norm', 2.0, | |
'If greater than 0 then the gradients would be clipped by ' | |
'it.') | |
flags.DEFINE_bool('sync_replicas', False, | |
'If True will synchronize replicas during training.') | |
flags.DEFINE_integer('replicas_to_aggregate', 1, | |
'The number of gradients updates before updating params.') | |
flags.DEFINE_integer('total_num_replicas', 1, | |
'Total number of worker replicas.') | |
flags.DEFINE_integer('startup_delay_steps', 15, | |
'Number of training steps between replicas startup.') | |
flags.DEFINE_boolean('reset_train_dir', False, | |
'If true will delete all files in the train_log_dir') | |
flags.DEFINE_boolean('show_graph_stats', False, | |
'Output model size stats to stderr.') | |
# yapf: enable | |
TrainingHParams = collections.namedtuple( | |
"TrainingHParams", | |
[ | |
"learning_rate", | |
"optimizer", | |
"momentum", | |
"use_augment_input", | |
], | |
) | |
def get_training_hparams(): | |
return TrainingHParams( | |
learning_rate=FLAGS.learning_rate, | |
optimizer=FLAGS.optimizer, | |
momentum=FLAGS.momentum, | |
use_augment_input=FLAGS.use_augment_input, | |
) | |
def create_optimizer(hparams): | |
"""Creates optimized based on the specified flags.""" | |
if hparams.optimizer == "momentum": | |
optimizer = tf.compat.v1.train.MomentumOptimizer( | |
hparams.learning_rate, momentum=hparams.momentum | |
) | |
elif hparams.optimizer == "adam": | |
optimizer = tf.compat.v1.train.AdamOptimizer(hparams.learning_rate) | |
elif hparams.optimizer == "adadelta": | |
optimizer = tf.compat.v1.train.AdadeltaOptimizer(hparams.learning_rate) | |
elif hparams.optimizer == "adagrad": | |
optimizer = tf.compat.v1.train.AdagradOptimizer(hparams.learning_rate) | |
elif hparams.optimizer == "rmsprop": | |
optimizer = tf.compat.v1.train.RMSPropOptimizer( | |
hparams.learning_rate, momentum=hparams.momentum | |
) | |
return optimizer | |
def train(loss, init_fn, hparams): | |
"""Wraps slim.learning.train to run a training loop. | |
Args: | |
loss: a loss tensor | |
init_fn: A callable to be executed after all other initialization is done. | |
hparams: a model hyper parameters | |
""" | |
optimizer = create_optimizer(hparams) | |
if FLAGS.sync_replicas: | |
replica_id = tf.constant(FLAGS.task, tf.int32, shape=()) | |
optimizer = tf.LegacySyncReplicasOptimizer( | |
opt=optimizer, | |
replicas_to_aggregate=FLAGS.replicas_to_aggregate, | |
replica_id=replica_id, | |
total_num_replicas=FLAGS.total_num_replicas, | |
) | |
sync_optimizer = optimizer | |
startup_delay_steps = 0 | |
else: | |
startup_delay_steps = 0 | |
sync_optimizer = None | |
train_op = slim.learning.create_train_op( | |
loss, | |
optimizer, | |
summarize_gradients=True, | |
clip_gradient_norm=FLAGS.clip_gradient_norm, | |
) | |
slim.learning.train( | |
train_op=train_op, | |
logdir=FLAGS.train_log_dir, | |
graph=loss.graph, | |
master=FLAGS.master, | |
is_chief=(FLAGS.task == 0), | |
number_of_steps=FLAGS.max_number_of_steps, | |
save_summaries_secs=FLAGS.save_summaries_secs, | |
save_interval_secs=FLAGS.save_interval_secs, | |
startup_delay_steps=startup_delay_steps, | |
sync_optimizer=sync_optimizer, | |
init_fn=init_fn, | |
) | |
def prepare_training_dir(): | |
if not tf.io.gfile.exists(FLAGS.train_log_dir): | |
logging.info("Create a new training directory %s", FLAGS.train_log_dir) | |
tf.io.gfile.makedirs(FLAGS.train_log_dir) | |
else: | |
if FLAGS.reset_train_dir: | |
logging.info("Reset the training directory %s", FLAGS.train_log_dir) | |
tf.io.gfile.rmtree(FLAGS.train_log_dir) | |
tf.io.gfile.makedirs(FLAGS.train_log_dir) | |
else: | |
logging.info( | |
"Use already existing training directory %s", FLAGS.train_log_dir | |
) | |
def main(_): | |
prepare_training_dir() | |
dataset = common_flags.create_dataset(split_name=FLAGS.split_name) | |
model = common_flags.create_model( | |
dataset.num_char_classes, | |
dataset.max_sequence_length, | |
dataset.num_of_views, | |
dataset.null_code, | |
) | |
hparams = get_training_hparams() | |
# If ps_tasks is zero, the local device is used. When using multiple | |
# (non-local) replicas, the ReplicaDeviceSetter distributes the variables | |
# across the different devices. | |
device_setter = tf.compat.v1.train.replica_device_setter( | |
FLAGS.ps_tasks, merge_devices=True | |
) | |
with tf.device(device_setter): | |
data = data_provider.get_data( | |
dataset, | |
FLAGS.batch_size, | |
augment=hparams.use_augment_input, | |
central_crop_size=common_flags.get_crop_size(), | |
) | |
endpoints = model.create_base(data.images, data.labels_one_hot) | |
total_loss = model.create_loss(data, endpoints) | |
model.create_summaries(data, endpoints, dataset.charset, is_training=True) | |
init_fn = model.create_init_fn_to_restore( | |
FLAGS.checkpoint, FLAGS.checkpoint_inception | |
) | |
train(total_loss, init_fn, hparams) | |
if __name__ == "__main__": | |
app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment