Skip to content

Instantly share code, notes, and snippets.

@netsatsawat
Created August 24, 2020 11:53
Show Gist options
  • Save netsatsawat/6d10fed8be0309c8068efd3029eb0b28 to your computer and use it in GitHub Desktop.
Save netsatsawat/6d10fed8be0309c8068efd3029eb0b28 to your computer and use it in GitHub Desktop.
Snippet of CNN model with data augmentation implementation
train_gen = tf.keras.preprocessing.image.ImageDataGenerator(rotation_range=40, shear_range=0.2, zoom_range=0.2,
horizontal_flip=True, vertical_flip=True, rescale=1./255.,
validation_split=0.2)
train_generator = train_gen.flow_from_directory(TRAIN_DIR, target_size=IMG_SIZE, batch_size=32,
class_mode='categorical', subset='training')
valid_generator = train_gen.flow_from_directory(TRAIN_DIR, target_size=IMG_SIZE, batch_size=32,
class_mode='categorical', subset='validation')
cnn_model2 = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation=tf.nn.relu),
tf.keras.layers.Dense(N_CLASS, activation=tf.nn.softmax)
])
cnn_model2.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.losses.CategoricalCrossentropy(), metrics=['accuracy'])
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, mode='min')
cnn_hist2 = cnn_model2.fit_generator(
train_generator,
validation_data=valid_generator,
epochs=500,
callbacks=[tfdocs.modeling.EpochDots(), early_stopping],
verbose=0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment