Skip to content

Instantly share code, notes, and snippets.

@Rishit-dagli
Created April 24, 2021 04:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Rishit-dagli/8de5a92fc8caf5700e1025ec3cc31798 to your computer and use it in GitHub Desktop.
Save Rishit-dagli/8de5a92fc8caf5700e1025ec3cc31798 to your computer and use it in GitHub Desktop.
Demonstrate Mixed precision Training with TensorFlow
import tensorflow as tf
policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)
inputs = keras.Input(shape=(784,))
x = tf.keras.layers.Dense(4096, activation='relu')(inputs)
x = tf.keras.layers.Dense(4096, activation='relu')(x)
x = layers.Dense(10)(x)
outputs = layers.Activation('softmax', dtype='float32')(x)
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(...)
model.fit(...)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment