Skip to content

Instantly share code, notes, and snippets.

@lxuechen
Created August 3, 2018 19:41
Show Gist options
  • Save lxuechen/782bb0efff2988a6bf2807b39768463d to your computer and use it in GitHub Desktop.
Save lxuechen/782bb0efff2988a6bf2807b39768463d to your computer and use it in GitHub Desktop.
wrapping in estimator
def model_fn(features, labels, mode, params):
model = RevNet(params["hyperparameters"])
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.MomentumOptimizer(learning_rate, momentum)
logits, saved_hidden = model(features, training=True)
grads, loss = model.compute_gradients(saved_hidden, labels, training=True)
with tf.control_dependencies(model.get_updates_for(features)):
train_op = optimizer.apply_gradients(zip(grads, model.trainable_variables))
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment