Skip to content

Instantly share code, notes, and snippets.

@rjpower
Created December 11, 2018 18:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rjpower/57cb75b82ef7bc28be86abd01d0d8900 to your computer and use it in GitHub Desktop.
Save rjpower/57cb75b82ef7bc28be86abd01d0d8900 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import sys
from tensorflow.contrib.tpu.python.tpu import session_support
def reset_tpu(name):
print('Resetting: %s' % name)
resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=name, job_name='worker')
address = resolver.master()
cluster_def = resolver.cluster_spec().as_cluster_def()
print('Master %s. Cluster %s' % (address, cluster_def))
session = tf.Session(address, config=tf.ConfigProto(cluster_def=cluster_def))
m = session_support.WorkerHeartbeatManager.from_devices(
session, session_support.all_worker_devices(session)
)
print('Resetting workers: %s', m)
m.shutdown(timeout_ms=5000)
if __name__ == '__main__':
reset_tpu(sys.argv[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment