Skip to content

Instantly share code, notes, and snippets.

@a46554
Created May 13, 2020 06:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save a46554/5ca3000d2fc2a2c72d1237d9d149b773 to your computer and use it in GitHub Desktop.
Save a46554/5ca3000d2fc2a2c72d1237d9d149b773 to your computer and use it in GitHub Desktop.
import tensorflow as tf
# Callback function to check model accuracy
class RayCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
if(logs.get('accuracy')>0.99):
print("\nReached 99% accuracy so cancelling training!")
self.model.stop_training = True
# Load the MNIST handwrite digit data set
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
# Normalize training data and callback function
callbacks = RayCallback()
x_train = x_train/255.0
x_test = x_test/255.0
# Create an 3 layer model: Flatten -> 128 input -> 10 output
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation=tf.nn.relu),
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
# Setting optimizer and loss function
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Training model untill accuracy > 99%
model.fit(x_train, y_train, epochs=15, callbacks=[callbacks])
# Evaluate with test data
model.evaluate(x_test, y_test)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment