Skip to content

Instantly share code, notes, and snippets.

@zmjjmz
Created August 22, 2018 21:28
Show Gist options
  • Save zmjjmz/b84248a8a1ff21afd1c172e0abbba2e8 to your computer and use it in GitHub Desktop.
Save zmjjmz/b84248a8a1ff21afd1c172e0abbba2e8 to your computer and use it in GitHub Desktop.
Example of conditional outputs for estimators (partial)
# just the fn from model_store that takes:
model_builder = model_class(model_artifacts, **model_parameters)
model_builder.build_model(inp_placeholder)
# dict of tensors like {'softmax':softmax_layer, 'oov_code':oov_code}
tensors = model_builder.give_outputs()
if mode == tensorflow.estimator.ModeKeys.PREDICT:
return tensorflow.estimator.EstimatorSpec(
mode=mode,
predictions=tensors,
export_outputs={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: PredictOutput(tensors)
},
)
loss_tensor = tensorflow.losses.softmax_cross_entropy(
labels, tensors['softmax'])
tensorflow.summary.scalar('{0}_loss'.format(mode), loss_tensor)
train_op = tensorflow.contrib.layers.optimize_loss(
loss=loss_tensor,
global_step=tensorflow.train.get_global_step(),
optimizer='sgd',
learning_rate=0.01
)
# since we're already assuming softmax & cross entropy, might as well just hardcode in
# acc as the eval metric
# these need to be converted to non-categorical!
eval_metrics_ops = {
"accuracy": tensorflow.metrics.accuracy(
tensorflow.argmax(labels, axis=1),
tensorflow.argmax(tensors['softmax'], axis=1))
}
tensorflow.summary.scalar('{0}_accuracy'.format(mode), eval_metrics_ops[
'accuracy'][1]) # get the update op
if mode in (tensorflow.estimator.ModeKeys.TRAIN,
tensorflow.estimator.ModeKeys.EVAL):
return tensorflow.estimator.EstimatorSpec(
mode=mode,
loss=loss_tensor,
train_op=train_op,
eval_metric_ops=eval_metrics_ops,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment