Skip to content

Instantly share code, notes, and snippets.

@netsatsawat
Created August 20, 2020 11:36
Show Gist options
  • Save netsatsawat/e5de36ee69a800f189eaa9b41611c5c4 to your computer and use it in GitHub Desktop.
Save netsatsawat/e5de36ee69a800f189eaa9b41611c5c4 to your computer and use it in GitHub Desktop.
Snippet code for basic CNN implementation using tensorflow
import tensorflow as tf
import tensorflow_docs as tfdocs
cnn_model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=train_imgs.shape[1: 4]),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation=tf.nn.relu),
tf.keras.layers.Dense(N_CLASS, activation=tf.nn.softmax)
])
cnn_model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, mode='min')
cnn_hist = cnn_model.fit(
train_imgs, train_labels, epochs=500, validation_split=0.2, verbose=0,
batch_size=64, callbacks=[tfdocs.modeling.EpochDots(), early_stopping]
)
# plot the performance during each epochs
plt.figure(figsize=(11, 7))
plotter = tfdocs.plots.HistoryPlotter()
plotter.plot({'Basic': cnn_hist}, metric='accuracy')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment