Skip to content

Instantly share code, notes, and snippets.

@eileen-code4fun
Created May 26, 2021 03:20
Show Gist options
  • Save eileen-code4fun/48aa1365c279d62f562ec9c914224968 to your computer and use it in GitHub Desktop.
Save eileen-code4fun/48aa1365c279d62f562ec9c914224968 to your computer and use it in GitHub Desktop.
CIFAR10 Model on Cloud
# A distributed strategy to take advantage of available hardward.
# No-op otherwise.
mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
model = create_model()
# Restore from the latest checkpoint if available.
latest_ckpt = tf.train.latest_checkpoint(GCS_PATH_FOR_CHECKPOINTS)
if latest_ckpt:
model.load_weights(latest_ckpt)
# Create a callback to store a check at the end of each epoch.
ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=GCS_PATH_FOR_CHECKPOINTS + CHECKPOINTS_PREFIX,
monitor='val_loss',
save_weights_only=True
)
model.fit(train_dataset, epochs=EPOCHS, validation_data=val_dataset, callbacks=[ckpt_callback])
# Export the model to GCS.
model.save(GCS_PATH_FOR_SAVED_MODEL)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment