Skip to content

Instantly share code, notes, and snippets.

@lucasjinreal
Created May 11, 2017 07:11
Show Gist options
  • Save lucasjinreal/6ac9e5b134ee191bffb5ccbf599531bf to your computer and use it in GitHub Desktop.
Save lucasjinreal/6ac9e5b134ee191bffb5ccbf599531bf to your computer and use it in GitHub Desktop.
Sketelon of train tensorflow pipline.
import os
import sys
import numpy as np
import tensorflow as tf
import logging
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(filename)s line:%(lineno)d %(levelname)s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
tf.app.flags.DEFINE_integer('batch_size', 64, 'batch size.')
tf.app.flags.DEFINE_integer('epochs', 50, 'train how many epochs.')
tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate.')
tf.app.flags.DEFINE_string('checkpoints_dir', os.path.abspath('./checkpoints/poems/'), 'checkpoints save path.')
tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.')
FLAGS = tf.app.flags.FLAGS
def run_training(is_train):
if not os.path.exists(os.path.dirname(FLAGS.checkpoints_dir)):
os.mkdir(os.path.dirname(FLAGS.checkpoints_dir))
if not os.path.exists(FLAGS.checkpoints_dir):
os.mkdir(FLAGS.checkpoints_dir)
data_loader = DataLoader()
model = Model()
saver = tf.train.Saver(tf.global_variables())
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(init_op)
start_epoch = 0
epoch = 0
checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoints_dir)
if checkpoint:
saver.restore(sess, checkpoint)
logging.info("restore from the checkpoint {0}".format(checkpoint))
start_epoch += int(checkpoint.split('-')[-1])
logging.info('start training...')
if is_train:
try:
for epoch in range(start_epoch, FLAGS.epochs):
for batch in range(poems.num_chunks):
inputs, labels = poems.batch_inputs()
feed_dict = model.make_train_inputs(inputs, labels)
loss, _ = sess.run([
model.loss,
model.train_op,
], feed_dict=feed_dict)
logging.info('epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss))
if epoch % 6 == 0:
saver.save(sess, os.path.join(FLAGS.checkpoints_dir, FLAGS.model_prefix), global_step=epoch)
except KeyboardInterrupt:
logging.info('interrupt manually, try saving checkpoint for now...')
saver.save(sess, os.path.join(FLAGS.checkpoints_dir, FLAGS.model_prefix), global_step=epoch)
logging.info('last epoch were saved, next time will start from epoch {}.'.format(epoch))
else:
logging.info('start inference...')
pass
def main(is_train):
run_training(is_train)
if __name__ == '__main__':
tf.app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment