Skip to content

Instantly share code, notes, and snippets.

@nb312
Created May 5, 2023 00:32
Show Gist options
  • Save nb312/f3b851a1503b5649352d3eae42054e41 to your computer and use it in GitHub Desktop.
Save nb312/f3b851a1503b5649352d3eae42054e41 to your computer and use it in GitHub Desktop.
from tensorflow import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
# 加载数据集并预处理
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 设置参数
num_classes = 10
img_rows, img_cols = 28, 28
# 根据后端调整数据格式
if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
# 归一化
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
# 将标签转换为 one-hot 编码
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
# 构建模型
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25)) # 丢弃网络以防止过拟合
model.add(Flatten()) ##扁平化
model.add(Dense(128, activation='relu')) # 全链接层
model.add(Dropout(0.25)) #Dropout扔掉 部分参数
model.add(Dense(num_classes, activation='softmax'))
# 编译模型
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy'])
# 训练模型
model.fit(x_train, y_train,
batch_size=128,
epochs=100,
verbose=1,
validation_data=(x_test, y_test))
# 测试模型
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment