Skip to content

Instantly share code, notes, and snippets.

@jeanmidevacc
Created April 14, 2020 01:15
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save jeanmidevacc/4a4348ee94c625de0c905a2992d6fea5 to your computer and use it in GitHub Desktop.
Save jeanmidevacc/4a4348ee94c625de0c905a2992d6fea5 to your computer and use it in GitHub Desktop.
import tensorflow as tf
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(
32, (3, 3), padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Dropout(rate=0.5),
tf.keras.layers.Conv2D(
64, (3, 3), padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Dropout(rate=0.5),
tf.keras.layers.Conv2D(
128, (3, 3), padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(
256, (3, 3), padding='same', activation='relu'),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(len(CLASS_NAMES), activation=None)
])
model.build(input_shape=(None, IMG_HEIGHT, IMG_WIDTH, 3))
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
log_dir = "logs\\fit\\" + 'pml_' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
history = model.fit(
train_data_gen,
steps_per_epoch=image_count_train // BATCH_SIZE,
epochs=EPOCHS,
validation_data=val_data_gen,
validation_steps=image_count_validation // BATCH_SIZE,
callbacks=[tensorboard_callback]
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment