Skip to content

Instantly share code, notes, and snippets.

@jackyyeh5111
Created June 28, 2019 11:30
Show Gist options
  • Save jackyyeh5111/1820bda6c0958282261f53f7dfecd98a to your computer and use it in GitHub Desktop.
Save jackyyeh5111/1820bda6c0958282261f53f7dfecd98a to your computer and use it in GitHub Desktop.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generic training script that trains a model using a given dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import time
from datasets import dataset_factory
from deployment import model_deploy
from nets import nets_factory
from preprocessing import preprocessing_factory
from tensorflow.python.client import timeline
from losses import multilabel_losses, ghm_loss
from tensorflow.python.ops.losses import util
import math
#import cv2
import numpy as np
from tqdm import tqdm
import os
slim = tf.contrib.slim
# seed = 123
# tf.random.set_random_seed(seed)
# np.random.seed(seed)
tf.app.flags.DEFINE_string(
'master', '', 'The address of the TensorFlow master to use.')
tf.app.flags.DEFINE_string(
'train_dir', '/tmp/tfmodel/',
'Directory where checkpoints and event logs are written to.')
tf.app.flags.DEFINE_integer('num_clones', 1,
'Number of model clones to deploy. Note For '
'historical reasons loss from all clones averaged '
'out and learning rate decay happen per clone '
'epochs')
tf.app.flags.DEFINE_boolean('clone_on_cpu', False,
'Use CPUs to deploy clones.')
tf.app.flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas.')
tf.app.flags.DEFINE_integer(
'num_ps_tasks', 0,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.')
tf.app.flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
tf.app.flags.DEFINE_integer(
'num_preprocessing_threads', 6,
'The number of threads used to create the batches.')
tf.app.flags.DEFINE_integer(
'log_every_n_steps', 10,
'The frequency with which logs are print.')
tf.app.flags.DEFINE_integer(
'save_summaries_secs', 600,
'The frequency with which summaries are saved, in seconds.')
tf.app.flags.DEFINE_integer(
'save_interval_secs', 600,
'The frequency with which the model is saved, in seconds.')
tf.app.flags.DEFINE_integer(
'val_interval', 50, 'run val every val_interval.')
tf.app.flags.DEFINE_boolean(
'use_focal_loss',False,'whether to use focal loss or not')
tf.app.flags.DEFINE_boolean(
'use_ghm_loss',False,'whether to use gradient harmonic loss or not')
tf.app.flags.DEFINE_float(
'dropout_keep_rate',0.8,'the dropout keep rate. \
In effect only using resnet_dropout')
tf.app.flags.DEFINE_integer(
'task', 0, 'Task id of the replica running the training.')
######################
# Optimization Flags #
######################
tf.app.flags.DEFINE_float(
'alpha', 0.75, 'The alpha value of focal loss.')
tf.app.flags.DEFINE_float(
'weight_decay', 0.00004, 'The weight decay on the model weights.')
tf.app.flags.DEFINE_string(
'optimizer', 'rmsprop',
'The name of the optimizer, one of "adadelta", "adagrad", "adam",'
'"ftrl", "momentum", "sgd" or "rmsprop".')
tf.app.flags.DEFINE_float(
'adadelta_rho', 0.95,
'The decay rate for adadelta.')
tf.app.flags.DEFINE_float(
'adagrad_initial_accumulator_value', 0.1,
'Starting value for the AdaGrad accumulators.')
tf.app.flags.DEFINE_float(
'adam_beta1', 0.9,
'The exponential decay rate for the 1st moment estimates.')
tf.app.flags.DEFINE_float(
'adam_beta2', 0.999,
'The exponential decay rate for the 2nd moment estimates.')
tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 'Epsilon term for the optimizer.')
tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5,
'The learning rate power.')
tf.app.flags.DEFINE_float(
'ftrl_initial_accumulator_value', 0.1,
'Starting value for the FTRL accumulators.')
tf.app.flags.DEFINE_float(
'ftrl_l1', 0.0, 'The FTRL l1 regularization strength.')
tf.app.flags.DEFINE_float(
'ftrl_l2', 0.0, 'The FTRL l2 regularization strength.')
tf.app.flags.DEFINE_float(
'momentum', 0.9,
'The momentum for the MomentumOptimizer and RMSPropOptimizer.')
tf.app.flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum.')
tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')
#######################
# Learning Rate Flags #
#######################
tf.app.flags.DEFINE_string(
'learning_rate_decay_type',
'exponential',
'Specifies how the learning rate is decayed. One of "fixed", "exponential",'
' or "polynomial"')
tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
tf.app.flags.DEFINE_float(
'end_learning_rate', 0.0001,
'The minimal end learning rate used by a polynomial decay learning rate.')
tf.app.flags.DEFINE_float(
'label_smoothing', 0.0, 'The amount of label smoothing.')
tf.app.flags.DEFINE_float(
'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.')
tf.app.flags.DEFINE_float(
'num_epochs_per_decay', 2.0,
'Number of epochs after which learning rate decays. Note: this flag counts '
'epochs per clone but aggregates per sync replicas. So 1.0 means that '
'each clone will go over full epoch individually, but replicas will go '
'once across all replicas.')
tf.app.flags.DEFINE_bool(
'sync_replicas', False,
'Whether or not to synchronize the replicas during training.')
tf.app.flags.DEFINE_integer(
'replicas_to_aggregate', 1,
'The Number of gradients to collect before updating params.')
tf.app.flags.DEFINE_float(
'moving_average_decay', None,
'The decay to use for the moving average.'
'If left as None, then moving averages are not used.')
#######################
# Dataset Flags #
#######################
tf.app.flags.DEFINE_string(
'dataset_name', 'imagenet', 'The name of the dataset to load.')
tf.app.flags.DEFINE_string(
'dataset_split_name', 'train', 'The name of the train/test split.')
tf.app.flags.DEFINE_string(
'val_dataset_split_name', 'test', 'The name of the validation split.')
tf.app.flags.DEFINE_string(
'dataset_dir', None, 'The directory where the dataset files are stored.')
tf.app.flags.DEFINE_integer(
'labels_offset', 1,
'An offset for the labels in the dataset. This flag is primarily used to '
'evaluate the VGG and ResNet architectures which do not use a background '
'class for the ImageNet dataset.')
tf.app.flags.DEFINE_string(
'model_name', 'inception_v3', 'The name of the architecture to train.')
tf.app.flags.DEFINE_string(
'preprocessing_name', 'resnet_cuplid', 'The name of the preprocessing to use. If left '
'as `None`, then the model_name flag is used.')
tf.app.flags.DEFINE_integer(
'batch_size', 32, 'The number of samples in each batch.')
tf.app.flags.DEFINE_integer(
'val_batch_size', 5, 'The number of samples in each val batch.')
tf.app.flags.DEFINE_integer(
'train_image_height', 400, 'Train image height') #384
tf.app.flags.DEFINE_integer(
'train_image_width', 400, 'Train image width') #512
tf.app.flags.DEFINE_integer('max_number_of_steps', None,
'The maximum number of training steps.')
tf.app.flags.DEFINE_boolean(
'use_more_augmentation', True,
'Using more augmentation methods.')
tf.app.flags.DEFINE_boolean(
'use_batch_preprocessing', False,
'Using batch only preprocessing_fn.')
#####################
# Fine-Tuning Flags #
#####################
tf.app.flags.DEFINE_string(
'checkpoint_path', None,
'The path to a checkpoint from which to fine-tune.')
tf.app.flags.DEFINE_string(
'checkpoint_exclude_scopes', None,
'Comma-separated list of scopes of variables to exclude when restoring '
'from a checkpoint.')
tf.app.flags.DEFINE_string(
'trainable_scopes', None,
'Comma-separated list of scopes to filter the set of variables to train.'
'By default, None would train all the variables.')
tf.app.flags.DEFINE_boolean(
'ignore_missing_vars', False,
'When restoring a checkpoint would ignore missing variables.')
#####################
# Mode Flags #
#####################
tf.app.flags.DEFINE_boolean(
'is_binary_cls', True,
'Using Binary Classifier or Multi-label Classifier')
FLAGS = tf.app.flags.FLAGS
def _configure_learning_rate(num_samples_per_epoch, global_step):
"""Configures the learning rate.
Args:
num_samples_per_epoch: The number of samples in each epoch of training.
global_step: The global_step tensor.
Returns:
A `Tensor` representing the learning rate.
Raises:
ValueError: if
"""
# Note: when num_clones is > 1, this will actually have each clone to go
# over each epoch FLAGS.num_epochs_per_decay times. This is different
# behavior from sync replicas and is expected to produce different results.
decay_steps = int(num_samples_per_epoch * FLAGS.num_epochs_per_decay /
FLAGS.batch_size)
if FLAGS.sync_replicas:
decay_steps /= FLAGS.replicas_to_aggregate
if FLAGS.learning_rate_decay_type == 'exponential':
return tf.train.exponential_decay(FLAGS.learning_rate,
global_step,
decay_steps,
FLAGS.learning_rate_decay_factor,
staircase=True,
name='exponential_decay_learning_rate')
elif FLAGS.learning_rate_decay_type == 'fixed':
return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')
elif FLAGS.learning_rate_decay_type == 'polynomial':
return tf.train.polynomial_decay(FLAGS.learning_rate,
global_step,
decay_steps,
FLAGS.end_learning_rate,
power=1.0,
cycle=False,
name='polynomial_decay_learning_rate')
else:
raise ValueError('learning_rate_decay_type [%s] was not recognized' %
FLAGS.learning_rate_decay_type)
def _configure_optimizer(learning_rate):
"""Configures the optimizer used for training.
Args:
learning_rate: A scalar or `Tensor` learning rate.
Returns:
An instance of an optimizer.
Raises:
ValueError: if FLAGS.optimizer is not recognized.
"""
if FLAGS.optimizer == 'adadelta':
optimizer = tf.train.AdadeltaOptimizer(
learning_rate,
rho=FLAGS.adadelta_rho,
epsilon=FLAGS.opt_epsilon)
elif FLAGS.optimizer == 'adagrad':
optimizer = tf.train.AdagradOptimizer(
learning_rate,
initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value)
elif FLAGS.optimizer == 'adam':
optimizer = tf.train.AdamOptimizer(
learning_rate,
beta1=FLAGS.adam_beta1,
beta2=FLAGS.adam_beta2,
epsilon=FLAGS.opt_epsilon)
elif FLAGS.optimizer == 'ftrl':
optimizer = tf.train.FtrlOptimizer(
learning_rate,
learning_rate_power=FLAGS.ftrl_learning_rate_power,
initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value,
l1_regularization_strength=FLAGS.ftrl_l1,
l2_regularization_strength=FLAGS.ftrl_l2)
elif FLAGS.optimizer == 'momentum':
optimizer = tf.train.MomentumOptimizer(
learning_rate,
momentum=FLAGS.momentum,
name='Momentum')
elif FLAGS.optimizer == 'rmsprop':
optimizer = tf.train.RMSPropOptimizer(
learning_rate,
decay=FLAGS.rmsprop_decay,
momentum=FLAGS.rmsprop_momentum,
epsilon=FLAGS.opt_epsilon)
elif FLAGS.optimizer == 'sgd':
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
else:
raise ValueError('Optimizer [%s] was not recognized' % FLAGS.optimizer)
return optimizer
def _get_init_fn():
"""Returns a function run by the chief worker to warm-start the training.
Note that the init_fn is only run when initializing the model during the very
first global step.
Returns:
An init function run by the supervisor.
"""
if FLAGS.checkpoint_path is None:
return None
# Warn the user if a checkpoint exists in the train_dir. Then we'll be
# ignoring the checkpoint anyway.
if tf.train.latest_checkpoint(FLAGS.train_dir):
tf.logging.info(
'Ignoring --checkpoint_path because a checkpoint already exists in %s'
% FLAGS.train_dir)
return None
exclusions = []
if FLAGS.checkpoint_exclude_scopes:
exclusions = [scope.strip()
for scope in FLAGS.checkpoint_exclude_scopes.split(',')]
# TODO(sguada) variables.filter_variables()
variables_to_restore = []
for var in slim.get_model_variables():
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
break
else:
variables_to_restore.append(var)
if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
else:
checkpoint_path = FLAGS.checkpoint_path
tf.logging.info('Fine-tuning from %s' % checkpoint_path)
return slim.assign_from_checkpoint_fn(
checkpoint_path,
variables_to_restore,
ignore_missing_vars=FLAGS.ignore_missing_vars)
def _get_variables_to_train():
"""Returns a list of variables to train.
Returns:
A list of variables to train by the optimizer.
"""
print ('aaa*10')
if FLAGS.trainable_scopes is None:
return tf.trainable_variables()
else:
scopes = [scope.strip() for scope in FLAGS.trainable_scopes.split(',')]
variables_to_train = []
for scope in scopes:
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
variables_to_train.extend(variables)
return variables_to_train
def train_step(sess, train_op, global_step, train_step_kwargs):
"""Function that takes a gradient step and specifies whether to stop.
Args:
sess: The current session.
train_op: An `Operation` that evaluates the gradients and returns the total
loss.
global_step: A `Tensor` representing the global training step.
train_step_kwargs: A dictionary of keyword arguments.
Returns:
The total loss and a boolean indicating whether or not to stop training.
Raises:
ValueError: if 'should_trace' is in `train_step_kwargs` but `logdir` is not.
"""
start_time = time.time()
trace_run_options = None
run_metadata = None
if 'should_trace' in train_step_kwargs:
if 'logdir' not in train_step_kwargs:
raise ValueError('logdir must be present in train_step_kwargs when '
'should_trace is present')
if sess.run(train_step_kwargs['should_trace']):
trace_run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
total_loss, np_global_step = sess.run([train_op, global_step],
options=trace_run_options,
run_metadata=run_metadata)
time_elapsed = time.time() - start_time
if run_metadata is not None:
tl = timeline.Timeline(run_metadata.step_stats)
trace = tl.generate_chrome_trace_format()
trace_filename = os.path.join(train_step_kwargs['logdir'],
'tf_trace-%d.json' % np_global_step)
tf.logging.info('Writing trace to %s', trace_filename)
file_io.write_string_to_file(trace_filename, trace)
if 'summary_writer' in train_step_kwargs:
train_step_kwargs['summary_writer'].add_run_metadata(
run_metadata, 'run_metadata-%d' % np_global_step)
if 'should_log' in train_step_kwargs:
if sess.run(train_step_kwargs['should_log']):
tf.logging.info('global step %d: loss = %.4f (%.3f sec/step)',
np_global_step, total_loss, time_elapsed)
if np_global_step%FLAGS.val_interval == 0:
#do eval
tf.logging.info('======START VALIDATION======')
val_list = [v for v in tf.get_collection('validation')]
val_list_name = [v.op.name for v in val_list]
agg_val_loss = 0.0
agg_val_acc = 0.0
total_val_g_hit = 0
total_val_ng_hit = 0
num_g = 0
num_ng = 0
data_size = 2251
ng_weight = 2
num_batches = math.ceil(data_size / float(FLAGS.val_batch_size))
for i in tqdm(range(num_batches)):
np_val_list = sess.run(val_list)
_val_g_mask, _val_ng_mask = sess.run(['val_g_mask:0', 'val_ng_mask:0'])
agg_val_loss+= np_val_list[0]
agg_val_acc+= np_val_list[1]
total_val_g_hit+=np_val_list[val_list_name.index('val_g_hit')]
total_val_ng_hit+=np_val_list[val_list_name.index('val_ng_hit')]
num_g += np.sum(_val_g_mask)
num_ng += np.sum(_val_ng_mask)
print ('total_val_g_hit:', total_val_g_hit)
print ('total_val_ng_hit:', total_val_ng_hit)
print ('np_val_list[val_list_name.index(val_g_hit)]:', np_val_list[val_list_name.index('val_g_hit')])
print ('np_val_list[2]:', np_val_list[2])
print ('np_val_list[val_list_name.index(val_ng_hit)]:', np_val_list[val_list_name.index('val_ng_hit')])
print ('np_val_list[3]:', np_val_list[3])
print ('num_g:', num_g)
print ('num_g:', num_ng)
input()
with open(os.path.join(FLAGS.train_dir, 'log.txt'), 'a') as f:
__val_loss = agg_val_loss/num_batches
__val_g_recall = float(total_val_g_hit)/num_g
__val_ng_recall = float(total_val_ng_hit)/num_ng
__val_acc = float(total_val_g_hit+total_val_ng_hit)/(num_g+num_ng)
__wt_val_acc = float(total_val_g_hit+total_val_ng_hit*ng_weight)/(num_g+num_ng*ng_weight)
__size = num_g+num_ng
out_w = 'global_step: %d | %s: %.3f | %s: %.3f | %s: %.3f | %s: %.3f | %s: %.3f | %s: %d\n' % \
(np_global_step,
'__val_loss', __val_loss,
'val_g_recall', __val_g_recall,
'val_ng_recall', __val_ng_recall,
'val_acc', __val_acc,
'wt_val_acc', __wt_val_acc,
'size ', __size)
tf.logging.info(out_w)
f.write(out_w)
tf.logging.info('======END VALIDATION======')
'''
# This section of code is used for debugging the TFRecord and Augmentation
debug_list = [d for d in tf.get_collection('debug')]
for d in debug_list:
imgs = sess.run(d)
print(np.shape(imgs))
for idx,img in enumerate(imgs):
print(np.shape(img))
_R_MEAN = 123.68
_G_MEAN = 116.78
_B_MEAN = 103.94
img[:,:,0]+=_R_MEAN
img[:,:,1]+=_G_MEAN
img[:,:,2]+=_B_MEAN
print(img)
img = img.astype(np.uint8)
cv2.imwrite(os.path.join('./',str(idx)+'.jpg'),img)
'''
# TODO(nsilberman): figure out why we can't put this into sess.run. The
# issue right now is that the stop check depends on the global step. The
# increment of global step often happens via the train op, which used
# created using optimizer.apply_gradients.
#
# Since running `train_op` causes the global step to be incremented, one
# would expected that using a control dependency would allow the
# should_stop check to be run in the same session.run call:
#
# with ops.control_dependencies([train_op]):
# should_stop_op = ...
#
# However, this actually seems not to work on certain platforms.
if 'should_stop' in train_step_kwargs:
should_stop = sess.run(train_step_kwargs['should_stop'])
else:
should_stop = False
return total_loss, should_stop
def main(_):
if not FLAGS.dataset_dir:
raise ValueError('You must supply the dataset directory with --dataset_dir')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default():
#######################
# Config model_deploy #
#######################
deploy_config = model_deploy.DeploymentConfig(
num_clones=FLAGS.num_clones,
clone_on_cpu=FLAGS.clone_on_cpu,
replica_id=FLAGS.task,
num_replicas=FLAGS.worker_replicas,
num_ps_tasks=FLAGS.num_ps_tasks)
# Create global_step
with tf.device(deploy_config.variables_device()):
global_step = slim.create_global_step()
######################
# Select the dataset #
######################
dataset = dataset_factory.get_dataset(
FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
val_dataset = dataset_factory.get_dataset(
FLAGS.dataset_name, FLAGS.val_dataset_split_name, FLAGS.dataset_dir)
######################
# Select the network #
######################
if FLAGS.is_binary_cls:
network_num_classes = 1
num_classes = dataset.num_classes - FLAGS.labels_offset #11-1 = 10
else:
num_classes = dataset.num_classes #11, 0 as perfect class
network_num_classes = num_classes
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(network_num_classes),
weight_decay=FLAGS.weight_decay,
is_training=True)
val_network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(network_num_classes),
weight_decay=FLAGS.weight_decay,
is_training=False)
#####################################
# Select the preprocessing function #
#####################################
preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
preprocessing_name,
is_training=True)
image_preprocessing_fn_val = preprocessing_factory.get_preprocessing(
preprocessing_name,
is_training=False)
##############################################################
# Create a dataset provider that loads data from the dataset #
##############################################################
with tf.device(deploy_config.inputs_device()):
######### Training ##########
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=FLAGS.num_readers,
common_queue_capacity=20 * FLAGS.batch_size,
common_queue_min=10 * FLAGS.batch_size)
# provider is a tensor_dict
[image, bboxes, labels, filename] = provider.get(
['image', 'bboxes', 'labels', 'filename'])
######### Validation ##########
val_provider = slim.dataset_data_provider.DatasetDataProvider(
val_dataset,
num_readers=FLAGS.num_readers,
common_queue_capacity= 2*FLAGS.val_batch_size,
common_queue_min= 1*FLAGS.val_batch_size)
[val_image, val_bboxes, val_labels, val_filename] = val_provider.get(
['image', 'bboxes', 'labels', 'filename'])
####################Label #############
if FLAGS.is_binary_cls:
labels -= FLAGS.labels_offset
val_labels -= FLAGS.labels_offset
#else:
# labels -= FLAGS.labels_offset
# val_labels -= FLAGS.labels_offset
train_image_height = (FLAGS.train_image_height or
network_fn.default_image_size)
train_image_width = (FLAGS.train_image_width or
network_fn.default_image_size)
image, label = image_preprocessing_fn(image, train_image_height,
train_image_width, bboxes=bboxes, labels=labels, filename=filename,
num_classes=num_classes,
use_more_augmentation=FLAGS.use_more_augmentation)
val_image, val_label = image_preprocessing_fn_val(val_image,
train_image_height, train_image_width, bboxes=val_bboxes,
labels=val_labels, filename=val_filename, num_classes=num_classes)
if FLAGS.is_binary_cls:
# we assume the label only contains 1 and 0
# 0 as perfect and 1 as defect
label = tf.reduce_max(label, keepdims=True) #[num_classes] => [1]
val_label = tf.reduce_max(val_label, keepdims=True)
else:
# adding idx 0 as perfect is idx>0 are all 0
# well... idx 0 is always zero, we can perfrom reduce max directly
perfect_label = 1 - tf.reduce_max(label, keepdims=True) #[num_classes] => [1]
val_perfect_label = 1 - tf.reduce_max(val_label,keepdims=True)
label = tf.concat([perfect_label, label[1:]],axis=0) #[1] concat [10]
val_label = tf.concat([val_perfect_label, val_label[1:]],axis=0)
images, labels, filenames = tf.train.shuffle_batch(
[image, label, filename],
batch_size=FLAGS.batch_size,
num_threads=FLAGS.num_preprocessing_threads,
capacity=10 * FLAGS.batch_size,
min_after_dequeue=2)
val_images, val_labels = tf.train.shuffle_batch(
[val_image, val_label],
batch_size=FLAGS.val_batch_size,
num_threads=FLAGS.num_preprocessing_threads,
capacity=2 * FLAGS.val_batch_size,
min_after_dequeue=1)
if FLAGS.use_batch_preprocessing:
batch_preprocessing_fn = \
preprocessing_factory.get_batch_preprocessing(
FLAGS.preprocessing_name)
images, labels = batch_preprocessing_fn(
images, labels, filenames=filenames)
#labels = slim.one_hot_encoding(
# labels, dataset.num_classes - FLAGS.labels_offset)
batch_queue = slim.prefetch_queue.prefetch_queue(
[images, labels], capacity=4 * deploy_config.num_clones)
val_batch_queue = slim.prefetch_queue.prefetch_queue(
[val_images, val_labels], capacity=1 * deploy_config.num_clones)
####################
# Define the model #
####################
def clone_fn(batch_queue):
"""Allows data parallelism by creating multiple clones of network_fn."""
images, labels = batch_queue.dequeue()
tf.add_to_collection('debug',images)
val_images, val_labels = val_batch_queue.dequeue()
#with tf.variable_scope("model") as scope:
if FLAGS.model_name == 'resnet_dropout':
logits, end_points = network_fn(images,
keep_prob = FLAGS.dropout_keep_rate)
val_logits, _ = val_network_fn(val_images, reuse=True,
keep_prob = FLAGS.dropout_keep_rate)
else:
logits, end_points = network_fn(images)
val_logits, _ = val_network_fn(val_images, reuse=True)
#############################
# Specify the loss function #
#############################
if 'AuxLogits' in end_points:
if FLAGS.use_focal_loss:
aux_loss = multilabel_losses.focal_loss(end_points['AuxLogits'],
labels, alpha=FLAGS.alpha, label_smoothing=FLAGS.label_smoothing,
scope='aux_loss')
util.add_loss(aux_loss, tf.GraphKeys.LOSSES)
else:
slim.losses.sigmoid_cross_entropy(
end_points['AuxLogits'], labels,
label_smoothing=FLAGS.label_smoothing, weights=0.4,
scope='aux_loss')
if FLAGS.use_focal_loss:
loss = multilabel_losses.focal_loss(logits, labels, alpha=FLAGS.alpha,
label_smoothing=FLAGS.label_smoothing)
util.add_loss(loss, tf.GraphKeys.LOSSES)
else:
if FLAGS.use_ghm_loss:
labels = tf.cast(labels, tf.float32)
weights = ghm_loss.get_ghm_weight(logits, labels)
else:
weights = 1.0
slim.losses.sigmoid_cross_entropy(
logits, labels, label_smoothing=FLAGS.label_smoothing, weights=weights)
#################################
# Specify the val loss function #
#################################
# in tf.slim it automatically cast label to float32
val_labels = tf.cast(val_labels, dtype=tf.float32)
#val_loss = (
# tf.losses.sigmoid_cross_entropy(val_logits, val_labels,
# loss_collection="validation_loss"))
if FLAGS.use_focal_loss:
val_loss = multilabel_losses.focal_loss(
val_logits, val_labels, weights=1.0, name='val_loss')
else:
val_loss = tf.nn.sigmoid_cross_entropy_with_logits(
logits=val_logits, labels=val_labels, name= 'val_sig_xetropy')
val_loss = tf.reduce_sum(val_loss, axis=1)
val_loss = tf.reduce_mean(val_loss, name='val_loss')
tf.add_to_collection('validation', val_loss)
predictions = tf.nn.sigmoid(val_logits)
predictions = tf.cast(tf.greater(predictions,0.5),tf.float32)
val_labels = tf.reshape(
val_labels, [FLAGS.val_batch_size, network_num_classes])
predictions = tf.reshape(
predictions, [FLAGS.val_batch_size, network_num_classes])
#val_labels = tf.squeeze(val_labels)
correct = tf.reduce_min(tf.cast(tf.equal(predictions, val_labels),tf.float32),axis=1)
accuracy = tf.multiply(tf.reduce_mean(correct), tf.constant(100.0),
name='val_acc')
tf.add_to_collection('validation', accuracy)
g_mask = tf.cast(tf.equal(val_labels, 0), tf.float32, name="val_g_mask")
ng_mask = tf.cast(tf.equal(val_labels, 1), tf.float32, name="val_ng_mask")
correct_tensor = tf.cast(tf.equal(predictions, val_labels), tf.float32)
g_hit = tf.reduce_sum(correct_tensor*g_mask, name="val_g_hit")
ng_hit = tf.reduce_sum(correct_tensor*ng_mask, name="val_ng_hit")
tf.add_to_collection('validation', g_hit)
tf.add_to_collection('validation', ng_hit)
return end_points
# Gather initial summaries.
summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
first_clone_scope = deploy_config.clone_scope(0)
# Gather update_ops from the first clone. These contain, for example,
# the updates for the batch_norm variables created by network_fn.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
# Add summaries for end_points.
end_points = clones[0].outputs
for end_point in end_points:
x = end_points[end_point]
summaries.add(tf.summary.histogram('activations/' + end_point, x))
summaries.add(tf.summary.scalar('sparsity/' + end_point,
tf.nn.zero_fraction(x)))
# Add summaries for losses.
for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
for val_scalar in tf.get_collection('validation'):
summaries.add(
tf.summary.scalar('validation/%s' % val_scalar.op.name, val_scalar))
# Add summaries for variables.
for variable in slim.get_model_variables():
summaries.add(tf.summary.histogram(variable.op.name, variable))
#################################
# Configure the moving averages #
#################################
if FLAGS.moving_average_decay:
moving_average_variables = slim.get_model_variables()
variable_averages = tf.train.ExponentialMovingAverage(
FLAGS.moving_average_decay, global_step)
else:
moving_average_variables, variable_averages = None, None
#########################################
# Configure the optimization procedure. #
#########################################
with tf.device(deploy_config.optimizer_device()):
learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
optimizer = _configure_optimizer(learning_rate)
summaries.add(tf.summary.scalar('learning_rate', learning_rate))
if FLAGS.sync_replicas:
# If sync_replicas is enabled, the averaging will be done in the chief
# queue runner.
optimizer = tf.train.SyncReplicasOptimizer(
opt=optimizer,
replicas_to_aggregate=FLAGS.replicas_to_aggregate,
total_num_replicas=FLAGS.worker_replicas,
variable_averages=variable_averages,
variables_to_average=moving_average_variables)
elif FLAGS.moving_average_decay:
# Update ops executed locally by trainer.
update_ops.append(variable_averages.apply(moving_average_variables))
# Variables to train.
variables_to_train = _get_variables_to_train()
# and returns a train_tensor and summary_op
total_loss, clones_gradients = model_deploy.optimize_clones(
clones,
optimizer,
var_list=variables_to_train)
# Add total_loss to summary.
summaries.add(tf.summary.scalar('total_loss', total_loss))
# Create gradient updates.
grad_updates = optimizer.apply_gradients(clones_gradients,
global_step=global_step)
update_ops.append(grad_updates)
update_op = tf.group(*update_ops)
with tf.control_dependencies([update_op]):
train_tensor = tf.identity(total_loss, name='train_op')
# Add the summaries from the first clone. These contain the summaries
# created by model_fn and either optimize_clones() or _gather_clone_loss().
summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
first_clone_scope))
# Merge all summaries together.
summary_op = tf.summary.merge(list(summaries), name='summary_op')
###########################
# Kicks off the training. #
###########################
session_config = tf.ConfigProto()
session_config.gpu_options.allow_growth = True
slim.learning.train(
train_tensor,
logdir=FLAGS.train_dir,
train_step_fn=train_step,
master=FLAGS.master,
is_chief=(FLAGS.task == 0),
init_fn=_get_init_fn(),
summary_op=summary_op,
number_of_steps=FLAGS.max_number_of_steps,
log_every_n_steps=FLAGS.log_every_n_steps,
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs,
sync_optimizer=optimizer if FLAGS.sync_replicas else None,
session_config=session_config)
if __name__ == '__main__':
tf.app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment