Created
May 17, 2020 08:16
-
-
Save a46554/29ad285e6c40e1a53a1cbd5e0759d62f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
# Callback function to check model accuracy | |
class RayCallback(tf.keras.callbacks.Callback): | |
def on_epoch_end(self, epoch, logs={}): | |
if(logs.get('accuracy')>0.998): | |
print("\nReached 99.8% accuracy so cancelling training!") | |
self.model.stop_training = True | |
# Load the MNIST handwrite digit data set | |
mnist = tf.keras.datasets.mnist | |
(training_images, training_labels), (test_images, test_labels) = mnist.load_data() | |
# Reshap and normalize training data and callback function | |
callbacks = RayCallback() | |
training_images=training_images.reshape(60000, 28, 28, 1) | |
training_images = training_images/255.0 | |
# Create an 5 layer model: | |
# Convolution 16 filter with 3 X 3 size to each Image -> Polling each Image to 1/4 | |
# Flatten -> 128 input -> 10 output | |
model = tf.keras.models.Sequential([ | |
tf.keras.layers.Conv2D(16, (3,3), activation='relu', input_shape=(28, 28, 1)), | |
tf.keras.layers.MaxPooling2D(2, 2), | |
tf.keras.layers.Flatten(), | |
tf.keras.layers.Dense(128, activation='relu'), | |
tf.keras.layers.Dense(10, activation='softmax') | |
]) | |
# Setting optimizer and loss function | |
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) | |
# Training model untill accuracy > 99.8% | |
model.fit(training_images, training_labels, epochs=10, callbacks=[callbacks]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment