Skip to content

Instantly share code, notes, and snippets.

@orwa-te
Created August 14, 2020 14:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save orwa-te/ff60040f6628323dfbf44b17805ac064 to your computer and use it in GitHub Desktop.
Save orwa-te/ff60040f6628323dfbf44b17805ac064 to your computer and use it in GitHub Desktop.
Code snippet of PySpark program in 2 worker nodes
................
................
def main_fun(args, ctx):
batch_size=32
print(len(trainx)) # -----> 672
# 672/32 = 21
#Create distribute strategy
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
steps_per_epoch = 21 # 672/32 === 21
#scale input data
trainx = trainx.astype(np.float32) / np.max(trainx)
train_dataset = tf.data.Dataset.from_tensor_slices((trainx, trainy_hot))
...................
...................
#set options for sharding policy
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
train_datasets_unbatched = train_dataset.with_options(options).shuffle(len(trainx))
train_dataset=train_datasets_unbatched.batch(batch_size)
with strategy.scope():
multi_worker_model = build_unet_model()
history = multi_worker_model.fit(train_dataset, epochs=10, verbose=1, steps_per_epoch=steps_per_epoch)
# end of main_fun(args, ctx).............
if __name__ == '__main__':
import argparse
from pyspark.context import SparkContext
from pyspark.conf import SparkConf
from tensorflowonspark import TFCluster
sc = SparkContext(conf=SparkConf().setAppName("unet_train_keras").set('spark.cores.max',8))
.............
.............
............
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, num_ps=0, tensorboard=args.tensorboard, input_mode=TFCluster.InputMode.TENSORFLOW, master_node='chief', log_dir=args.model_dir)
cluster.shutdown()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment