Skip to content

Instantly share code, notes, and snippets.

@drorgl
Created June 6, 2018 09:01
Show Gist options
  • Save drorgl/3dca097b11320e9cf21cececdb46b6c0 to your computer and use it in GitHub Desktop.
Save drorgl/3dca097b11320e9cf21cececdb46b6c0 to your computer and use it in GitHub Desktop.
mnist sequential tiny model
# 3. Import libraries and modules
import numpy as np
np.random.seed(123) # for reproducibility
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D,AveragePooling2D
from keras.utils import np_utils
from keras.datasets import mnist
from matplotlib import pyplot as plt
#from keras.utils.vis_utils import plot_model
# 4. Load pre-shuffled MNIST data into train and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()
orig_x_test = X_test
# 5. Preprocess input data
X_train = X_train.reshape(X_train.shape[0], 28,28,1)
X_test = X_test.reshape(X_test.shape[0], 28,28,1)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
# plt.imshow(X_train[0])
# plt.show()
# 6. Preprocess class labels
Y_train = np_utils.to_categorical(y_train, 10)
Y_test = np_utils.to_categorical(y_test, 10)
#7. Define model architecture
model = Sequential()
model.add(Convolution2D(64, (4, 4), activation='relu', input_shape=(28,28,1)))
model.add(Dropout(0.25))
model.add(AveragePooling2D(2,2))
model.add(Convolution2D(16, (4, 4), activation='relu'))
model.add(Dropout(0.25))
model.add(AveragePooling2D(2,2))
model.add(Flatten())
model.add(Dropout(0.15))
model.add(Dense(70, activation='relu'))
model.add(Dense(10, activation='softmax'))
# 8. Compile model
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model.summary()
#plot_model(model, to_file='model_plot.png', show_shapes=True, show_layer_names=True)
# 9. Fit model on training data
history = model.fit(X_train, Y_train,
batch_size=16, nb_epoch=20, verbose=1,shuffle=True,
validation_data=(X_test, Y_test))
print ("history", history.history);
#plot history
plt.plot(history.history['loss'], label='train')
plt.plot(history.history['val_loss'], label='test')
plt.legend()
plt.savefig('training.png')
#plt.show()
# 10. Evaluate model on test data
score = model.evaluate(X_test, Y_test, verbose=0)
#prediction = model.predict(X_test[0:1])
#print("prediction", prediction)
#plt.imshow(orig_x_test[0])
#plt.show()
print ("score", score)
print("Large CNN Error: %.2f%%" % (100-score[1]*100))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment