Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save charlesbaer/4b73c291109710da61dd9f5e2740feb3 to your computer and use it in GitHub Desktop.
Save charlesbaer/4b73c291109710da61dd9f5e2740feb3 to your computer and use it in GitHub Desktop.
train_experiment function in TensorFlow
def train_experiment(training_steps, clean_start, incremental, run_config):
if clean_start == True:
if tf.gfile.Exists(run_config.model_dir):
print("Removing previous artefacts...")
tf.gfile.DeleteRecursively(run_config.model_dir)
print ""
estimator = create_estimator(run_config)
print ""
time_start = datetime.utcnow()
print("Experiment started at {}".format(time_start.strftime("%H:%M:%S")))
print(".......................................")
if incremental:
# Use steps parameter
estimator.train(train_input_fn, steps=training_steps)
else:
# Use max_steps parameter
estimator.train(train_input_fn, max_steps=training_steps)
time_end = datetime.utcnow()
print(".......................................")
print("Experiment finished at {}".format(time_end.strftime("%H:%M:%S")))
print("")
time_elapsed = time_end - time_start
print("Experiment elapsed time: {} seconds".format(time_elapsed.total_seconds()))
return estimator
Then we defined several different calls to the train_experiment function to highlight the impact of the steps or max_steps values. The first function call trains to 1000 steps using the max_steps parameter.
train_experiment(
training_steps=1000,
clean_start=True,
incremental=False,
run_config=run_config
)
Experiment started at 16:52:35
.......................................
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into models/census/dnn_classifier/model.ckpt.
INFO:tensorflow:loss = 31132416.0, step = 1
INFO:tensorflow:global_step/sec: 82.2509
INFO:tensorflow:loss = 12670475.0, step = 101 (1.218 sec)
INFO:tensorflow:global_step/sec: 169.578
INFO:tensorflow:loss = 11341477.0, step = 201 (0.590 sec)
INFO:tensorflow:global_step/sec: 177.614
INFO:tensorflow:loss = 12852321.0, step = 301 (0.563 sec)
INFO:tensorflow:global_step/sec: 163.928
INFO:tensorflow:loss = 13684520.0, step = 401 (0.610 sec)
INFO:tensorflow:global_step/sec: 169.234
INFO:tensorflow:loss = 12090486.0, step = 501 (0.591 sec)
INFO:tensorflow:global_step/sec: 187.021
INFO:tensorflow:loss = 13600504.0, step = 601 (0.534 sec)
INFO:tensorflow:global_step/sec: 167.494
INFO:tensorflow:loss = 14767286.0, step = 701 (0.597 sec)
INFO:tensorflow:global_step/sec: 145.886
INFO:tensorflow:loss = 10702760.0, step = 801 (0.685 sec)
INFO:tensorflow:global_step/sec: 153.083
INFO:tensorflow:loss = 13668747.0, step = 901 (0.654 sec)
INFO:tensorflow:Saving checkpoints for 1000 into models/census/dnn_classifier/model.ckpt.
INFO:tensorflow:Loss for final step: 14524872.0.
.......................................
Experiment finished at 16:52:47
Experiment elapsed time: 12.109969 seconds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment