Skip to content

Instantly share code, notes, and snippets.

@formigone
Created January 27, 2018 18:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save formigone/c5634c9bf11c8a7034188b0b6f0eff8f to your computer and use it in GitHub Desktop.
Save formigone/c5634c9bf11c8a7034188b0b6f0eff8f to your computer and use it in GitHub Desktop.
A simple RNN model using Tensorflow 1.4 Estimator API
def model_fn(features, labels, mode, params):
word_vec = embed_sequence(features, vocab_size=params['vocab_size'], embed_dim=params['embed_size'])
# [-1, seq_len, embed_size]
word_list = tf.unstack(word_vec, axis=1)
# [-1, embed_size]
cell = tf.nn.rnn_cell.GRUCell(params['embed_size'])
_, encoding = tf.nn.static_rnn(cell, word_list, dtype=tf.float32)
logits = tf.layers.dense(encoding, params['num_classes'], activation=None)
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=logits)
cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(
labels=labels,
logits=logits
)
loss = tf.reduce_mean(cross_entropy)
tf.summary.scalar('loss', loss)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=params['learning_rate'])
train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
eval_metric_ops = {
'accuracy': tf.metrics.accuracy(labels=labels, predictions=tf.round(logits)),
}
tf.summary.scalar('accuracy', eval_metric_ops['accuracy'][1])
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment