Skip to content

Instantly share code, notes, and snippets.

@adimyth
Last active November 9, 2020 18:04
Show Gist options
  • Save adimyth/3be8ca7c1bdd8cfc778aee52074d66d5 to your computer and use it in GitHub Desktop.
Save adimyth/3be8ca7c1bdd8cfc778aee52074d66d5 to your computer and use it in GitHub Desktop.
OCR Attention - TF 2.3 compatible files
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to read, decode and pre-process input data for the Model.
"""
import collections
import functools
import tensorflow as tf
import tf_slim as slim
import inception_preprocessing
# Tuple to store input data endpoints for the Model.
# It has following fields (tensors):
# images: input images,
# shape [batch_size x H x W x 3];
# labels: ground truth label ids,
# shape=[batch_size x seq_length];
# labels_one_hot: labels in one-hot encoding,
# shape [batch_size x seq_length x num_char_classes];
InputEndpoints = collections.namedtuple(
'InputEndpoints', ['images', 'images_orig', 'labels', 'labels_one_hot'])
# A namedtuple to define a configuration for shuffled batch fetching.
# num_batching_threads: A number of parallel threads to fetch data.
# queue_capacity: a max number of elements in the batch shuffling queue.
# min_after_dequeue: a min number elements in the queue after a dequeue, used
# to ensure a level of mixing of elements.
ShuffleBatchConfig = collections.namedtuple('ShuffleBatchConfig', [
'num_batching_threads', 'queue_capacity', 'min_after_dequeue'
])
DEFAULT_SHUFFLE_CONFIG = ShuffleBatchConfig(
num_batching_threads=8, queue_capacity=3000, min_after_dequeue=1000)
def augment_image(image):
"""Augmentation the image with a random modification.
Args:
image: input Tensor image of rank 3, with the last dimension
of size 3.
Returns:
Distorted Tensor image of the same shape.
"""
with tf.compat.v1.variable_scope('AugmentImage'):
height = image.get_shape().dims[0].value
width = image.get_shape().dims[1].value
# Random crop cut from the street sign image, resized to the same size.
# Assures that the crop is covers at least 0.8 area of the input image.
bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box(
image_size=tf.shape(input=image),
bounding_boxes=tf.zeros([0, 0, 4]),
min_object_covered=0.8,
aspect_ratio_range=[0.8, 1.2],
area_range=[0.8, 1.0],
use_image_if_no_bounding_boxes=True)
distorted_image = tf.slice(image, bbox_begin, bbox_size)
# Randomly chooses one of the 4 interpolation methods
distorted_image = inception_preprocessing.apply_with_random_selector(
distorted_image,
lambda x, method: tf.image.resize(x, [height, width], method),
num_cases=4)
distorted_image.set_shape([height, width, 3])
# Color distortion
distorted_image = inception_preprocessing.apply_with_random_selector(
distorted_image,
functools.partial(
inception_preprocessing.distort_color, fast_mode=False),
num_cases=4)
distorted_image = tf.clip_by_value(distorted_image, -1.5, 1.5)
return distorted_image
def central_crop(image, crop_size):
"""Returns a central crop for the specified size of an image.
Args:
image: A tensor with shape [height, width, channels]
crop_size: A tuple (crop_width, crop_height)
Returns:
A tensor of shape [crop_height, crop_width, channels].
"""
with tf.compat.v1.variable_scope('CentralCrop'):
target_width, target_height = crop_size
image_height, image_width = tf.shape(
input=image)[0], tf.shape(input=image)[1]
assert_op1 = tf.Assert(
tf.greater_equal(image_height, target_height),
['image_height < target_height', image_height, target_height])
assert_op2 = tf.Assert(
tf.greater_equal(image_width, target_width),
['image_width < target_width', image_width, target_width])
with tf.control_dependencies([assert_op1, assert_op2]):
offset_width = tf.cast((image_width - target_width) / 2, tf.int32)
offset_height = tf.cast((image_height - target_height) / 2, tf.int32)
return tf.image.crop_to_bounding_box(image, offset_height, offset_width,
target_height, target_width)
def preprocess_image(image, augment=False, central_crop_size=None,
num_towers=4):
"""Normalizes image to have values in a narrow range around zero.
Args:
image: a [H x W x 3] uint8 tensor.
augment: optional, if True do random image distortion.
central_crop_size: A tuple (crop_width, crop_height).
num_towers: optional, number of shots of the same image in the input image.
Returns:
A float32 tensor of shape [H x W x 3] with RGB values in the required
range.
"""
with tf.compat.v1.variable_scope('PreprocessImage'):
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
if augment or central_crop_size:
if num_towers == 1:
images = [image]
else:
images = tf.split(value=image, num_or_size_splits=num_towers, axis=1)
if central_crop_size:
view_crop_size = (int(central_crop_size[0] / num_towers),
central_crop_size[1])
images = [central_crop(img, view_crop_size) for img in images]
if augment:
images = [augment_image(img) for img in images]
image = tf.concat(images, 1)
return image
def get_data(dataset,
batch_size,
augment=False,
central_crop_size=None,
shuffle_config=None,
shuffle=True):
"""Wraps calls to DatasetDataProviders and shuffle_batch.
For more details about supported Dataset objects refer to datasets/fsns.py.
Args:
dataset: a slim.data.dataset.Dataset object.
batch_size: number of samples per batch.
augment: optional, if True does random image distortion.
central_crop_size: A CharLogittuple (crop_width, crop_height).
shuffle_config: A namedtuple ShuffleBatchConfig.
shuffle: if True use data shuffling.
Returns:
"""
if not shuffle_config:
shuffle_config = DEFAULT_SHUFFLE_CONFIG
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
shuffle=shuffle,
common_queue_capacity=2 * batch_size,
common_queue_min=batch_size)
image_orig, label = provider.get(['image', 'label'])
image = preprocess_image(
image_orig, augment, central_crop_size, num_towers=dataset.num_of_views)
label_one_hot = slim.one_hot_encoding(label, dataset.num_char_classes)
images, images_orig, labels, labels_one_hot = (tf.compat.v1.train.shuffle_batch(
[image, image_orig, label, label_one_hot],
batch_size=batch_size,
num_threads=shuffle_config.num_batching_threads,
capacity=shuffle_config.queue_capacity,
min_after_dequeue=shuffle_config.min_after_dequeue))
return InputEndpoints(
images=images,
images_orig=images_orig,
labels=labels,
labels_one_hot=labels_one_hot)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Script to train the Attention OCR model.
A simple usage example:
python train.py
"""
import collections
import logging
import tensorflow as tf
import tf_slim as slim
from tensorflow.compat.v1 import app
from tensorflow.python.platform import flags
import data_provider
import common_flags
FLAGS = flags.FLAGS
common_flags.define()
# yapf: disable
flags.DEFINE_integer('task', 0,
'The Task ID. This value is used when training with '
'multiple workers to identify each worker.')
flags.DEFINE_integer('ps_tasks', 0,
'The number of parameter servers. If the value is 0, then'
' the parameters are handled locally by the worker.')
flags.DEFINE_integer('save_summaries_secs', 60,
'The frequency with which summaries are saved, in '
'seconds.')
flags.DEFINE_integer('save_interval_secs', 600,
'Frequency in seconds of saving the model.')
flags.DEFINE_integer('max_number_of_steps', int(1e10),
'The maximum number of gradient steps.')
flags.DEFINE_string('checkpoint_inception', '',
'Checkpoint to recover inception weights from.')
flags.DEFINE_float('clip_gradient_norm', 2.0,
'If greater than 0 then the gradients would be clipped by '
'it.')
flags.DEFINE_bool('sync_replicas', False,
'If True will synchronize replicas during training.')
flags.DEFINE_integer('replicas_to_aggregate', 1,
'The number of gradients updates before updating params.')
flags.DEFINE_integer('total_num_replicas', 1,
'Total number of worker replicas.')
flags.DEFINE_integer('startup_delay_steps', 15,
'Number of training steps between replicas startup.')
flags.DEFINE_boolean('reset_train_dir', False,
'If true will delete all files in the train_log_dir')
flags.DEFINE_boolean('show_graph_stats', False,
'Output model size stats to stderr.')
# yapf: enable
TrainingHParams = collections.namedtuple(
"TrainingHParams",
[
"learning_rate",
"optimizer",
"momentum",
"use_augment_input",
],
)
def get_training_hparams():
return TrainingHParams(
learning_rate=FLAGS.learning_rate,
optimizer=FLAGS.optimizer,
momentum=FLAGS.momentum,
use_augment_input=FLAGS.use_augment_input,
)
def create_optimizer(hparams):
"""Creates optimized based on the specified flags."""
if hparams.optimizer == "momentum":
optimizer = tf.compat.v1.train.MomentumOptimizer(
hparams.learning_rate, momentum=hparams.momentum
)
elif hparams.optimizer == "adam":
optimizer = tf.compat.v1.train.AdamOptimizer(hparams.learning_rate)
elif hparams.optimizer == "adadelta":
optimizer = tf.compat.v1.train.AdadeltaOptimizer(hparams.learning_rate)
elif hparams.optimizer == "adagrad":
optimizer = tf.compat.v1.train.AdagradOptimizer(hparams.learning_rate)
elif hparams.optimizer == "rmsprop":
optimizer = tf.compat.v1.train.RMSPropOptimizer(
hparams.learning_rate, momentum=hparams.momentum
)
return optimizer
def train(loss, init_fn, hparams):
"""Wraps slim.learning.train to run a training loop.
Args:
loss: a loss tensor
init_fn: A callable to be executed after all other initialization is done.
hparams: a model hyper parameters
"""
optimizer = create_optimizer(hparams)
if FLAGS.sync_replicas:
replica_id = tf.constant(FLAGS.task, tf.int32, shape=())
optimizer = tf.LegacySyncReplicasOptimizer(
opt=optimizer,
replicas_to_aggregate=FLAGS.replicas_to_aggregate,
replica_id=replica_id,
total_num_replicas=FLAGS.total_num_replicas,
)
sync_optimizer = optimizer
startup_delay_steps = 0
else:
startup_delay_steps = 0
sync_optimizer = None
train_op = slim.learning.create_train_op(
loss,
optimizer,
summarize_gradients=True,
clip_gradient_norm=FLAGS.clip_gradient_norm,
)
slim.learning.train(
train_op=train_op,
logdir=FLAGS.train_log_dir,
graph=loss.graph,
master=FLAGS.master,
is_chief=(FLAGS.task == 0),
number_of_steps=FLAGS.max_number_of_steps,
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs,
startup_delay_steps=startup_delay_steps,
sync_optimizer=sync_optimizer,
init_fn=init_fn,
)
def prepare_training_dir():
if not tf.io.gfile.exists(FLAGS.train_log_dir):
logging.info("Create a new training directory %s", FLAGS.train_log_dir)
tf.io.gfile.makedirs(FLAGS.train_log_dir)
else:
if FLAGS.reset_train_dir:
logging.info("Reset the training directory %s", FLAGS.train_log_dir)
tf.io.gfile.rmtree(FLAGS.train_log_dir)
tf.io.gfile.makedirs(FLAGS.train_log_dir)
else:
logging.info(
"Use already existing training directory %s", FLAGS.train_log_dir
)
def main(_):
prepare_training_dir()
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(
dataset.num_char_classes,
dataset.max_sequence_length,
dataset.num_of_views,
dataset.null_code,
)
hparams = get_training_hparams()
# If ps_tasks is zero, the local device is used. When using multiple
# (non-local) replicas, the ReplicaDeviceSetter distributes the variables
# across the different devices.
device_setter = tf.compat.v1.train.replica_device_setter(
FLAGS.ps_tasks, merge_devices=True
)
with tf.device(device_setter):
data = data_provider.get_data(
dataset,
FLAGS.batch_size,
augment=hparams.use_augment_input,
central_crop_size=common_flags.get_crop_size(),
)
endpoints = model.create_base(data.images, data.labels_one_hot)
total_loss = model.create_loss(data, endpoints)
model.create_summaries(data, endpoints, dataset.charset, is_training=True)
init_fn = model.create_init_fn_to_restore(
FLAGS.checkpoint, FLAGS.checkpoint_inception
)
train(total_loss, init_fn, hparams)
if __name__ == "__main__":
app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment