Created
December 17, 2018 05:25
-
-
Save vashineyu/b3eeb3f4aea9701c7f301a2f77084250 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import tf_parameter_mgr | |
import monitor_cb | |
from tensorflow.python.framework import ops | |
#from monitor_cb import CMonitor | |
from datetime import datetime | |
import os.path | |
import time | |
import sys, os | |
import threading | |
import glob | |
import numpy as np | |
import tensorflow as tf | |
from preprocessing import preprocessing_factory | |
from nets import nets_factory | |
from tensorflow.python.ops import variables | |
# Choose your image preprocessing | |
tf.app.flags.DEFINE_string('preprocessing_type', 'inception_v3', 'image processing type') | |
# Choose the network structure | |
tf.app.flags.DEFINE_string('network_type', 'inception_v3', 'image feature extraction network type') | |
# Set the number of classes | |
tf.app.flags.DEFINE_integer('number_classes', 5, 'number of classes') | |
tf.app.flags.DEFINE_integer('batch_size', 64, 'batch size') | |
tf.app.flags.DEFINE_string('weights', None, 'initialize with pretrained model weights') | |
tf.app.flags.DEFINE_float('weight_decay', 0.0, 'l2_regularizer weight_dacay parameter') | |
tf.app.flags.DEFINE_string("fixed_parameters_names", None, | |
"fixed parameters' names' list filename") | |
tf.app.flags.DEFINE_string("restore_parameters_names", None, | |
"restore parameters' names' list filename") | |
tf.app.flags.DEFINE_string('train_dir', 'train/', | |
"""Directory where to write event logs """ | |
"""and checkpoint.""") | |
tf.app.flags.DEFINE_integer('test_interval', 50, 'test_interval') | |
tf.app.flags.DEFINE_integer('eval_topk', 1, 'accuracy of evaluation of top-k') | |
# Parameters for distributed training, no need to modify | |
tf.app.flags.DEFINE_string('job_name', '', 'One of "ps", "worker"') | |
tf.app.flags.DEFINE_string('ps_hosts', '', | |
"""Comma-separated list of hostname:port for the """ | |
"""parameter server jobs. e.g. """ | |
"""'machine1:2222,machine2:1111,machine2:2222'""") | |
tf.app.flags.DEFINE_string('worker_hosts', '', | |
"""Comma-separated list of hostname:port for the """ | |
"""worker jobs. e.g. """ | |
"""'machine1:2222,machine2:1111,machine2:2222'""") | |
tf.app.flags.DEFINE_integer('task_id', 0, 'Task ID of the worker/replica running the training.') | |
tf.app.flags.DEFINE_bool('log_device_placement', False, 'log_device_placement') | |
FLAGS = tf.app.flags.FLAGS | |
FLAGS.batch_size = tf_parameter_mgr.getTrainBatchSize() | |
image_size=299 | |
#image_size=224 #inceptionv1 | |
#image_size=4200 | |
def get_restore_fixed_var_list(): | |
var2restore = [] | |
var2fixed = [] | |
fixed_parameters_names = FLAGS.fixed_parameters_names | |
restore_parameters_names = FLAGS.restore_parameters_names | |
if fixed_parameters_names != None: | |
try: | |
with open(fixed_parameters_names) as file: | |
param = file.readline() | |
var2fixed = param.split(',') | |
var2fixed = [var + ":0" for var in var2fixed] | |
except Exception as e: | |
print("Error while reading", fixed_parameters_names) | |
if restore_parameters_names != None: | |
try: | |
with open(restore_parameters_names) as file: | |
param = file.readline() | |
var2restore = param.split(',') | |
var2restore = [var + ":0" for var in var2restore] | |
except Exception as e: | |
print("Error while reading", restore_parameters_names) | |
else: | |
var2restore = [var.name for var in tf.trainable_variables()] | |
return var2restore, var2fixed | |
def get_train_op(total_loss, global_step, return_grad=False): | |
lr = tf_parameter_mgr.getLearningRate(global_step) | |
_, var2fixed = get_restore_fixed_var_list() | |
# Compute gradients. | |
with tf.name_scope('user_optimizer'): | |
opt = tf_parameter_mgr.getOptimizer(lr) | |
# Exclude fixed variables | |
trainable_var = [] | |
for var in ops.get_collection("trainable_variables") + ops.get_collection("trainable_resource_variables"): | |
if var.name not in var2fixed: | |
trainable_var.append(var) | |
grads = opt.compute_gradients(total_loss, var_list=trainable_var) | |
# Apply gradients. | |
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) | |
with tf.control_dependencies([apply_gradient_op] + tf.get_collection(tf.GraphKeys.UPDATE_OPS)): | |
train_op = tf.no_op(name='train') | |
if return_grad: | |
return train_op, grads | |
return train_op | |
def get_loss(logits, labels): | |
# Calculate the average cross entropy loss across the batch. | |
labels = tf.cast(labels, tf.int64) | |
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( | |
labels=labels, logits=logits, name='cross_entropy_per_example') | |
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy') | |
tf.add_to_collection('losses', cross_entropy_mean) | |
# The total loss is defined as the cross entropy loss plus all of the weight | |
# decay terms (L2 loss). | |
if FLAGS.weight_decay > 0: | |
regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) | |
else: | |
regularization_losses = [] | |
total_loss = tf.add_n(tf.get_collection('losses') + regularization_losses, name='total_loss') | |
return total_loss | |
def get_accuracy(logits, labels): | |
top_k_op = tf.nn.in_top_k(logits, labels, FLAGS.eval_topk) | |
correct = np.sum(top_k_op) | |
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) | |
return accuracy | |
def get_data(is_train=True): | |
batch_size = FLAGS.batch_size | |
if is_train: | |
filenames = tf_parameter_mgr.getTrainData() | |
else: | |
filenames = tf_parameter_mgr.getTestData() | |
filename_queue = tf.train.string_input_producer(filenames) | |
reader = tf.TFRecordReader() | |
_, serialized_example = reader.read(filename_queue) | |
features = tf.parse_single_example(serialized_example, | |
features={ | |
'height': tf.FixedLenFeature([], tf.int64), | |
'width': tf.FixedLenFeature([], tf.int64), | |
'depth': tf.FixedLenFeature([], tf.int64), | |
'label': tf.FixedLenFeature([], tf.int64), | |
'image_raw': tf.FixedLenFeature([], tf.string), | |
}) | |
label = features['label'] | |
image = tf.decode_raw(features['image_raw'], tf.uint8) | |
height = tf.cast(features['height'], tf.int32) | |
width = tf.cast(features['width'], tf.int32) | |
depth = tf.cast(features['depth'], tf.int32) | |
image = tf.reshape(image, tf.stack([height, width, depth])) | |
preprocessing_type = FLAGS.preprocessing_type | |
preprocessor = preprocessing_factory.get_preprocessing(preprocessing_type, is_training=is_train) | |
network_type = FLAGS.network_type | |
default_size = image_size #nets_factory.get_default_size(network_type) | |
image = preprocessor(image, output_height=default_size, output_width=default_size) | |
image.set_shape([default_size, default_size, 3]) | |
image_batch, label_batch = tf.train.batch([image, label], batch_size=batch_size, num_threads=4, capacity=32) | |
return image_batch, label_batch | |
def setup_distribute(): | |
global FLAGS | |
worker_hosts = [] | |
ps_hosts = [] | |
spec = {} | |
if FLAGS.worker_hosts is not None and FLAGS.worker_hosts != '': | |
worker_hosts = FLAGS.worker_hosts.split(',') | |
spec.update({'worker': worker_hosts}) | |
if FLAGS.ps_hosts is not None and FLAGS.ps_hosts != '': | |
ps_hosts = FLAGS.ps_hosts.split(',') | |
spec.update({'ps': ps_hosts}) | |
if len(worker_hosts) > 0: | |
print('Cluster spec: ', spec) | |
cluster = tf.train.ClusterSpec(spec) | |
# Create and start a server for the local task. | |
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_id) | |
if FLAGS.job_name == "ps": | |
server.join() | |
else: | |
cluster = None | |
server = tf.train.Server.create_local_server() | |
# enforce a task_id for single node mode | |
FLAGS.task_id = 0 | |
return cluster, server | |
def train(): | |
cluster, server = setup_distribute() | |
is_chief = (FLAGS.task_id == 0) | |
# try to get the number of classes | |
filenames = tf_parameter_mgr.getTrainData() | |
if len(filenames) > 0: | |
try: | |
data_path = os.path.dirname(os.path.dirname(os.path.abspath(filenames[0]))) | |
label_filename = data_path + '/labels.txt' | |
if os.path.exists(label_filename): | |
cnts = 0 | |
with open(label_filename) as label_file: | |
for line in label_file: | |
cnts += 1 | |
if cnts > 0: | |
FLAGS.number_classes = cnts | |
print('FLAGS.number_classes', FLAGS.number_classes) | |
except Exception as e: | |
print(e) | |
# setup MAO and Tensorboard monitoring | |
with tf.device(tf.train.replica_device_setter( | |
worker_device="/job:worker/task:%d" % FLAGS.task_id, | |
cluster=cluster)): | |
global_step = tf.train.get_or_create_global_step() | |
is_training = tf.placeholder_with_default(False, shape=[]) | |
i_train, l_train = get_data(is_train=True) | |
i_test, l_test = get_data(is_train=False) | |
images, labels = tf.cond(is_training, lambda: (i_train, l_train), lambda: (i_test, l_test)) | |
network_type = FLAGS.network_type | |
print('FLAGS.number_classes', FLAGS.number_classes) | |
# l2_regularizer weight_dacay parameter | |
weight_decay = FLAGS.weight_decay | |
embedding_network = nets_factory.get_network_fn(network_type, weight_decay=weight_decay, num_classes=FLAGS.number_classes, | |
is_training=is_training) | |
logits, end_points = embedding_network(images) | |
total_loss = get_loss(logits, labels) | |
accuracy = get_accuracy(logits, labels) | |
train_op, grads = get_train_op(total_loss, global_step, True) | |
from tensorflow.contrib.lms import LMS | |
lms_model = LMS({'user_optimizer'}, | |
lb=1) | |
lms_model.run(tf.get_default_graph()) | |
class _LoggerHook(tf.train.SessionRunHook): | |
"""Logs loss and runtime.""" | |
def begin(self): | |
self._next_trigger_step = FLAGS.test_interval | |
self._trigger = False | |
def before_run(self, run_context): | |
args = {'global_step': global_step} | |
if self._trigger: | |
self._trigger = False | |
return tf.train.SessionRunArgs(args) # Asks for loss value. | |
def after_run(self, run_context, run_values): | |
gs = run_values.results['global_step'] | |
if gs >= self._next_trigger_step: | |
self._trigger = True | |
self._next_trigger_step += FLAGS.test_interval | |
if self._trigger: | |
lossVal,accval = run_context.session.run([total_loss,accuracy], feed_dict = {is_training:False}) | |
print("Iteration {}: tag test_loss, simple_value {}".format(gs, lossVal)) | |
print("Iteration {}: tag test_accuracy, simple_value {}".format(gs, accval)) | |
def end(self, session): | |
lossVal,accval = session.run([total_loss,accuracy], feed_dict = {is_training:False}) | |
print("Iteration {}: tag test_loss, simple_value {}".format(tf_parameter_mgr.getMaxSteps(), lossVal)) | |
print("Iteration {}: tag test_accuracy, simple_value {}".format(tf_parameter_mgr.getMaxSteps(), accval)) | |
hooks = [tf.train.StopAtStepHook(last_step=tf_parameter_mgr.getMaxSteps()), | |
tf.train.NanTensorHook(total_loss)] | |
if is_chief: hooks.append(_LoggerHook()) | |
var2restore, _ = get_restore_fixed_var_list() | |
pretrained_model = FLAGS.weights | |
var_list = [] | |
for var in tf.global_variables() + tf.local_variables(): | |
if var.name in var2restore: | |
var_list.append(var) | |
if len(var_list) > 0 and pretrained_model != None: | |
print("------------------------------") | |
print('will restore ', var_list) | |
saver = tf.train.Saver(var_list=var_list) | |
mon_sess = tf.train.MonitoredTrainingSession(master=server.target, is_chief=is_chief, checkpoint_dir=FLAGS.train_dir, hooks=hooks, | |
config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)) | |
if pretrained_model != None and len(var_list) > 0: | |
ckpt = tf.train.get_checkpoint_state(pretrained_model) | |
print("Restore pre-trained checkpoint:", ckpt) | |
if ckpt and ckpt.model_checkpoint_path: | |
saver.restore(mon_sess, ckpt.model_checkpoint_path) | |
print("Successfully restore checkpoint:", ckpt.model_checkpoint_path) | |
else: | |
files = os.listdir(pretrained_model) | |
for f in files: | |
if f.endswith('ckpt'): | |
model_checkpoint_path = pretrained_model + "/" + f | |
try: | |
saver.restore(mon_sess, model_checkpoint_path) | |
print("Successfully restore checkpoint:", model_checkpoint_path) | |
except Exception as e: | |
print("Fail to restore ", model_checkpoint_path, 'with message', e) | |
break | |
steps = 0 | |
t1 = time.time() | |
while not mon_sess.should_stop(): | |
_, lossval, accval = mon_sess.run([train_op, total_loss, accuracy], feed_dict={is_training: True}) | |
#print('Training iteration %d: loss=%f,acc=%f'%(steps, lossval, accval)) | |
steps += 1 | |
if steps % FLAGS.test_interval == 0: | |
print("Iteration {}: tag train_loss, simple_value {}".format(steps, lossval)) | |
print("Iteration {}: tag train_accuracy, simple_value {}".format(steps, accval)) | |
print('time cost:', time.time() - t1) | |
print('%d steps executed on worker %d.' % (steps, FLAGS.task_id)) | |
# TF 1.10 some times fail to close session and hang, wait 30 seconds here for session to close | |
# This is just a workaround before there is a formal fix in TF | |
start_time = time.time() | |
close = threading.Thread(target=mon_sess.close, args=()) | |
close.setDaemon(True) | |
close.start() | |
close.join(30) | |
duration = time.time() - start_time | |
if close.isAlive(): | |
print('WARN: The tf training session fail to close') | |
print('close session cost:', duration) | |
def main(argv=None): | |
train() | |
if __name__ == '__main__': | |
tf.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment