Skip to content

Instantly share code, notes, and snippets.

@jzuern
Created November 7, 2018 10:54
Show Gist options
  • Save jzuern/9b5156af17c1f8d687bdcd09923943fb to your computer and use it in GitHub Desktop.
Save jzuern/9b5156af17c1f8d687bdcd09923943fb to your computer and use it in GitHub Desktop.
#
# .... body of model_fn
#
optimizer = tf.train.AdamOptimizer()
if FLAGS.use_tpu:
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
# return tf.estimator.EstimatorSpec( # CPU or GPU estimator
# mode=mode,
# loss=loss,
# train_op=train_op,
# predictions=predictions)
return tf.contrib.tpu.TPUEstimatorSpec( # TPU estimator
mode=mode,
loss=loss,
train_op=train_op,
predictions=predictions)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment