Last active
October 8, 2021 06:47
-
-
Save yumaueno/9b458cfe0c92b4e3e14430e05add52c0 to your computer and use it in GitHub Desktop.
MnistデータをCNNで分類
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
from sklearn.model_selection import train_test_split | |
import tensorflow as tf | |
from tensorflow.keras.datasets import mnist | |
from tensorflow import keras | |
from tensorflow.keras import layers, models | |
from tensorflow.keras.utils import to_categorical | |
# Kerasに付属の手書き数字画像データをダウンロード | |
np.random.seed(0) | |
(X_train_base, labels_train_base), (test_x, test_y) = mnist.load_data() | |
# Training set を学習データと検証データに8:2で分割する | |
train_x, valid_x, train_y, valid_y = train_test_split(X_train_base, labels_train_base, test_size = 0.2) | |
# 各画像のShapeを整形 | |
train_x = train_x.reshape((48000, 28, 28, 1)) | |
valid_x = valid_x.reshape((12000, 28, 28, 1)) | |
test_x = test_x.reshape((10000,28,28,1)) | |
#正規化 | |
train_x = np.array(train_x).astype('float32') | |
valid_x = np.array(valid_x).astype('float32') | |
test_x = np.array(test_x).astype('float32') | |
train_x /= 255 | |
valid_x /= 255 | |
test_x /= 255 | |
# train_y, valid_y をダミー変数化 | |
train_y = to_categorical(train_y) | |
valid_y = to_categorical(valid_y) | |
model = models.Sequential() | |
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1))) | |
model.add(layers.MaxPooling2D((2, 2))) | |
model.add(layers.Conv2D(64, (3, 3), activation='relu')) | |
model.add(layers.MaxPooling2D((2, 2))) | |
model.add(layers.Conv2D(64, (3, 3), activation='relu')) | |
model.add(layers.Flatten()) | |
model.add(layers.Dense(64, activation='relu')) | |
model.add(layers.Dense(10, activation='softmax')) | |
# モデルを構築 | |
model.compile(optimizer=tf.optimizers.Adam(0.01), loss='categorical_crossentropy', metrics=['accuracy']) | |
# Early stoppingを適用してフィッティング | |
log = model.fit(train_x, train_y, epochs=100, batch_size=10, verbose=True, | |
callbacks=[keras.callbacks.EarlyStopping(monitor='val_loss', | |
min_delta=0, patience=10, | |
verbose=1)], | |
validation_data=(valid_x, valid_y)) | |
# テストデータの出力から0~9のどの値か判断 | |
pred_test = np.argmax(model.predict(test_x), axis=1) | |
sum(pred_test == test_y)/len(pred_test) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment