Created
August 22, 2018 21:28
-
-
Save zmjjmz/b84248a8a1ff21afd1c172e0abbba2e8 to your computer and use it in GitHub Desktop.
Example of conditional outputs for estimators (partial)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# just the fn from model_store that takes: | |
model_builder = model_class(model_artifacts, **model_parameters) | |
model_builder.build_model(inp_placeholder) | |
# dict of tensors like {'softmax':softmax_layer, 'oov_code':oov_code} | |
tensors = model_builder.give_outputs() | |
if mode == tensorflow.estimator.ModeKeys.PREDICT: | |
return tensorflow.estimator.EstimatorSpec( | |
mode=mode, | |
predictions=tensors, | |
export_outputs={ | |
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: PredictOutput(tensors) | |
}, | |
) | |
loss_tensor = tensorflow.losses.softmax_cross_entropy( | |
labels, tensors['softmax']) | |
tensorflow.summary.scalar('{0}_loss'.format(mode), loss_tensor) | |
train_op = tensorflow.contrib.layers.optimize_loss( | |
loss=loss_tensor, | |
global_step=tensorflow.train.get_global_step(), | |
optimizer='sgd', | |
learning_rate=0.01 | |
) | |
# since we're already assuming softmax & cross entropy, might as well just hardcode in | |
# acc as the eval metric | |
# these need to be converted to non-categorical! | |
eval_metrics_ops = { | |
"accuracy": tensorflow.metrics.accuracy( | |
tensorflow.argmax(labels, axis=1), | |
tensorflow.argmax(tensors['softmax'], axis=1)) | |
} | |
tensorflow.summary.scalar('{0}_accuracy'.format(mode), eval_metrics_ops[ | |
'accuracy'][1]) # get the update op | |
if mode in (tensorflow.estimator.ModeKeys.TRAIN, | |
tensorflow.estimator.ModeKeys.EVAL): | |
return tensorflow.estimator.EstimatorSpec( | |
mode=mode, | |
loss=loss_tensor, | |
train_op=train_op, | |
eval_metric_ops=eval_metrics_ops, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment