Skip to content

Instantly share code, notes, and snippets.

@benoitdescamps
Last active May 10, 2019 01:14
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save benoitdescamps/d5418ab69d7b8a631b53e0cb505c572c to your computer and use it in GitHub Desktop.
Save benoitdescamps/d5418ab69d7b8a631b53e0cb505c572c to your computer and use it in GitHub Desktop.
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % task_index,
cluster=cluster)):
def build_model():
model_input = tf.placeholder(tf.float32,\
[None,args['num_features'] ])
model_labels = tf.placeholder(tf.float32, [None, args['num_classes'] ])
logits = tf.keras.layers.Dense(args['num_classes'])(model_input)
model_output = tf.nn.softmax(logits)
tf_global_step = tf.train.get_or_create_global_step()
tf_loss = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(logits=logits, targets=model_labels,pos_weight=weights))
tf_optimizer = tf.train.AdamOptimizer(learning_rate=args['learning_rate']).minimize(tf_loss,
global_step=tf.train.get_global_step())
model_input,\
model_labels,\
model_output,\
tf_global_step,\
tf_loss,\
tf_optimizer,\
tf_metrics = build_model()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment