Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
def map_fun(args, ctx):
try:
import tensorflow as tf
#utils
from datetime import datetime
import time
import logging
import numpy as np
logger = logging.getLogger()
tf.logging.set_verbosity(tf.logging.DEBUG)
worker_num = ctx.worker_num
job_name = ctx.job_name
task_index = ctx.task_index
cluster, server = ctx.start_cluster_server(1)
#TFNode.start_cluster_server(ctx)
def get_next_batch(batch):
batch = np.array(batch)
data = batch[:,2:-1].reshape((batch.shape[0],timesteps,num_features))
labels = batch[:,-1].astype(int)
return data,to_categorical(labels,num_classes=num_classes)
if job_name == "ps":
server.join()
elif job_name == "worker":
#https://www.tensorflow.org/api_docs/python/tf/train/Supervisor
#one task should be identified as chief. This is necessary to handle for exmaple initialization
is_chiefing = (task_index == 0)
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % task_index,
cluster=cluster)):
def build_model():
pass
model_input,\
model_labels,\
model_output,\
tf_global_step,\
tf_loss,\
tf_optimizer,\
tf_metrics = build_model()
hooks=[tf.train.StepCounterHook()]
with tf.train.MonitoredTrainingSession(master=server.target,\
is_chief=is_chiefing,
checkpoint_dir=arsg['save_dir'],\
hooks=hooks,\
save_checkpoint_secs=600.) as mon_sess:
start_time = datetime.now()
tf.logging.info("{0} session ready".format(start_time.isoformat()))
#https://github.com/yahoo/TensorFlowOnSpark/blob/master/tensorflowonspark/TFSparkNode.py
# see TFNODE https://github.com/yahoo/TensorFlowOnSpark/blob/master/tensorflowonspark/TFNode.py
tf_feed = ctx.get_data_feed(train_mode=True)
step = 0
while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args['steps']:
batch_data, batch_labels = get_next_batch(tf_feed.next_batch(args['batch_size']))
if len(batch_data) > 0:
feed = {model_input: batch_data, model_labels: batch_labels}
_, logloss, step = mon_sess.run([tf_optimizer, tf_loss,tf_global_step],feed_dict=feed)
if mon_sess.should_stop() or step >= args['steps']:
tf_feed.terminate()
logger.info("{0} stopping supervisor".format(datetime.now().isoformat()))
except Exception as e:
logger.error(e)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment