Last active
November 15, 2020 05:17
-
-
Save orwa-te/aca04e79beec6644c19fb1a235f1f2d2 to your computer and use it in GitHub Desktop.
Function submitted to Horovod Runner
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train_fn(): | |
# Make sure pyarrow is referenced before anything else to avoid segfault due to conflict | |
# with TensorFlow libraries. Use `pa` package reference to ensure it's loaded before | |
# functions like `deserialize_model` which are implemented at the top level. | |
# See https://jira.apache.org/jira/browse/ARROW-3346 | |
pa | |
# import atexit | |
import horovod.tensorflow.keras as hvd | |
import os | |
import tensorflow as tf | |
import tensorflow.keras.backend as K | |
import shutil | |
def build_model(shape = (128,128,4)): | |
# Add layers | |
#... | |
#... | |
return model | |
# Horovod: initialize Horovod inside the trainer. | |
hvd.init() | |
# Horovod: restore from checkpoint, use hvd.load_model under the hood. | |
model = build_model() | |
# Horovod: add Distributed Optimizer. | |
opt = tf.keras.optimizers.Adam(lr=args.learning_rate, epsilon=1e-3) | |
opt = hvd.DistributedOptimizer(opt) | |
model.compile(opt, 'categorical_crossentropy', metrics=['accuracy']) | |
# Horovod: adjust learning rate based on number of processes. | |
scaled_lr = K.get_value(model.optimizer.lr) * hvd.size() | |
K.set_value(model.optimizer.lr, scaled_lr) | |
# Horovod: print summary logs on the first worker. | |
verbose = 2 if hvd.rank() == 0 else 0 | |
# Dataset shape are just like: | |
# trainx.shape = [256,128,128,4] | |
# trainy_hot.shape = [256,128,128,9] | |
result = model.fit(trainx, | |
trainy_hot, | |
steps_per_epoch=int(len(trainx) / args.batch_size / hvd.size()), | |
verbose=1, | |
epochs=args.epochs) | |
model.save('hvd_model.h5') | |
# Dataset API usage currently displays a wall of errors upon termination. | |
# This global model registration ensures clean termination. | |
# Tracked in https://github.com/tensorflow/tensorflow/issues/24570 | |
globals()['_DATASET_FINALIZATION_HACK'] = model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment