Skip to content

Instantly share code, notes, and snippets.

@hongthaiphi
Created September 21, 2016 15:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hongthaiphi/dce3019667bf2bd9cc1d66a29b86f4eb to your computer and use it in GitHub Desktop.
Save hongthaiphi/dce3019667bf2bd9cc1d66a29b86f4eb to your computer and use it in GitHub Desktop.
def dual_encoder_model(hparams, mode, context, context_len, utterance, utterance_len, targets):
# Initialize embedidngs randomly or with pre-trained vectors if available
embeddings_W = get_embeddings(hparams)
# Embed the context and the utterance
context_embedded = tf.nn.embedding_lookup(embeddings_W, context, name="embed_context")
utterance_embedded = tf.nn.embedding_lookup(embeddings_W, utterance, name="embed_utterance")
# Build the RNN
with tf.variable_scope("rnn") as vs:
# We use an LSTM Cell
cell = tf.nn.rnn_cell.LSTMCell(hparams.rnn_dim, forget_bias=2.0, use_peepholes=True, state_is_tuple=True)
# Run the utterance and context through the RNN
rnn_outputs, rnn_states = tf.nn.dynamic_rnn(cell, tf.concat(0, [context_embedded, utterance_embedded]),
sequence_length=tf.concat(0, [context_len, utterance_len]),
dtype=tf.float32)
encoding_context, encoding_utterance = tf.split(0, 2, rnn_states.h)
with tf.variable_scope("prediction") as vs:
M = tf.get_variable("M", shape=[hparams.rnn_dim, hparams.rnn_dim],
initializer=tf.truncated_normal_initializer())
# “Predict” a response: c * M
generated_response = tf.matmul(encoding_context, M)
generated_response = tf.expand_dims(generated_response, 2)
encoding_utterance = tf.expand_dims(encoding_utterance, 2)
# Dot product between generated response and actual response
# (c * M) * r
logits = tf.batch_matmul(generated_response, encoding_utterance, True)
logits = tf.squeeze(logits, [2])
# Apply sigmoid to convert logits to probabilities
probs = tf.sigmoid(logits)
# Calculate the binary cross-entropy loss
losses = tf.nn.sigmoid_cross_entropy_with_logits(logits, tf.to_float(targets))
# Mean loss across the batch of examples
mean_loss = tf.reduce_mean(losses, name="mean_loss")
return probs, mean_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment