Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import tensorflow as tf
_INT = tf.int32
_FLOAT = tf.float32
class SkipThought(object):
def __init__(self, vocab_size, embedding_size, encoder_size, decoder_size):
inputs = tf.placeholder(_INT, [None, None], name='inputs')
sequence_length = tf.placeholder(_INT, [None], name='sequence_length')
previous_target_lengths = tf.placeholder(
_INT, [None], name='previous_target_lengths')
next_target_lengths = tf.placeholder(
_INT, [None], name='next_target_lengths')
previous_targets = tf.placeholder(
_INT, [None, None], name='previous_targets')
next_targets = tf.placeholder(_INT, [None, None], name='next_targets')
encoder_lengths, batch_size = tf.unstack(tf.shape(inputs))
# Embedding
with tf.name_scope('embedding'):
embedding_matrix = tf.Variable(
# V x E
tf.random_uniform([vocab_size, embedding_size], -0.1, 0.1),
name='embedding_matrix')
# T x B x E
embedding_inputs = tf.nn.embedding_lookup(embedding_matrix, inputs)
# B
eos_time_slice = tf.ones([batch_size], dtype=tf.int32, name='EOS')
# B x E
eos_step_embedded = tf.nn.embedding_lookup(embedding_matrix,
eos_time_slice)
with tf.variable_scope('encoder'):
encoder_cell = tf.nn.rnn_cell.GRUCell(encoder_size)
# T x B x H1, B x H1
encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(
cell=encoder_cell,
inputs=embedding_inputs,
sequence_length=sequence_length,
dtype=_FLOAT,
time_major=True)
with tf.variable_scope('decoder'):
# H2 x V
output_weights = tf.get_variable(
'weights',
shape=[decoder_size, vocab_size],
initializer=tf.contrib.layers.xavier_initializer())
# V
output_biases = tf.Variable(
tf.constant(0.0, shape=[vocab_size]), name='biases')
def decode(name, targets, target_lengths):
# decoder_lengths = target_lengths + 1
def loop_fn_transition(time, cell_output, cell_state,
loop_state):
def get_next_input():
output_logits = tf.add(
tf.matmul(cell_output, output_weights),
output_biases)
predictions = tf.argmax(output_logits, axis=1)
next_input = tf.nn.embedding_lookup(
embedding_matrix, predictions)
return next_input
elements_finished = (time >= target_lengths)
emit_output = cell_output
next_cell_state = cell_state
next_loop_state = None
return (elements_finished, get_next_input(),
next_cell_state, emit_output, next_loop_state)
def loop_fn(time, cell_output, cell_state, loop_state):
if cell_state is None:
elements_finished = (0 >= target_lengths)
next_input = eos_step_embedded
next_cell_state = encoder_final_state
emit_output = None
next_loop_state = None
return (elements_finished, next_input, next_cell_state,
emit_output, next_loop_state)
else:
return loop_fn_transition(time, cell_output, cell_state,
loop_state)
with tf.variable_scope(name):
decoder = tf.nn.rnn_cell.GRUCell(decoder_size)
decoder_outputs, _, _ = tf.nn.raw_rnn(decoder, loop_fn)
# T x B x H2
decoder_outputs = decoder_outputs.stack()
(decoder_max_steps, decoder_batch_size, decoder_hidden_size
) = tf.unstack(tf.shape(decoder_outputs))
# TB x H2
decoder_outputs_flat = tf.reshape(decoder_outputs,
(-1, decoder_hidden_size))
# TB x V
decoder_logits_flat = tf.add(
tf.matmul(decoder_outputs_flat, output_weights),
output_biases)
# T x B x V
decoder_logits = tf.reshape(
decoder_logits_flat, (decoder_max_steps,
decoder_batch_size, vocab_size))
# T x B
decoder_predictions = tf.argmax(decoder_logits, axis=2)
loss = tf.nn.softmax_cross_entropy_with_logits(
labels=tf.one_hot(
targets, depth=vocab_size, dtype=_INT),
logits=decoder_logits)
loss = tf.reduce_mean(loss)
return (decoder, decoder_logits, decoder_predictions, loss)
(previous_decoder, previous_decoder_logits,
previous_decoder_predictions, previous_loss) = decode(
'previous_decoder', previous_targets, previous_target_lengths)
(next_decoder, next_decoder_logits,
next_decoder_predictions, next_loss) = decode(
'next_decoder', next_targets, next_target_lengths)
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.encoder_size = encoder_size
self.decoder_size = decoder_size
self.inputs = inputs
self.sequence_length = sequence_length
self.previous_targets = previous_targets
self.previous_target_lengths = previous_target_lengths
self.next_targets = next_targets
self.next_target_lengths = next_target_lengths
self.previous_decoder = previous_decoder
self.previous_decoder_logits = previous_decoder_logits
self.previous_decoder_predictions = previous_decoder_predictions
self.previous_loss = previous_loss
self.next_decoder = next_decoder
self.next_decoder_logits = next_decoder_logits
self.next_decoder_predictions = next_decoder_predictions
self.next_loss = next_loss
self.loss = self.previous_loss + self.next_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment