Created
July 31, 2018 13:19
-
-
Save koshian2/3d99ff54715de586f3ac050b32fa1402 to your computer and use it in GitHub Desktop.
Achieved 90% CIFAR-10 validation accuracy with 10-layers CNN
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import pickle | |
from keras.layers import Input, Conv2D, Activation, MaxPool2D, BatchNormalization, Flatten, Dense, Dropout | |
from keras.models import Model | |
from keras.optimizers import Adam | |
from keras.utils import to_categorical | |
from keras.datasets import cifar10 | |
from keras.preprocessing.image import ImageDataGenerator | |
(X_train, y_train), (X_val, y_val) = cifar10.load_data() | |
y_train = to_categorical(y_train) | |
y_val = to_categorical(y_val) | |
# モデル | |
input = Input(shape=(32, 32, 3)) | |
X = Conv2D(64, (1, 1))(input) | |
X = BatchNormalization()(X) | |
X = Activation("relu")(X) | |
X = Conv2D(64, (3, 3))(X) | |
X = BatchNormalization()(X) | |
X = Activation("relu")(X) | |
X = Conv2D(64, (5, 5))(X) | |
X = BatchNormalization()(X) | |
X = Activation("relu")(X) | |
X = Dropout(0.25)(X) | |
X = MaxPool2D((2,2))(X) | |
X = Conv2D(128, (1, 1))(X) | |
X = BatchNormalization()(X) | |
X = Activation("relu")(X) | |
X = Conv2D(128, (3, 3))(X) | |
X = BatchNormalization()(X) | |
X = Activation("relu")(X) | |
X = Conv2D(128, (5, 5))(X) | |
X = BatchNormalization()(X) | |
X = Activation("relu")(X) | |
X = Dropout(0.25)(X) | |
X = Conv2D(256, (1, 1))(X) | |
X = BatchNormalization()(X) | |
X = Activation("relu")(X) | |
X = Conv2D(256, (3, 3))(X) | |
X = BatchNormalization()(X) | |
X = Activation("relu")(X) | |
X = Conv2D(256, (5, 5))(X) | |
X = BatchNormalization()(X) | |
X = Activation("relu")(X) | |
X = Dropout(0.25)(X) | |
X = Flatten()(X) | |
output = Dense(10, activation="softmax")(X) | |
model = Model(input, output) | |
# コンパイル | |
model.compile(optimizer=Adam(), loss="categorical_crossentropy", metrics=["accuracy"]) | |
# Data Augmentation | |
datagen = ImageDataGenerator( | |
rescale=1./255, | |
rotation_range=20, | |
width_shift_range=0.2, | |
height_shift_range=0.2, | |
channel_shift_range=50, | |
horizontal_flip=True) | |
validationgen = ImageDataGenerator( | |
rescale=1./255) | |
# フィット | |
datagen.fit(X_train) | |
validationgen.fit(X_val) | |
history = model.fit_generator(datagen.flow(X_train, y_train, batch_size=128), | |
steps_per_epoch=len(X_train) / 128, validation_data=validationgen.flow(X_val, y_val), epochs=700).history | |
with open("history.dat", "wb") as fp: | |
pickle.dump(history, fp) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment