Last active
February 6, 2018 09:22
-
-
Save yaroslavvb/f7db732d0afc99ab98172ca2cf992245 to your computer and use it in GitHub Desktop.
Example of benchmarking session.run call
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Example of profiling session.run overhead | |
# for python profiling | |
# python -m cProfile -o session-run-benchmark-feed.prof session-run-benchmark.py feed_dict | |
# python -m cProfile -o session-run-benchmark-variable.prof session-run-benchmark.py variable | |
# pip install snakeviz | |
# snakeviz session-run-benchmark-feed.prof | |
# snakeviz session-run-benchmark.prof | |
# | |
# | |
# Feed_dict: 147 usec, no feed dict, 71 usec | |
import tensorflow as tf | |
import numpy as np | |
import time, sys, os | |
# make sure our ops aren't getting optimized away | |
config = tf.ConfigProto(graph_options=tf.GraphOptions(optimizer_options=tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L0))) | |
sess = tf.Session(config=config) | |
n = 1024 | |
x0 = np.random.random([1, n]) | |
x = tf.placeholder(tf.float32, shape=x0.shape) | |
x_cached = tf.Variable(x0) | |
simple_op_feed_dict = tf.square(x) | |
simple_op = tf.square(x_cached) | |
sess.run(tf.global_variables_initializer()) | |
num_iters = 100000 | |
use_feed_dict = True | |
timelines = False | |
if sys.argv[1] == 'feed_dict': | |
use_feed_dict = True | |
elif sys.argv[1] == 'variable': | |
use_feed_dict = False | |
else: | |
print("Error") | |
if len(sys.argv)>2: | |
assert sys.argv[2] == 'timelines' | |
timelines = True | |
if timelines: | |
run_metadata = tf.RunMetadata() | |
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) | |
ss = tf.contrib.stat_summarizer.NewStatSummarizer(tf.get_default_graph().as_graph_def().SerializeToString()) | |
for i in range(num_iters//10): | |
if use_feed_dict: | |
sess.run(simple_op_feed_dict.op, feed_dict={x: x0}, | |
options=run_options, | |
run_metadata = run_metadata) | |
else: | |
sess.run(simple_op.op, | |
options=run_options, | |
run_metadata = run_metadata) | |
ss.ProcessStepStatsStr(run_metadata.step_stats.SerializeToString()) | |
print(ss.GetOutputString()) | |
sys.exit() | |
if use_feed_dict: | |
for i in range(num_iters): | |
sess.run(simple_op_feed_dict.op, feed_dict={x: x0}) | |
else: | |
for i in range(num_iters): | |
sess.run(simple_op.op) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment