Created
September 6, 2017 03:39
-
-
Save Yevgnen/b7858fd7f6e9e137f029cf391faee5e2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
import tensorflow as tf | |
_INT = tf.int32 | |
_FLOAT = tf.float32 | |
class SkipThought(object): | |
def __init__(self, vocab_size, embedding_size, encoder_size, decoder_size): | |
inputs = tf.placeholder(_INT, [None, None], name='inputs') | |
sequence_length = tf.placeholder(_INT, [None], name='sequence_length') | |
previous_target_lengths = tf.placeholder( | |
_INT, [None], name='previous_target_lengths') | |
next_target_lengths = tf.placeholder( | |
_INT, [None], name='next_target_lengths') | |
previous_targets = tf.placeholder( | |
_INT, [None, None], name='previous_targets') | |
next_targets = tf.placeholder(_INT, [None, None], name='next_targets') | |
encoder_lengths, batch_size = tf.unstack(tf.shape(inputs)) | |
# Embedding | |
with tf.name_scope('embedding'): | |
embedding_matrix = tf.Variable( | |
# V x E | |
tf.random_uniform([vocab_size, embedding_size], -0.1, 0.1), | |
name='embedding_matrix') | |
# T x B x E | |
embedding_inputs = tf.nn.embedding_lookup(embedding_matrix, inputs) | |
# B | |
eos_time_slice = tf.ones([batch_size], dtype=tf.int32, name='EOS') | |
# B x E | |
eos_step_embedded = tf.nn.embedding_lookup(embedding_matrix, | |
eos_time_slice) | |
with tf.variable_scope('encoder'): | |
encoder_cell = tf.nn.rnn_cell.GRUCell(encoder_size) | |
# T x B x H1, B x H1 | |
encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn( | |
cell=encoder_cell, | |
inputs=embedding_inputs, | |
sequence_length=sequence_length, | |
dtype=_FLOAT, | |
time_major=True) | |
with tf.variable_scope('decoder'): | |
# H2 x V | |
output_weights = tf.get_variable( | |
'weights', | |
shape=[decoder_size, vocab_size], | |
initializer=tf.contrib.layers.xavier_initializer()) | |
# V | |
output_biases = tf.Variable( | |
tf.constant(0.0, shape=[vocab_size]), name='biases') | |
def decode(name, targets, target_lengths): | |
# decoder_lengths = target_lengths + 1 | |
def loop_fn_transition(time, cell_output, cell_state, | |
loop_state): | |
def get_next_input(): | |
output_logits = tf.add( | |
tf.matmul(cell_output, output_weights), | |
output_biases) | |
predictions = tf.argmax(output_logits, axis=1) | |
next_input = tf.nn.embedding_lookup( | |
embedding_matrix, predictions) | |
return next_input | |
elements_finished = (time >= target_lengths) | |
emit_output = cell_output | |
next_cell_state = cell_state | |
next_loop_state = None | |
return (elements_finished, get_next_input(), | |
next_cell_state, emit_output, next_loop_state) | |
def loop_fn(time, cell_output, cell_state, loop_state): | |
if cell_state is None: | |
elements_finished = (0 >= target_lengths) | |
next_input = eos_step_embedded | |
next_cell_state = encoder_final_state | |
emit_output = None | |
next_loop_state = None | |
return (elements_finished, next_input, next_cell_state, | |
emit_output, next_loop_state) | |
else: | |
return loop_fn_transition(time, cell_output, cell_state, | |
loop_state) | |
with tf.variable_scope(name): | |
decoder = tf.nn.rnn_cell.GRUCell(decoder_size) | |
decoder_outputs, _, _ = tf.nn.raw_rnn(decoder, loop_fn) | |
# T x B x H2 | |
decoder_outputs = decoder_outputs.stack() | |
(decoder_max_steps, decoder_batch_size, decoder_hidden_size | |
) = tf.unstack(tf.shape(decoder_outputs)) | |
# TB x H2 | |
decoder_outputs_flat = tf.reshape(decoder_outputs, | |
(-1, decoder_hidden_size)) | |
# TB x V | |
decoder_logits_flat = tf.add( | |
tf.matmul(decoder_outputs_flat, output_weights), | |
output_biases) | |
# T x B x V | |
decoder_logits = tf.reshape( | |
decoder_logits_flat, (decoder_max_steps, | |
decoder_batch_size, vocab_size)) | |
# T x B | |
decoder_predictions = tf.argmax(decoder_logits, axis=2) | |
loss = tf.nn.softmax_cross_entropy_with_logits( | |
labels=tf.one_hot( | |
targets, depth=vocab_size, dtype=_INT), | |
logits=decoder_logits) | |
loss = tf.reduce_mean(loss) | |
return (decoder, decoder_logits, decoder_predictions, loss) | |
(previous_decoder, previous_decoder_logits, | |
previous_decoder_predictions, previous_loss) = decode( | |
'previous_decoder', previous_targets, previous_target_lengths) | |
(next_decoder, next_decoder_logits, | |
next_decoder_predictions, next_loss) = decode( | |
'next_decoder', next_targets, next_target_lengths) | |
self.vocab_size = vocab_size | |
self.embedding_size = embedding_size | |
self.encoder_size = encoder_size | |
self.decoder_size = decoder_size | |
self.inputs = inputs | |
self.sequence_length = sequence_length | |
self.previous_targets = previous_targets | |
self.previous_target_lengths = previous_target_lengths | |
self.next_targets = next_targets | |
self.next_target_lengths = next_target_lengths | |
self.previous_decoder = previous_decoder | |
self.previous_decoder_logits = previous_decoder_logits | |
self.previous_decoder_predictions = previous_decoder_predictions | |
self.previous_loss = previous_loss | |
self.next_decoder = next_decoder | |
self.next_decoder_logits = next_decoder_logits | |
self.next_decoder_predictions = next_decoder_predictions | |
self.next_loss = next_loss | |
self.loss = self.previous_loss + self.next_loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment