Created
May 26, 2021 03:20
-
-
Save eileen-code4fun/48aa1365c279d62f562ec9c914224968 to your computer and use it in GitHub Desktop.
CIFAR10 Model on Cloud
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# A distributed strategy to take advantage of available hardward. | |
# No-op otherwise. | |
mirrored_strategy = tf.distribute.MirroredStrategy() | |
with mirrored_strategy.scope(): | |
model = create_model() | |
# Restore from the latest checkpoint if available. | |
latest_ckpt = tf.train.latest_checkpoint(GCS_PATH_FOR_CHECKPOINTS) | |
if latest_ckpt: | |
model.load_weights(latest_ckpt) | |
# Create a callback to store a check at the end of each epoch. | |
ckpt_callback = tf.keras.callbacks.ModelCheckpoint( | |
filepath=GCS_PATH_FOR_CHECKPOINTS + CHECKPOINTS_PREFIX, | |
monitor='val_loss', | |
save_weights_only=True | |
) | |
model.fit(train_dataset, epochs=EPOCHS, validation_data=val_dataset, callbacks=[ckpt_callback]) | |
# Export the model to GCS. | |
model.save(GCS_PATH_FOR_SAVED_MODEL) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment