Skip to content

Instantly share code, notes, and snippets.

@orwa-te
Last active November 15, 2020 05:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save orwa-te/aca04e79beec6644c19fb1a235f1f2d2 to your computer and use it in GitHub Desktop.
Save orwa-te/aca04e79beec6644c19fb1a235f1f2d2 to your computer and use it in GitHub Desktop.
Function submitted to Horovod Runner
def train_fn():
# Make sure pyarrow is referenced before anything else to avoid segfault due to conflict
# with TensorFlow libraries. Use `pa` package reference to ensure it's loaded before
# functions like `deserialize_model` which are implemented at the top level.
# See https://jira.apache.org/jira/browse/ARROW-3346
pa
# import atexit
import horovod.tensorflow.keras as hvd
import os
import tensorflow as tf
import tensorflow.keras.backend as K
import shutil
def build_model(shape = (128,128,4)):
# Add layers
#...
#...
return model
# Horovod: initialize Horovod inside the trainer.
hvd.init()
# Horovod: restore from checkpoint, use hvd.load_model under the hood.
model = build_model()
# Horovod: add Distributed Optimizer.
opt = tf.keras.optimizers.Adam(lr=args.learning_rate, epsilon=1e-3)
opt = hvd.DistributedOptimizer(opt)
model.compile(opt, 'categorical_crossentropy', metrics=['accuracy'])
# Horovod: adjust learning rate based on number of processes.
scaled_lr = K.get_value(model.optimizer.lr) * hvd.size()
K.set_value(model.optimizer.lr, scaled_lr)
# Horovod: print summary logs on the first worker.
verbose = 2 if hvd.rank() == 0 else 0
# Dataset shape are just like:
# trainx.shape = [256,128,128,4]
# trainy_hot.shape = [256,128,128,9]
result = model.fit(trainx,
trainy_hot,
steps_per_epoch=int(len(trainx) / args.batch_size / hvd.size()),
verbose=1,
epochs=args.epochs)
model.save('hvd_model.h5')
# Dataset API usage currently displays a wall of errors upon termination.
# This global model registration ensures clean termination.
# Tracked in https://github.com/tensorflow/tensorflow/issues/24570
globals()['_DATASET_FINALIZATION_HACK'] = model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment