Skip to content

Instantly share code, notes, and snippets.

@ericyue
Created February 21, 2017 19:21
Show Gist options
  • Save ericyue/e694a90338b9fadf9996025719005c9d to your computer and use it in GitHub Desktop.
Save ericyue/e694a90338b9fadf9996025719005c9d to your computer and use it in GitHub Desktop.
def main():
# Read TFRecords files for training
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
return serialized_example
# Read TFRecords files for training
filename_queue = tf.train.string_input_producer(
tf.train.match_filenames_once(FLAGS.train_tfrecords_file),
num_epochs=EPOCH_NUMBER)
serialized_example = read_and_decode(filename_queue)
batch_serialized_example = tf.train.shuffle_batch(
[serialized_example],
batch_size=FLAGS.batch_size,
num_threads=BATCH_THREAD_NUMBER,
capacity=BATCH_CAPACITY,
min_after_dequeue=MIN_AFTER_DEQUEUE)
features = tf.parse_example(batch_serialized_example,
features={
"label": tf.FixedLenFeature([], tf.float32),
"ids": tf.VarLenFeature(tf.int64),
"values": tf.VarLenFeature(tf.float32),
})
batch_labels = features["label"]
batch_ids = features["ids"]
batch_values = features["values"]
# Create session to run
with tf.Session() as sess:
logging.info("Start to run with mode: {}".format(MODE))
writer = tf.summary.FileWriter(OUTPUT_PATH, sess.graph)
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
if MODE == "train":
# Restore session and start queue runner
restore_session_from_checkpoint(sess, saver, LATEST_CHECKPOINT)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord, sess=sess)
start_time = datetime.datetime.now()
try:
while not coord.should_stop():
_, loss_value, step = sess.run([train_op, loss, global_step])
except tf.errors.OutOfRangeError:
# Export the model after training
export_model(sess, saver, model_signature, FLAGS.model_path,
FLAGS.model_version)
finally:
coord.request_stop()
coord.join(threads)
def restore_session_from_checkpoint(sess, saver, checkpoint):
if checkpoint:
logging.info("Restore session from checkpoint: {}".format(checkpoint))
saver.restore(sess, checkpoint)
return True
else:
return False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment