Skip to content

Instantly share code, notes, and snippets.

@ispamm
Created November 27, 2018 15:19
Show Gist options
  • Save ispamm/69ae87e88ca8f6c006fc3d7cca88ecb0 to your computer and use it in GitHub Desktop.
Save ispamm/69ae87e88ca8f6c006fc3d7cca88ecb0 to your computer and use it in GitHub Desktop.
# Adapted from here: https://www.tensorflow.org/tutorials/layers
def single_task_cnn_model_fn(features, labels, mode):
# Get features
dense = extract_features(features)
# Make predictions
predictions = tf.layers.dense(inputs=dense, units=2)
outputs = {
"predictions": predictions
}
# We just want the predictions
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=outputs)
# If not in mode.PREDICT, compute the loss (mean squared error)
loss = tf.losses.mean_squared_error(labels=labels[:, 2:8:5], predictions=predictions)
# Single optimization step
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.AdamOptimizer()
train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
# If not PREDICT or TRAIN, then we are evaluating the model
eval_metric_ops = {
"rmse": tf.metrics.root_mean_squared_error(
labels=labels[:, 2:8:5], predictions=outputs["predictions"])}
return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment