Skip to content

Instantly share code, notes, and snippets.

@azybler
Created January 13, 2019 05:09
Show Gist options
  • Save azybler/16c56984be2fe53433a603b53143cb57 to your computer and use it in GitHub Desktop.
Save azybler/16c56984be2fe53433a603b53143cb57 to your computer and use it in GitHub Desktop.
Runnable Effnet training script (trains on CIFAR-10 dataset)
'''Train a EffNet CNN on the CIFAR10 small images dataset.
https://towardsdatascience.com/3-small-but-powerful-convolutional-networks-27ef86faa42d
'''
from __future__ import print_function
import keras
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.layers import Dense, Dropout, Activation, Flatten, Conv2D, MaxPooling2D, Input, LeakyReLU, BatchNormalization, DepthwiseConv2D
from keras.activations import *
from keras.callbacks import *
import os
def get_post(x_in):
x = LeakyReLU()(x_in)
x = BatchNormalization()(x)
return x
def get_block(x_in, ch_in, ch_out):
x = Conv2D(ch_in,
kernel_size=(1, 1),
padding='same',
use_bias=False)(x_in)
x = get_post(x)
x = DepthwiseConv2D(kernel_size=(1, 3), padding='same', use_bias=False)(x)
x = get_post(x)
x = MaxPooling2D(pool_size=(2, 1),
strides=(2, 1))(x) # Separable pooling
x = DepthwiseConv2D(kernel_size=(3, 1),
padding='same',
use_bias=False)(x)
x = get_post(x)
x = Conv2D(ch_out,
kernel_size=(2, 1),
strides=(1, 2),
padding='same',
use_bias=False)(x)
x = get_post(x)
return x
def EffNet(input_shape, num_classes, include_top=True, weights=None):
x_in = Input(shape=input_shape)
x = get_block(x_in, 32, 64)
x = get_block(x, 64, 128)
x = get_block(x, 128, 256)
if include_top:
x = Flatten()(x)
x = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=x_in, outputs=x)
if weights is not None:
model.load_weights(weights, by_name=True)
return model
batch_size = 32
num_classes = 10
data_augmentation = False
save_dir = os.path.join(os.getcwd(), 'saved_models')
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
# The data, split between train and test sets:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# Convert class vectors to binary class matrices.
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
model = EffNet(input_shape=x_train.shape[1:], num_classes=num_classes)
# initiate RMSprop optimizer
opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6)
# Let's train the model
model.compile(loss='categorical_crossentropy',
optimizer=opt,
metrics=['accuracy'])
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
if not data_augmentation:
print('Not using data augmentation.')
iter = 1
min_val_loss = 999
max_val_acc = -1
while True:
print("iter = {}".format(iter))
history = model.fit(x_train, y_train,
batch_size=batch_size,
epochs=1,
validation_data=(x_test, y_test),
shuffle=True)
iter = iter + 1
val_loss = history.history['val_loss'][0]
val_acc = history.history['val_acc'][0]
if val_loss < min_val_loss:
min_val_loss = val_loss
model_path = os.path.join(save_dir, "keras_cifar10_mvl{}_trained_model.h5".format(min_val_loss))
model.save(model_path)
if val_acc > max_val_acc:
max_val_acc = val_acc
model_path = os.path.join(save_dir, "keras_cifar10_mva{}_trained_model.h5".format(max_val_acc))
model.save(model_path)
print("val_loss = {} min_val_loss = {}".format(val_loss, min_val_loss))
print("val_acc = {} max_val_acc = {}".format(val_acc, max_val_acc))
else:
print('Using real-time data augmentation.')
# This will do preprocessing and realtime data augmentation:
datagen = ImageDataGenerator(
featurewise_center=False, # set input mean to 0 over the dataset
samplewise_center=False, # set each sample mean to 0
featurewise_std_normalization=False, # divide inputs by std of the dataset
samplewise_std_normalization=False, # divide each input by its std
zca_whitening=False, # apply ZCA whitening
zca_epsilon=1e-06, # epsilon for ZCA whitening
rotation_range=5, # randomly rotate images in the range (degrees, 0 to 180)
# randomly shift images horizontally (fraction of total width)
width_shift_range=0.1,
# randomly shift images vertically (fraction of total height)
height_shift_range=0.1,
shear_range=0., # set range for random shear
zoom_range=0., # set range for random zoom
channel_shift_range=0., # set range for random channel shifts
# set mode for filling points outside the input boundaries
fill_mode='nearest',
cval=0., # value used for fill_mode = "constant"
horizontal_flip=True, # randomly flip images
vertical_flip=False, # randomly flip images
# set rescaling factor (applied before any other transformation)
rescale=None,
# set function that will be applied on each input
preprocessing_function=None,
# image data format, either "channels_first" or "channels_last"
data_format=None,
# fraction of images reserved for validation (strictly between 0 and 1)
validation_split=0.0)
# Compute quantities required for feature-wise normalization
# (std, mean, and principal components if ZCA whitening is applied).
datagen.fit(x_train)
iter = 1
min_val_loss = 999
max_val_acc = -1
while True:
print("iter = {}".format(iter))
history = model.fit_generator(datagen.flow(x_train, y_train,
batch_size=batch_size),
epochs=1,
validation_data=(x_test, y_test),
workers=8,
use_multiprocessing=True)
iter = iter + 1
val_loss = history.history['val_loss'][0]
val_acc = history.history['val_acc'][0]
if val_loss < min_val_loss:
min_val_loss = val_loss
model_path = os.path.join(save_dir, "keras_cifar10_mvl{}_trained_model.h5".format(min_val_loss))
model.save(model_path)
if val_acc > max_val_acc:
max_val_acc = val_acc
model_path = os.path.join(save_dir, "keras_cifar10_mva{}_trained_model.h5".format(max_val_acc))
model.save(model_path)
print("val_loss = {} min_val_loss = {}".format(val_loss, min_val_loss))
print("val_acc = {} max_val_acc = {}".format(val_acc, max_val_acc))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment