Skip to content

Instantly share code, notes, and snippets.

@siddMahen
Created July 30, 2016 14:03
Show Gist options
  • Save siddMahen/e45174fbf60a4df174af4a5d95a293f1 to your computer and use it in GitHub Desktop.
Save siddMahen/e45174fbf60a4df174af4a5d95a293f1 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import sys
import os
def read_and_decode(filename_queue):
reader = tf.TextLineReader()
_, record = reader.read(filename_queue)
return record
def inputs(filenames, batch_size, num_epochs):
with tf.name_scope('input'):
filename_queue = tf.train.string_input_producer(
filenames, num_epochs=num_epochs, shuffle=True)
line = read_and_decode(filename_queue)
min_after_dequeue = 10
capacity = min_after_dequeue + 3*batch_size
line_batch = tf.train.shuffle_batch([line], batch_size=batch_size,
capacity=capacity, min_after_dequeue=min_after_dequeue,
allow_smaller_final_batch=True)
return line_batch
def train(run_name, filenames):
with tf.Graph().as_default():
lines = inputs(filenames, batch_size=5, num_epochs=1)
v = tf.Variable(1.0)
init_op = tf.initialize_all_variables()
init_again = tf.initialize_local_variables()
sess = tf.Session()
saver = tf.train.Saver()
prev_step = 0
ckpt = tf.train.get_checkpoint_state('.')
if ckpt and ckpt.model_checkpoint_path:
# Check if the run name matches ours
ending = ckpt.model_checkpoint_path.split('/')[-1].split('-')
alt_name = ending[1]
if alt_name == run_name:
prev_step = int(ending[2])
saver.restore(sess, ckpt.model_checkpoint_path)
else:
sess.run(init_op)
else:
sess.run(init_op)
sess.run(init_again)
coord = tf.train.Coordinator()
ckpt_path = os.path.join('.', "model-" + run_name)
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
step = prev_step
while not coord.should_stop():
l = sess.run(lines)
for line in l:
print(line)
save_path = saver.save(sess, ckpt_path, global_step=step)
print('Model saved to %s' % save_path)
step += 1
except tf.errors.OutOfRangeError:
print("Done training!")
save_path = saver.save(sess, ckpt_path, global_step=step)
print('Model saved to %s' % save_path)
finally:
coord.request_stop()
coord.join(threads)
sess.close()
if __name__ == '__main__':
run_name = sys.argv[1]
filenames = sys.argv[2:]
train(run_name, filenames)
# Usage: python train.py model_name input1.txt input2.txt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment