Skip to content

Instantly share code, notes, and snippets.

@yaroslavvb
Last active February 6, 2018 09:22
Show Gist options
  • Save yaroslavvb/f7db732d0afc99ab98172ca2cf992245 to your computer and use it in GitHub Desktop.
Save yaroslavvb/f7db732d0afc99ab98172ca2cf992245 to your computer and use it in GitHub Desktop.
Example of benchmarking session.run call
# Example of profiling session.run overhead
# for python profiling
# python -m cProfile -o session-run-benchmark-feed.prof session-run-benchmark.py feed_dict
# python -m cProfile -o session-run-benchmark-variable.prof session-run-benchmark.py variable
# pip install snakeviz
# snakeviz session-run-benchmark-feed.prof
# snakeviz session-run-benchmark.prof
#
#
# Feed_dict: 147 usec, no feed dict, 71 usec
import tensorflow as tf
import numpy as np
import time, sys, os
# make sure our ops aren't getting optimized away
config = tf.ConfigProto(graph_options=tf.GraphOptions(optimizer_options=tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L0)))
sess = tf.Session(config=config)
n = 1024
x0 = np.random.random([1, n])
x = tf.placeholder(tf.float32, shape=x0.shape)
x_cached = tf.Variable(x0)
simple_op_feed_dict = tf.square(x)
simple_op = tf.square(x_cached)
sess.run(tf.global_variables_initializer())
num_iters = 100000
use_feed_dict = True
timelines = False
if sys.argv[1] == 'feed_dict':
use_feed_dict = True
elif sys.argv[1] == 'variable':
use_feed_dict = False
else:
print("Error")
if len(sys.argv)>2:
assert sys.argv[2] == 'timelines'
timelines = True
if timelines:
run_metadata = tf.RunMetadata()
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
ss = tf.contrib.stat_summarizer.NewStatSummarizer(tf.get_default_graph().as_graph_def().SerializeToString())
for i in range(num_iters//10):
if use_feed_dict:
sess.run(simple_op_feed_dict.op, feed_dict={x: x0},
options=run_options,
run_metadata = run_metadata)
else:
sess.run(simple_op.op,
options=run_options,
run_metadata = run_metadata)
ss.ProcessStepStatsStr(run_metadata.step_stats.SerializeToString())
print(ss.GetOutputString())
sys.exit()
if use_feed_dict:
for i in range(num_iters):
sess.run(simple_op_feed_dict.op, feed_dict={x: x0})
else:
for i in range(num_iters):
sess.run(simple_op.op)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment