Skip to content

Instantly share code, notes, and snippets.

@talolard
Last active February 19, 2018 14:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save talolard/6c45f439a7524ef6b894d6225f6bbb58 to your computer and use it in GitHub Desktop.
Save talolard/6c45f439a7524ef6b894d6225f6bbb58 to your computer and use it in GitHub Desktop.
Example of using dataset and iterators to the train and val
if __name__=="__main__":
#make the iterators and next element op
next_element, training_init_op, validation_init_op = prepare_dataset_iterators(batch_size=32)
...
for epoch in range(1000):
#Initialize the iterator to consume training data
sess.run(training_init_op)
while True:
#As long as the iterator is not empty
try:
_, summary,gs = sess.run([M.train,M.write_op,M.gs],feed_dict={M.lr: lr, M.keep_prob:keep_prob})
except tf.errors.OutOfRangeError:
#Do stuff at the end of a training epoch here
break
#Intiialize the iterator to provide validation data
sess.run(validation_init_op)
#We'll store the losses from each batch to get an average
while True:
# As long as the iterator is not empty
try:
loss,summary,gs,_ = sess.run([M.total_loss,M.write_op,M.gs,M.increment_gs],feed_dict={M.lr: lr,M.keep_prob:1})
except tf.errors.OutOfRangeError:
#Do stuff at the end of a validation run here
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment