Skip to content

Instantly share code, notes, and snippets.

@gibiansky
Created September 1, 2016 17:49
Show Gist options
  • Save gibiansky/407340dc25348c1d38e21d177b081397 to your computer and use it in GitHub Desktop.
Save gibiansky/407340dc25348c1d38e21d177b081397 to your computer and use it in GitHub Desktop.
Demo of wrapping Tensorflow ops with some timing info
import time
import tensorflow as tf
def print_time(name):
"""This creates a new function that prints out the time annotated with a
marker.
>>> print_time("Name")(1, 2, 3)
Name 1472751445.5185795
(1, 2, 3)
Meant to be used with tf.py_func to add timing information to a graph in
this demo.
See py_func here:
https://www.tensorflow.org/versions/r0.7/api_docs/python/script_ops.html#py_func
"""
def f(*args):
print(name, time.time())
return args
return f
class TimerGraph(tf.Graph):
"""Create a tf.Graph subclass which inserts nodes for timing. These nodes
simply print out the time of completion for a few chosen operation types.
This idea can be extended (probably?) to adding a custom timing op to
record start AND end times to some sort of in-memory database, and then
printing it out at the end.
Perhaps this would also be the right place to add "performance counters",
and effectively have a large mapping between op type and performance
metrics (floating point operations, bytes transferred, etc), which then
gets populated with data for the model in question in this subclass.
"""
def create_op(self, *args, **kwargs):
"""This method gets called whenever a new op is created; it also adds
the op to the graph."""
original_op = super(TimerGraph, self).create_op(*args, **kwargs)
op_name = args[0]
if op_name in ["Pow", "Add"]:
# Our addition just passes through all data, so output dtypes
# are just copied from output dtypes of original op
output_dtypes = [output.dtype for output in original_op.outputs]
# Create a new op by calling a python function, which just passes
# along all its inputs as outputs. It's a multi-argument identity
# function that records some metadata. By having this be an op we
# ensure that it gets called at the right time in the graph.
wrap_op = tf.py_func(print_time(op_name), original_op.outputs,
output_dtypes)
# py_func returns a list of tensors that were output by the
# function. We don't actually care about the tensors that were
# input, we care about the op, so we can return the op. So we get
# the first tensor and get its op, which is the py_func op we care
# about.
return wrap_op[0].op
else:
# Just return the original op if this isn't an op we want to time
# For example, we don't want to time Placeholder ops...
return original_op
graph = TimerGraph()
with graph.as_default():
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
z = x + y + y + y + y + y
w = z ** 2
with tf.Session() as session:
print(session.run(w, feed_dict={
x: [i for i in range(10000000)],
y: [i for i in range(10000000)]
}))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment