Skip to content

Instantly share code, notes, and snippets.

@hadifar
Created January 5, 2019 14:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hadifar/49e0f6e70eec5adf01a4aff2a20b228b to your computer and use it in GitHub Desktop.
Save hadifar/49e0f6e70eec5adf01a4aff2a20b228b to your computer and use it in GitHub Desktop.
def model_fn(features, labels, mode):
logits = neural_net_model(features, mode)
class_prediction = tf.argmax(logits, axis=-1)
preds = class_prediction
loss = None
train_op = None
eval_metric_ops = {}
if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.TRAIN):
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.cast(labels, dtype=tf.int32), logits=logits))
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = tf.train.AdamOptimizer().minimize(loss, global_step=tf.train.get_global_step())
if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = {
'accuracy': tf.metrics.accuracy(
labels=labels,
predictions=preds,
name='accuracy')
}
return tf.estimator.EstimatorSpec(mode=mode,
predictions=class_prediction,
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment