Skip to content

Instantly share code, notes, and snippets.

@koshian2
Created July 31, 2018 13:19
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koshian2/3d99ff54715de586f3ac050b32fa1402 to your computer and use it in GitHub Desktop.
Save koshian2/3d99ff54715de586f3ac050b32fa1402 to your computer and use it in GitHub Desktop.
Achieved 90% CIFAR-10 validation accuracy with 10-layers CNN
import matplotlib.pyplot as plt
import pickle
from keras.layers import Input, Conv2D, Activation, MaxPool2D, BatchNormalization, Flatten, Dense, Dropout
from keras.models import Model
from keras.optimizers import Adam
from keras.utils import to_categorical
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
(X_train, y_train), (X_val, y_val) = cifar10.load_data()
y_train = to_categorical(y_train)
y_val = to_categorical(y_val)
# モデル
input = Input(shape=(32, 32, 3))
X = Conv2D(64, (1, 1))(input)
X = BatchNormalization()(X)
X = Activation("relu")(X)
X = Conv2D(64, (3, 3))(X)
X = BatchNormalization()(X)
X = Activation("relu")(X)
X = Conv2D(64, (5, 5))(X)
X = BatchNormalization()(X)
X = Activation("relu")(X)
X = Dropout(0.25)(X)
X = MaxPool2D((2,2))(X)
X = Conv2D(128, (1, 1))(X)
X = BatchNormalization()(X)
X = Activation("relu")(X)
X = Conv2D(128, (3, 3))(X)
X = BatchNormalization()(X)
X = Activation("relu")(X)
X = Conv2D(128, (5, 5))(X)
X = BatchNormalization()(X)
X = Activation("relu")(X)
X = Dropout(0.25)(X)
X = Conv2D(256, (1, 1))(X)
X = BatchNormalization()(X)
X = Activation("relu")(X)
X = Conv2D(256, (3, 3))(X)
X = BatchNormalization()(X)
X = Activation("relu")(X)
X = Conv2D(256, (5, 5))(X)
X = BatchNormalization()(X)
X = Activation("relu")(X)
X = Dropout(0.25)(X)
X = Flatten()(X)
output = Dense(10, activation="softmax")(X)
model = Model(input, output)
# コンパイル
model.compile(optimizer=Adam(), loss="categorical_crossentropy", metrics=["accuracy"])
# Data Augmentation
datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
channel_shift_range=50,
horizontal_flip=True)
validationgen = ImageDataGenerator(
rescale=1./255)
# フィット
datagen.fit(X_train)
validationgen.fit(X_val)
history = model.fit_generator(datagen.flow(X_train, y_train, batch_size=128),
steps_per_epoch=len(X_train) / 128, validation_data=validationgen.flow(X_val, y_val), epochs=700).history
with open("history.dat", "wb") as fp:
pickle.dump(history, fp)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment