Skip to content

Instantly share code, notes, and snippets.

@hristian-carabulea
Forked from jzuern/model_fn.py
Created October 29, 2019 17:44
Show Gist options
  • Save hristian-carabulea/2e8787f40cf586df964526249e71a94c to your computer and use it in GitHub Desktop.
Save hristian-carabulea/2e8787f40cf586df964526249e71a94c to your computer and use it in GitHub Desktop.
#
# .... body of model_fn
#
optimizer = tf.train.AdamOptimizer()
if FLAGS.use_tpu:
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
# return tf.estimator.EstimatorSpec( # CPU or GPU estimator
# mode=mode,
# loss=loss,
# train_op=train_op,
# predictions=predictions)
return tf.contrib.tpu.TPUEstimatorSpec( # TPU estimator
mode=mode,
loss=loss,
train_op=train_op,
predictions=predictions)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment