Skip to content

Instantly share code, notes, and snippets.

@igormq
Forked from nikitakit/tf_beam_decoder.py
Created June 27, 2016 21:03
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save igormq/000add00702f09029ea4c30eba976e0a to your computer and use it in GitHub Desktop.
Save igormq/000add00702f09029ea4c30eba976e0a to your computer and use it in GitHub Desktop.
Tensorflow Beam Search
import tensorflow as tf
def beam_decoder(decoder_inputs, initial_state, cell, loop_function, scope=None,
beam_size=7, done_token=0
):
"""
Beam search decoder
Args:
decoder_inputs: A list of 2D Tensors [batch_size x input_size].
initial_state: 2D Tensor with shape [batch_size x cell.state_size].
cell: rnn_cell.RNNCell defining the cell function and size.
loop_function: This function will be applied to the i-th output
in order to generate the i+1-st input, and decoder_inputs will be ignored,
except for the first element ("GO" symbol).
Signature -- loop_function(prev_symbol, i) = next
* prev_symbol is a 1D Tensor of shape [batch_size*beam_size]
* i is an integer, the step number (when advanced control is needed),
* next is a 2D Tensor of shape [batch_size*beam_size, input_size].
scope: Passed to seq2seq.rnn_decoder
beam_size: An integer beam size to use for each example
done_token: An integer token that specifies the STOP symbol
Return:
A tensor of dimensions [batch_size, len(decoder_inputs)] that corresponds to
the 1-best beam for each batch.
Known limitations:
* The output sequence consisting of only a STOP symbol is not considered
(zero-length sequences are not very useful, so this wasn't implemented)
* The computation graph this creates is messy and not very well-optimized
"""
decoder = BeamDecoder(decoder_inputs, initial_state, beam_size=beam_size, done_token=done_token)
_ = seq2seq.rnn_decoder(
decoder.decoder_inputs,
decoder.initial_state,
cell=cell,
loop_function = lambda prev, i: loop_function(decoder.take_step(prev, i), i),
scope=scope
)
return decoder.finished_beams
class BeamDecoder():
"""
Main class for implementing beam decoder.
"""
def __init__(self, decoder_inputs, initial_state, beam_size=7, done_token=0,
batch_size=None, num_classes=None):
self.beam_size = beam_size
self.batch_size = batch_size
if self.batch_size is None:
self.batch_size = tf.shape(decoder_inputs[0])[0]
self.max_len = len(decoder_inputs)
self.num_classes = num_classes
self.done_token = done_token
self.past_logprobs = None
self.past_symbols = None
self.finished_beams = tf.zeros((self.batch_size, self.max_len), dtype=tf.int32)
self.logprobs_finished_beams = tf.ones((self.batch_size,), dtype=tf.float32) * -float('inf')
self.decoder_inputs = [None] * len(decoder_inputs)
self.decoder_inputs[0] = self.tile_along_beam(initial_input)
# Convert the state input to the decoder
if isinstance(initial_state, tf.nn.rnn_cell.LSTMStateTuple):
self.initial_state = tf.nn.rnn_cell.LSTMStateTuple(
c=self.tile_along_beam(initial_state.c),
h=self.tile_along_beam(initial_state.h)
)
else:
self.initial_state = self.tile_along_beam(initial_state)
def tile_along_beam(self, tensor):
"""
Helps tile tensors for each beam.
Args:
tensor: a 2-D tensor, [batch_size x T]
Return:
An [batch_size*beam_size x T] tensor, where each row of the input
tensor is copied beam_size times in a row in the output
"""
res = tf.expand_dims(tensor, 1)
res = tf.tile(res, [1, self.beam_size, 1])
res = tf.reshape(res, [-1, tf.shape(tensor)[1]])
try:
new_first_dim = tensor.get_shape()[0] * self.beam_size
except:
new_first_dim = None
res.set_shape((new_first_dim, tensor.get_shape()[1]))
return res
def take_step(self, prev, i):
logprobs = tf.nn.log_softmax(prev)
if self.num_classes is None:
self.num_classes = tf.shape(logprobs)[1]
logprobs_batched = tf.reshape(logprobs, [-1, self.beam_size, self.num_classes])
logprobs_batched.set_shape((None, self.beam_size, None))
# Note: masking out entries to -inf plays poorly with top_k, so just subtract out
# a large number.
nondone_mask = tf.reshape(
tf.cast(tf.equal(tf.range(self.num_classes), self.done_token), tf.float32) * -1e18,
[1, 1, self.num_classes]
)
if self.past_logprobs is not None:
logprobs_batched = logprobs_batched + tf.expand_dims(self.past_logprobs, 2)
self.past_logprobs, indices = tf.nn.top_k(
tf.reshape(logprobs_batched + nondone_mask, [-1, self.beam_size * self.num_classes]),
self.beam_size
)
else:
self.past_logprobs, indices = tf.nn.top_k(
(logprobs_batched + nondone_mask)[:,0,:],
self.beam_size
)
# For continuing to the next symbols
symbols = indices % self.num_classes
parent_refs = indices // self.num_classes
if self.past_symbols is not None:
parent_refs_offsets = tf.reshape(
(tf.range(self.batch_size * self.beam_size) // self.beam_size) * self.beam_size,
[self.batch_size, self.beam_size]
)
past_symbols_batch_major = tf.reshape(self.past_symbols, [-1, i-1])
beam_past_symbols = tf.gather(past_symbols_batch_major, #batch-major
parent_refs + parent_refs_offsets)
self.past_symbols = tf.concat(2, [beam_past_symbols, tf.expand_dims(symbols, 2)])
# For finishing the beam here
logprobs_done = logprobs_batched[:,:,self.done_token]
done_parent_refs = tf.cast(tf.argmax(logprobs_done, 1), tf.int32)
done_parent_refs_offsets = tf.range(self.batch_size) * self.beam_size
done_past_symbols = tf.gather(past_symbols_batch_major,
done_parent_refs + done_parent_refs_offsets
)
symbols_done = tf.concat(1, [done_past_symbols,
tf.ones_like(done_past_symbols[:,0:1]) * self.done_token,
tf.tile(tf.zeros_like(done_past_symbols[:,0:1]),
[1, self.max_len - i])
])
logprobs_done_max = tf.reduce_max(logprobs_done, 1)
self.finished_beams = tf.select(logprobs_done_max > self.logprobs_finished_beams,
symbols_done,
self.finished_beams)
self.logprobs_finished_beams = tf.maximum(logprobs_done_max, self.logprobs_finished_beams)
else:
self.past_symbols = tf.expand_dims(symbols, 2)
# NOTE: outputing a zero-length sequence is not supported for simplicity reasons
symbols_flat = tf.reshape(symbols, [-1])
return symbols_flat
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment