Skip to content

Instantly share code, notes, and snippets.

@saurabhpal97
Last active May 31, 2019 07:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save saurabhpal97/ade27a6bab0ee72b56a00db34d929001 to your computer and use it in GitHub Desktop.
Save saurabhpal97/ade27a6bab0ee72b56a00db34d929001 to your computer and use it in GitHub Desktop.
from keras.models import Sequential
from keras.layers import Input,Conv2D,BatchNormalization,MaxPooling2D,Dropout,Activation,Flatten,Dense
from keras import regularizers
from keras import models
from keras.callbacks import ModelCheckpoint
#we have 10 classes in the dataset
num_classes = 10
#define the input
img_input = Input(shape=(32,32,3))
x = Conv2D(32, (3,3), padding='same', input_shape=x_train.shape[1:])(img_input)
x = Activation('elu')(x)
x = BatchNormalization()(x)
x = Conv2D(32, (3,3), padding='same')(x)
x = Activation('elu')(x)
x = BatchNormalization()(x)
x = MaxPooling2D(pool_size=(2,2))(x)
x = Dropout(0.2)(x)
x = Conv2D(64, (3,3), padding='same')(x)
x = Activation('elu')(x)
x = BatchNormalization()(x)
x = Conv2D(64, (3,3), padding='same')(x)
x = Activation('elu')(x)
x = BatchNormalization()(x)
x = MaxPooling2D(pool_size=(2,2))(x)
x = Dropout(0.3)(x)
x = Conv2D(128, (3,3), padding='same')(x)
x = Activation('elu')(x)
x = BatchNormalization()(x)
x = Conv2D(128, (3,3), padding='same')(x)
x = Activation('elu')(x)
x = BatchNormalization()(x)
x = MaxPooling2D(pool_size=(2,2))(x)
x = Dropout(0.4)(x)
x = Flatten()(x)
x = Dense(num_classes, activation='softmax')(x)
model = models.Model(img_input, x, name='CNN')
#define the name pattern for saving weights of models
filepath="weights-improvement-{epoch:02d}-{val_acc:.2f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]
#compile the model
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
#train the model
model.fit(x_train,y_train,validation_data=(x_val,y_val),epochs=150,callbacks=callbacks_list)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment