Skip to content

Instantly share code, notes, and snippets.

@ireneweng
Last active February 15, 2024 20:17
Show Gist options
  • Save ireneweng/09e7717fe527ba960ae53b245eb57787 to your computer and use it in GitHub Desktop.
Save ireneweng/09e7717fe527ba960ae53b245eb57787 to your computer and use it in GitHub Desktop.
image classification with the mnist dataset
import matplotlib.pyplot as plt
import tensorflow.keras as keras
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
def renderImage(x_train, y_train, num):
image = x_train[num]
plt.imshow(image, cmap='gray')
print(f'This is a {y_train[num]}')
def prepareTrainingData():
# the data, split between train and validation sets
(x_train, y_train), (x_valid, y_valid) = mnist.load_data()
# print(x_train.shape) # 2d image
# render one of the images in the dataset
# renderImage(x_train, y_train, 45)
# flatten the image data
x_train = x_train.reshape(60000, 784)
x_valid = x_valid.reshape(10000, 784)
# print(x_train.shape) # 1d array
# normalize the image data
x_train = x_train / 255
x_valid = x_valid / 255
# print(x_train.dtype, x_train.min(), x_train.max())
# categorically encode labels
num_categories = 10
y_train = keras.utils.to_categorical(y_train, num_categories)
y_valid = keras.utils.to_categorical(y_valid, num_categories)
# print(y_valid[:9])
return (x_train, y_train, x_valid, y_valid)
def createModel():
model = Sequential()
# create input layer
model.add(Dense(units=512, activation='relu', input_shape=(784,)))
# create hidden layer
model.add(Dense(units = 512, activation='relu'))
# create output layer
model.add(Dense(units = 10, activation='softmax'))
model.summary()
model.compile(loss='categorical_crossentropy', metrics=['accuracy'])
return model
def trainModel(model, x_train, y_train, x_valid, y_valid):
history = model.fit(x_train, y_train, epochs=5, verbose=1, validation_data=(x_valid, y_valid))
def clearMemory():
import IPython
app = IPython.Application.instance()
app.kernel.do_shutdown(True)
def run():
model = createModel()
x_train, y_train, x_valid, y_valid = prepareTrainingData()
trainModel(model, x_train, y_train, x_valid, y_valid)
clearMemory()
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment