Skip to content

Instantly share code, notes, and snippets.

@a46554
Created May 17, 2020 08:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save a46554/29ad285e6c40e1a53a1cbd5e0759d62f to your computer and use it in GitHub Desktop.
Save a46554/29ad285e6c40e1a53a1cbd5e0759d62f to your computer and use it in GitHub Desktop.
import tensorflow as tf
# Callback function to check model accuracy
class RayCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
if(logs.get('accuracy')>0.998):
print("\nReached 99.8% accuracy so cancelling training!")
self.model.stop_training = True
# Load the MNIST handwrite digit data set
mnist = tf.keras.datasets.mnist
(training_images, training_labels), (test_images, test_labels) = mnist.load_data()
# Reshap and normalize training data and callback function
callbacks = RayCallback()
training_images=training_images.reshape(60000, 28, 28, 1)
training_images = training_images/255.0
# Create an 5 layer model:
# Convolution 16 filter with 3 X 3 size to each Image -> Polling each Image to 1/4
# Flatten -> 128 input -> 10 output
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(16, (3,3), activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
# Setting optimizer and loss function
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# Training model untill accuracy > 99.8%
model.fit(training_images, training_labels, epochs=10, callbacks=[callbacks])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment