Skip to content

Instantly share code, notes, and snippets.

@ispamm
Created November 27, 2018 15:44
Show Gist options
  • Save ispamm/fabca7f172ba7c83a50425e568ff72f4 to your computer and use it in GitHub Desktop.
Save ispamm/fabca7f172ba7c83a50425e568ff72f4 to your computer and use it in GitHub Desktop.
def multi_head_cnn_model_fn(features, labels, mode):
# Extract the features
dense = extract_features(features)
# Predictions for each task
predictions_nose = tf.layers.dense(inputs=dense, units=2)
predictions_pose = tf.layers.dense(inputs=dense, units=5)
logits = {'head_nose': predictions_nose, 'head_pose': predictions_pose}
# Optimizer (for both tasks simultaneously)
optimizer = tf.train.AdamOptimizer()
# Two heads
regression_head = tf.contrib.estimator.regression_head(name='head_nose', label_dimension=2)
classification_head = tf.contrib.estimator.multi_class_head(name='head_pose', n_classes=5)
# Multi-head combining two single heads
multi_head = tf.contrib.estimator.multi_head([regression_head, classification_head])
# Return the final model
return multi_head.create_estimator_spec(features, mode, logits, labels, optimizer)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment