Skip to content

Instantly share code, notes, and snippets.

@emuccino
Last active August 14, 2019 17:03
Show Gist options
  • Save emuccino/80d861bda4f5ba82bd5f3e802ddc0d4b to your computer and use it in GitHub Desktop.
Save emuccino/80d861bda4f5ba82bd5f3e802ddc0d4b to your computer and use it in GitHub Desktop.
mnist_classifier
import keras
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Dense, Dropout, Flatten ,Input
from keras.layers import Conv2D, MaxPooling2D, Reshape, Add
from keras.metrics import categorical_accuracy
from keras.regularizers import l1_l2, l2, l1
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.layers import Activation
from keras.utils.generic_utils import get_custom_objects
from tensorflow.python.keras import backend as K
from keras.preprocessing.image import array_to_img,img_to_array
import matplotlib.pyplot as plt
import numpy as np
#load MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
#preprocess data
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255.
x_test /= 255.
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)
#compile CNN network for MNIST classification
inputs = Input(shape=(28,28,1))
net = Conv2D(32, kernel_size=(3, 3),
activation='relu')(inputs)
net = Conv2D(64, kernel_size=(3, 3),
activation='relu')(net)
net = MaxPooling2D(pool_size=(2, 2))(net)
net = Dropout(0.25)(net)
net = Flatten()(net)
net = Dense(128, activation='relu')(net)
net = Dropout(0.5)(net)
outputs = Dense(10, activation='softmax')(net)
mnist_model = Model(inputs=inputs, outputs=outputs, name='classification_model')
mnist_model.compile(optimizer='nadam', loss='categorical_crossentropy',metrics=[categorical_accuracy])
#train MNIST classifer
earlyStop = EarlyStopping(monitor='val_categorical_accuracy', min_delta=0, patience=10, verbose=0, mode='auto',
baseline=None, restore_best_weights=True)
mnist_model.fit(x_train, y_train, batch_size=128, epochs=100, verbose=0, validation_data=(x_test, y_test),
callbacks=[earlyStop])
print(mnist_model.evaluate(x_train, y_train))
print(mnist_model.evaluate(x_test, y_test))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment