Skip to content

Instantly share code, notes, and snippets.

@gridcellcoder
Last active March 12, 2019 12:11
Show Gist options
  • Save gridcellcoder/db8e045bb920e066b2f6e2ac68ea0676 to your computer and use it in GitHub Desktop.
Save gridcellcoder/db8e045bb920e066b2f6e2ac68ea0676 to your computer and use it in GitHub Desktop.
TPU Cross Shard Optimizer
import tensorflow as tf
import numpy as np
from tensorflow.contrib.tpu.python.tpu import tpu_function
import os
import pprint
import tensorflow as tf
if 'COLAB_TPU_ADDR' not in os.environ:
print('ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!')
else:
tpu_address = 'grpc://' + os.environ['COLAB_TPU_ADDR']
print ('TPU address is', tpu_address)
with tf.Session(tpu_address) as session:
devices = session.list_devices()
print('TPU devices:')
pprint.pprint(devices)
# Add this somewhere at the top
tpu_function.get_tpu_context().set_number_of_shards(8)
# x and y are placeholders for our training data
x = tf.placeholder("float")
y = tf.placeholder("float")
# w is the variable storing our values. It is initialised with starting "guesses"
# w[0] is the "a" in our equation, w[1] is the "b"
w = tf.Variable([1.0, 2.0,3.0, 4.0], name="w")
# Our model of y = a*x + b
y_model = tf.multiply(x, w[0]) + w[1] + w[2] +3
# Our error is defined as the square of the differences
error = tf.square(y - y_model)
# The Gradient Descent Optimizer does the heavy lifting
train_op = tf.train.AdamOptimizer(0.01)
optimizer = tf.contrib.tpu.CrossShardOptimizer(train_op).minimize(error) # TPU change 1
# Normal TensorFlow - initialize values, create a session and run the model
model = tf.global_variables_initializer()
with tf.Session(tpu_address) as session:
session.run(tf.contrib.tpu.initialize_system())
print('init')
session.run(model)
for i in range(10000):
print(i)
x_value = np.random.rand()
y_value = x_value * 2 + 6 + 5 + 3
session.run(optimizer, feed_dict={x: x_value, y: y_value})
w_value = session.run(w)
print("Predicted model: {a:.3f}x + {b:.3f}+{c:.3f}x + {d:.3f}".format(a=w_value[0], b=w_value[1], c=w_value[2], d=w_value[3]))
session.run(tpu.shutdown_system())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment