Skip to content

Instantly share code, notes, and snippets.

@rreece
Last active September 13, 2018 20:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rreece/943cccd19825366844422aa4137792ef to your computer and use it in GitHub Desktop.
Save rreece/943cccd19825366844422aa4137792ef to your computer and use it in GitHub Desktop.
import tensorflow as tf
from google.protobuf import text_format
##______________________________________________________________________________
def model_fn(features, labels, mode, params, pbtxt):
...
optimizer = tf.train.AdamOptimizer(learning_rate=params['lr'])
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
if pbtxt:
save_pbtxt(train_op, pbtxt)
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
##______________________________________________________________________________
def save_pbtxt(train_op, pbtxt):
"""
Save a protobuf txt file of the train_op graph_def for XLA extraction.
"""
with tf.variable_scope.variable_scope("save_pbtxt", use_resource=True):
with tf.device("/job:localhost/replica:0/task:0/device:XLA_CPU:0"):
with open(pbtxt, 'w') as f:
f.write(text_format.MessageToString(train_op.graph.as_graph_def(add_shapes=True)))
print('==> %s written.' % pbtxt)
@rreece
Copy link
Author

rreece commented Sep 12, 2018

Vishal says to make sure to

train_op.graph.as_graph_def(add_shapes=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment