Created
January 15, 2020 11:58
-
-
Save Rishit-dagli/d0fa03d617703234bf270f3e18079270 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train_mnist_conv(): | |
# Please write your code only where you are indicated. | |
# please do not remove model fitting inline comments. | |
# YOUR CODE STARTS HERE | |
class myCallback(tf.keras.callbacks.Callback): | |
def on_epoch_end(self, epoch, logs={}): | |
if(logs.get('acc')>0.998): | |
print("/n Reached 99.8% accuracy so cancelling training!") | |
self.model.stop_training = True | |
# YOUR CODE ENDS HERE | |
mnist = tf.keras.datasets.mnist | |
(training_images, training_labels), (test_images, test_labels) = mnist.load_data() | |
# YOUR CODE STARTS HERE | |
callbacks = myCallback() | |
training_images=training_images.reshape(60000, 28, 28, 1) | |
test_images=test_images.reshape(10000, 28, 28, 1) | |
training_images = training_images / 255.0 | |
test_images = test_images / 255.0 | |
# YOUR CODE ENDS HERE | |
model = tf.keras.models.Sequential([ | |
# YOUR CODE STARTS HERE | |
tf.keras.layers.Conv2D(64, (3,3), activation='relu', input_shape=(28, 28, 1)), | |
tf.keras.layers.MaxPooling2D(2, 2), | |
tf.keras.layers.Flatten(), | |
tf.keras.layers.Dense(256, activation='relu'), | |
tf.keras.layers.Dense(10, activation='softmax') | |
# YOUR CODE ENDS HERE | |
]) | |
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) | |
# model fitting | |
history = model.fit( | |
# YOUR CODE STARTS HERE | |
training_images, | |
training_labels, | |
epochs = 20, | |
callbacks=[callbacks] | |
# YOUR CODE ENDS HERE | |
) | |
# model fitting | |
return history.epoch, history.history['acc'][-1] | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment