Skip to content

Instantly share code, notes, and snippets.

@vashineyu
Created December 17, 2018 05:25
Show Gist options
  • Save vashineyu/b3eeb3f4aea9701c7f301a2f77084250 to your computer and use it in GitHub Desktop.
Save vashineyu/b3eeb3f4aea9701c7f301a2f77084250 to your computer and use it in GitHub Desktop.
#! /usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tf_parameter_mgr
import monitor_cb
from tensorflow.python.framework import ops
#from monitor_cb import CMonitor
from datetime import datetime
import os.path
import time
import sys, os
import threading
import glob
import numpy as np
import tensorflow as tf
from preprocessing import preprocessing_factory
from nets import nets_factory
from tensorflow.python.ops import variables
# Choose your image preprocessing
tf.app.flags.DEFINE_string('preprocessing_type', 'inception_v3', 'image processing type')
# Choose the network structure
tf.app.flags.DEFINE_string('network_type', 'inception_v3', 'image feature extraction network type')
# Set the number of classes
tf.app.flags.DEFINE_integer('number_classes', 5, 'number of classes')
tf.app.flags.DEFINE_integer('batch_size', 64, 'batch size')
tf.app.flags.DEFINE_string('weights', None, 'initialize with pretrained model weights')
tf.app.flags.DEFINE_float('weight_decay', 0.0, 'l2_regularizer weight_dacay parameter')
tf.app.flags.DEFINE_string("fixed_parameters_names", None,
"fixed parameters' names' list filename")
tf.app.flags.DEFINE_string("restore_parameters_names", None,
"restore parameters' names' list filename")
tf.app.flags.DEFINE_string('train_dir', 'train/',
"""Directory where to write event logs """
"""and checkpoint.""")
tf.app.flags.DEFINE_integer('test_interval', 50, 'test_interval')
tf.app.flags.DEFINE_integer('eval_topk', 1, 'accuracy of evaluation of top-k')
# Parameters for distributed training, no need to modify
tf.app.flags.DEFINE_string('job_name', '', 'One of "ps", "worker"')
tf.app.flags.DEFINE_string('ps_hosts', '',
"""Comma-separated list of hostname:port for the """
"""parameter server jobs. e.g. """
"""'machine1:2222,machine2:1111,machine2:2222'""")
tf.app.flags.DEFINE_string('worker_hosts', '',
"""Comma-separated list of hostname:port for the """
"""worker jobs. e.g. """
"""'machine1:2222,machine2:1111,machine2:2222'""")
tf.app.flags.DEFINE_integer('task_id', 0, 'Task ID of the worker/replica running the training.')
tf.app.flags.DEFINE_bool('log_device_placement', False, 'log_device_placement')
FLAGS = tf.app.flags.FLAGS
FLAGS.batch_size = tf_parameter_mgr.getTrainBatchSize()
image_size=299
#image_size=224 #inceptionv1
#image_size=4200
def get_restore_fixed_var_list():
var2restore = []
var2fixed = []
fixed_parameters_names = FLAGS.fixed_parameters_names
restore_parameters_names = FLAGS.restore_parameters_names
if fixed_parameters_names != None:
try:
with open(fixed_parameters_names) as file:
param = file.readline()
var2fixed = param.split(',')
var2fixed = [var + ":0" for var in var2fixed]
except Exception as e:
print("Error while reading", fixed_parameters_names)
if restore_parameters_names != None:
try:
with open(restore_parameters_names) as file:
param = file.readline()
var2restore = param.split(',')
var2restore = [var + ":0" for var in var2restore]
except Exception as e:
print("Error while reading", restore_parameters_names)
else:
var2restore = [var.name for var in tf.trainable_variables()]
return var2restore, var2fixed
def get_train_op(total_loss, global_step, return_grad=False):
lr = tf_parameter_mgr.getLearningRate(global_step)
_, var2fixed = get_restore_fixed_var_list()
# Compute gradients.
with tf.name_scope('user_optimizer'):
opt = tf_parameter_mgr.getOptimizer(lr)
# Exclude fixed variables
trainable_var = []
for var in ops.get_collection("trainable_variables") + ops.get_collection("trainable_resource_variables"):
if var.name not in var2fixed:
trainable_var.append(var)
grads = opt.compute_gradients(total_loss, var_list=trainable_var)
# Apply gradients.
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
with tf.control_dependencies([apply_gradient_op] + tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
train_op = tf.no_op(name='train')
if return_grad:
return train_op, grads
return train_op
def get_loss(logits, labels):
# Calculate the average cross entropy loss across the batch.
labels = tf.cast(labels, tf.int64)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits, name='cross_entropy_per_example')
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
tf.add_to_collection('losses', cross_entropy_mean)
# The total loss is defined as the cross entropy loss plus all of the weight
# decay terms (L2 loss).
if FLAGS.weight_decay > 0:
regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
else:
regularization_losses = []
total_loss = tf.add_n(tf.get_collection('losses') + regularization_losses, name='total_loss')
return total_loss
def get_accuracy(logits, labels):
top_k_op = tf.nn.in_top_k(logits, labels, FLAGS.eval_topk)
correct = np.sum(top_k_op)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
return accuracy
def get_data(is_train=True):
batch_size = FLAGS.batch_size
if is_train:
filenames = tf_parameter_mgr.getTrainData()
else:
filenames = tf_parameter_mgr.getTestData()
filename_queue = tf.train.string_input_producer(filenames)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64),
'label': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),
})
label = features['label']
image = tf.decode_raw(features['image_raw'], tf.uint8)
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
depth = tf.cast(features['depth'], tf.int32)
image = tf.reshape(image, tf.stack([height, width, depth]))
preprocessing_type = FLAGS.preprocessing_type
preprocessor = preprocessing_factory.get_preprocessing(preprocessing_type, is_training=is_train)
network_type = FLAGS.network_type
default_size = image_size #nets_factory.get_default_size(network_type)
image = preprocessor(image, output_height=default_size, output_width=default_size)
image.set_shape([default_size, default_size, 3])
image_batch, label_batch = tf.train.batch([image, label], batch_size=batch_size, num_threads=4, capacity=32)
return image_batch, label_batch
def setup_distribute():
global FLAGS
worker_hosts = []
ps_hosts = []
spec = {}
if FLAGS.worker_hosts is not None and FLAGS.worker_hosts != '':
worker_hosts = FLAGS.worker_hosts.split(',')
spec.update({'worker': worker_hosts})
if FLAGS.ps_hosts is not None and FLAGS.ps_hosts != '':
ps_hosts = FLAGS.ps_hosts.split(',')
spec.update({'ps': ps_hosts})
if len(worker_hosts) > 0:
print('Cluster spec: ', spec)
cluster = tf.train.ClusterSpec(spec)
# Create and start a server for the local task.
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_id)
if FLAGS.job_name == "ps":
server.join()
else:
cluster = None
server = tf.train.Server.create_local_server()
# enforce a task_id for single node mode
FLAGS.task_id = 0
return cluster, server
def train():
cluster, server = setup_distribute()
is_chief = (FLAGS.task_id == 0)
# try to get the number of classes
filenames = tf_parameter_mgr.getTrainData()
if len(filenames) > 0:
try:
data_path = os.path.dirname(os.path.dirname(os.path.abspath(filenames[0])))
label_filename = data_path + '/labels.txt'
if os.path.exists(label_filename):
cnts = 0
with open(label_filename) as label_file:
for line in label_file:
cnts += 1
if cnts > 0:
FLAGS.number_classes = cnts
print('FLAGS.number_classes', FLAGS.number_classes)
except Exception as e:
print(e)
# setup MAO and Tensorboard monitoring
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % FLAGS.task_id,
cluster=cluster)):
global_step = tf.train.get_or_create_global_step()
is_training = tf.placeholder_with_default(False, shape=[])
i_train, l_train = get_data(is_train=True)
i_test, l_test = get_data(is_train=False)
images, labels = tf.cond(is_training, lambda: (i_train, l_train), lambda: (i_test, l_test))
network_type = FLAGS.network_type
print('FLAGS.number_classes', FLAGS.number_classes)
# l2_regularizer weight_dacay parameter
weight_decay = FLAGS.weight_decay
embedding_network = nets_factory.get_network_fn(network_type, weight_decay=weight_decay, num_classes=FLAGS.number_classes,
is_training=is_training)
logits, end_points = embedding_network(images)
total_loss = get_loss(logits, labels)
accuracy = get_accuracy(logits, labels)
train_op, grads = get_train_op(total_loss, global_step, True)
from tensorflow.contrib.lms import LMS
lms_model = LMS({'user_optimizer'},
lb=1)
lms_model.run(tf.get_default_graph())
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._next_trigger_step = FLAGS.test_interval
self._trigger = False
def before_run(self, run_context):
args = {'global_step': global_step}
if self._trigger:
self._trigger = False
return tf.train.SessionRunArgs(args) # Asks for loss value.
def after_run(self, run_context, run_values):
gs = run_values.results['global_step']
if gs >= self._next_trigger_step:
self._trigger = True
self._next_trigger_step += FLAGS.test_interval
if self._trigger:
lossVal,accval = run_context.session.run([total_loss,accuracy], feed_dict = {is_training:False})
print("Iteration {}: tag test_loss, simple_value {}".format(gs, lossVal))
print("Iteration {}: tag test_accuracy, simple_value {}".format(gs, accval))
def end(self, session):
lossVal,accval = session.run([total_loss,accuracy], feed_dict = {is_training:False})
print("Iteration {}: tag test_loss, simple_value {}".format(tf_parameter_mgr.getMaxSteps(), lossVal))
print("Iteration {}: tag test_accuracy, simple_value {}".format(tf_parameter_mgr.getMaxSteps(), accval))
hooks = [tf.train.StopAtStepHook(last_step=tf_parameter_mgr.getMaxSteps()),
tf.train.NanTensorHook(total_loss)]
if is_chief: hooks.append(_LoggerHook())
var2restore, _ = get_restore_fixed_var_list()
pretrained_model = FLAGS.weights
var_list = []
for var in tf.global_variables() + tf.local_variables():
if var.name in var2restore:
var_list.append(var)
if len(var_list) > 0 and pretrained_model != None:
print("------------------------------")
print('will restore ', var_list)
saver = tf.train.Saver(var_list=var_list)
mon_sess = tf.train.MonitoredTrainingSession(master=server.target, is_chief=is_chief, checkpoint_dir=FLAGS.train_dir, hooks=hooks,
config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement))
if pretrained_model != None and len(var_list) > 0:
ckpt = tf.train.get_checkpoint_state(pretrained_model)
print("Restore pre-trained checkpoint:", ckpt)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(mon_sess, ckpt.model_checkpoint_path)
print("Successfully restore checkpoint:", ckpt.model_checkpoint_path)
else:
files = os.listdir(pretrained_model)
for f in files:
if f.endswith('ckpt'):
model_checkpoint_path = pretrained_model + "/" + f
try:
saver.restore(mon_sess, model_checkpoint_path)
print("Successfully restore checkpoint:", model_checkpoint_path)
except Exception as e:
print("Fail to restore ", model_checkpoint_path, 'with message', e)
break
steps = 0
t1 = time.time()
while not mon_sess.should_stop():
_, lossval, accval = mon_sess.run([train_op, total_loss, accuracy], feed_dict={is_training: True})
#print('Training iteration %d: loss=%f,acc=%f'%(steps, lossval, accval))
steps += 1
if steps % FLAGS.test_interval == 0:
print("Iteration {}: tag train_loss, simple_value {}".format(steps, lossval))
print("Iteration {}: tag train_accuracy, simple_value {}".format(steps, accval))
print('time cost:', time.time() - t1)
print('%d steps executed on worker %d.' % (steps, FLAGS.task_id))
# TF 1.10 some times fail to close session and hang, wait 30 seconds here for session to close
# This is just a workaround before there is a formal fix in TF
start_time = time.time()
close = threading.Thread(target=mon_sess.close, args=())
close.setDaemon(True)
close.start()
close.join(30)
duration = time.time() - start_time
if close.isAlive():
print('WARN: The tf training session fail to close')
print('close session cost:', duration)
def main(argv=None):
train()
if __name__ == '__main__':
tf.app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment