Skip to content

Instantly share code, notes, and snippets.

@zredlined
Last active January 29, 2020 22:14
Show Gist options
  • Save zredlined/c747cc8eea0211f7143b394a64e6d051 to your computer and use it in GitHub Desktop.
Save zredlined/c747cc8eea0211f7143b394a64e6d051 to your computer and use it in GitHub Desktop.
# Define loss function
def loss(labels, logits):
return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
# Compile model
model.compile(optimizer='adam', loss=loss)
# Directory where the checkpoints will be saved
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_prefix,
save_weights_only=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment