Skip to content

Instantly share code, notes, and snippets.

@dettmar
Last active April 6, 2023 09:12
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dettmar/ac6b51708ab00949ae0364df6ac661f8 to your computer and use it in GitHub Desktop.
Save dettmar/ac6b51708ab00949ae0364df6ac661f8 to your computer and use it in GitHub Desktop.
Getting Tensorflow to run on TPUs in Google Colab
# don't forget to first switch to TPU (Runtime > Change runtime type)
import tensorflow as tf
# create the model (which is called later within the right scope)
# make sure that the input_shape or input_dim is given in the first layer
def createmodel():
return tf.keras.models.Sequential([
tf.keras.layers.Conv2D(128, ..., input_shape=input_shape),
# ...
])
# set up the resolver for multi TPU (usually a cluster of 8 tpus)
resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
tf.contrib.distribute.initialize_tpu_system(resolver)
strategy = tf.contrib.distribute.TPUStrategy(resolver)
# create and compile model within the tpu scope
with strategy.scope():
model = createmodel()
model.compile(loss=tf.keras.losses.categorical_crossentropy,
optimizer=tf.train.AdamOptimizer(), # note that keras optimizer do not yet work
metrics=['accuracy'])
# check that it's looking ok
model.summary()
# train the model (make sure that steps_per_epoch is an exact divisor of the total amount of training samples, to utilize all TPUs)
model.fit(x_train, y_train,
epochs=100,
steps_per_epoch=50)
@sravanch1287
Copy link

Hi, I tried exactly the same thing for a simple DQN trying to solve Cartpole env.
Please check this, In the colab notebook, (https://colab.research.google.com/drive/1I1KdH9HD2F_G5-vpCSB2jSN8qQtCuKcB#scrollTo=kZ7ve1ihiLN7)

Whenever the code enters into predict method, It gives the following error.

image

Also, I went through the following example.
https://colab.research.google.com/github/tensorflow/tpu/blob/master/tools/colab/classification_iris_data_with_keras.ipynb#scrollTo=ZhnrwcSe3KER&uniqifier=1

In this example, they use a sequential model similar to the one I'm trying to evaluate.

I couldn't find any differences, the first code doesn't work while the second one works.

Can you please let me know, If there is any step I'm missing out?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment