Skip to content

Instantly share code, notes, and snippets.

@bredfern
Forked from nikitakit/tf_beam_decoder.py
Created November 7, 2016 19:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bredfern/82b6709752ed4503a3a9cf7781ad011b to your computer and use it in GitHub Desktop.
Save bredfern/82b6709752ed4503a3a9cf7781ad011b to your computer and use it in GitHub Desktop.
Tensorflow Beam Search
"""
Beam decoder for tensorflow
Sample usage:
```
beam_decoder = BeamDecoder(NUM_CLASSES, beam_size=10, max_len=MAX_LEN)
_, final_state = tf.nn.seq2seq.rnn_decoder(
[beam_decoder.wrap_input(initial_input)] + [None] * (MAX_LEN - 1),
beam_decoder.wrap_state(initial_state),
beam_decoder.wrap_cell(my_cell),
loop_function = lambda prev_symbol, i: tf.nn.embedding_lookup(my_embedding, prev_symbol)
)
best_dense = beam_decoder.unwrap_output_dense(final_state) # Dense tensor output, right-aligned
best_sparse = beam_decoder.unwrap_output_sparse(final_state) # Output, this time as a sparse tensor
```
"""
import tensorflow as tf
try:
from tensorflow.python.util import nest
except ImportError:
# Backwards-compatibility
from tensorflow.python.ops import rnn_cell
class NestModule(object): pass
nest = NestModule()
nest.is_sequence = rnn_cell._is_sequence
nest.flatten = rnn_cell._unpacked_state
nest.pack_sequence_as = rnn_cell._packed_state
def nest_map(func, nested):
if not nest.is_sequence(nested):
return func(nested)
flat = nest.flatten(nested)
return nest.pack_sequence_as(nested, list(map(func, flat)))
class BeamDecoder(object):
def __init__(self, num_classes, stop_token=0, beam_size=7, max_len=20):
"""
num_classes: int. Number of output classes used
stop_token: int.
beam_size: int.
max-len: int or scalar Tensor. If this cell is called recurrently more
than max_len times in a row, the outputs will not be valid!
"""
self.num_classes = num_classes
self.stop_token = stop_token
self.beam_size = beam_size
self.max_len = max_len
@classmethod
def _tile_along_beam(cls, beam_size, state):
if nest.is_sequence(state):
return nest_map(
lambda val: cls._tile_along_beam(beam_size, val),
state
)
if not isinstance(state, tf.Tensor):
raise ValueError("State should be a sequence or tensor")
tensor = state
tensor_shape = tensor.get_shape().with_rank_at_least(1)
try:
new_first_dim = tensor_shape[0] * beam_size
except:
new_first_dim = None
dynamic_tensor_shape = tf.unpack(tf.shape(tensor))
res = tf.expand_dims(tensor, 1)
res = tf.tile(res, [1, beam_size] + [1] * (tensor_shape.ndims-1))
res = tf.reshape(res, [-1] + list(dynamic_tensor_shape[1:]))
res.set_shape([new_first_dim] + list(tensor_shape[1:]))
return res
def wrap_cell(self, cell):
"""
Wraps a cell for use with the beam decoder
"""
return BeamDecoderCellWrapper(cell, self.num_classes, self.max_len, self.stop_token, self.beam_size)
def wrap_state(self, state):
dummy = BeamDecoderCellWrapper(None, self.num_classes, self.max_len, self.stop_token, self.beam_size)
if nest.is_sequence(state):
batch_size = tf.shape(nest.flatten(state)[0])[0]
dtype = nest.flatten(state)[0].dtype
else:
batch_size = tf.shape(state)[0]
dtype = state.dtype
return dummy._create_state(batch_size, dtype, cell_state=state)
def wrap_input(self, input):
"""
Wraps an input for use with the beam decoder.
Should be used for the initial input at timestep zero, as well as any side-channel
inputs that are per-batch (e.g. attention targets)
"""
return self._tile_along_beam(self.beam_size, input)
def unwrap_output_dense(self, final_state, include_stop_tokens=True):
"""
Retreive the beam search output from the final state.
Returns a [batch_size, max_len]-sized Tensor.
"""
res = final_state[0]
if include_stop_tokens:
res = tf.concat(1, [res[:,1:], tf.ones_like(res[:,0:1]) * self.stop_token])
return res
def unwrap_output_sparse(self, final_state, include_stop_tokens=True):
"""
Retreive the beam search output from the final state.
Returns a sparse tensor with underlying dimensions of [batch_size, max_len]
"""
output_dense = final_state[0]
mask = tf.not_equal(output_dense, self.stop_token)
if include_stop_tokens:
output_dense = tf.concat(1, [output_dense[:,1:], tf.ones_like(output_dense[:,0:1]) * self.stop_token])
mask = tf.concat(1, [mask[:,1:], tf.cast(tf.ones_like(mask[:,0:1], dtype=tf.int8), tf.bool)])
return sparse_boolean_mask(output_dense, mask)
def unwrap_output_logprobs(self, final_state):
"""
Retreive the log-probabilities associated with the selected beams.
"""
return final_state[1]
class BeamDecoderCellWrapper(tf.nn.rnn_cell.RNNCell):
def __init__(self, cell, num_classes, max_len, stop_token=0, beam_size=7):
# TODO: determine if we can have dynamic shapes instead of pre-filling up to max_len
self.cell = cell
self.num_classes = num_classes
self.stop_token = stop_token
self.beam_size = beam_size
self.max_len = max_len
# Note: masking out entries to -inf plays poorly with top_k, so just subtract out
# a large number.
# TODO: consider using slice+fill+concat instead of adding a mask
# TODO: consider making the large negative constant dtype-dependent
self._nondone_mask = tf.reshape(
tf.cast(tf.equal(tf.range(self.num_classes), self.stop_token), tf.float32) * -1e18,
[1, 1, self.num_classes]
)
self._nondone_mask = tf.reshape(tf.tile(self._nondone_mask, [1, self.beam_size, 1]),
[-1, self.beam_size*self.num_classes])
def __call__(self, inputs, state, scope=None):
(
past_cand_symbols, # [batch_size, max_len]
past_cand_logprobs,# [batch_size]
past_beam_symbols, # [batch_size*self.beam_size, max_len], right-aligned!!!
past_beam_logprobs,# [batch_size*self.beam_size]
past_cell_state,
) = state
batch_size = tf.shape(past_cand_symbols)[0] # TODO: get as int, if possible
full_size = batch_size * self.beam_size
cell_inputs = inputs
cell_outputs, raw_cell_state = self.cell(cell_inputs, past_cell_state)
logprobs = tf.nn.log_softmax(cell_outputs)
logprobs_batched = tf.reshape(logprobs + tf.expand_dims(past_beam_logprobs, 1),
[-1, self.beam_size * self.num_classes])
logprobs_batched.set_shape((None, self.beam_size * self.num_classes))
beam_logprobs, indices = tf.nn.top_k(
tf.reshape(logprobs_batched + self._nondone_mask, [-1, self.beam_size * self.num_classes]),
self.beam_size
)
beam_logprobs = tf.reshape(beam_logprobs, [-1])
# For continuing to the next symbols
symbols = indices % self.num_classes # [batch_size, self.beam_size]
parent_refs = tf.reshape(indices // self.num_classes, [-1]) # [batch_size*self.beam_size]
# TODO: this technically doesn't need to be recalculated every loop
parent_refs_offsets = (tf.range(batch_size * self.beam_size) // self.beam_size) * self.beam_size
parent_refs = parent_refs + parent_refs_offsets
symbols_history = tf.gather(past_beam_symbols, parent_refs)
beam_symbols = tf.concat(1, [symbols_history[:,1:], tf.reshape(symbols, [-1, 1])])
# Handle the output and the cell state shuffling
outputs = tf.reshape(symbols, [-1]) # [batch_size*beam_size, 1]
cell_state = nest_map(
lambda element: tf.gather(element, parent_refs),
raw_cell_state
)
# Handling for getting a done token
logprobs_done = tf.reshape(logprobs_batched, [-1, self.beam_size, self.num_classes])[:,:,self.stop_token]
done_parent_refs = tf.to_int32(tf.argmax(logprobs_done, 1))
done_parent_refs_offsets = tf.range(batch_size) * self.beam_size
done_symbols = tf.gather(past_beam_symbols, done_parent_refs + done_parent_refs_offsets)
logprobs_done_max = tf.reduce_max(logprobs_done, 1)
cand_symbols = tf.select(logprobs_done_max > past_cand_logprobs,
done_symbols,
past_cand_symbols)
cand_logprobs = tf.maximum(logprobs_done_max, past_cand_logprobs)
return outputs, (
cand_symbols,
cand_logprobs,
beam_symbols,
beam_logprobs,
cell_state,
)
@property
def state_size(self):
return (self.max_len,
1,
self.max_len,
1,
self.cell.state_size
)
@property
def output_size(self):
return 1
def _create_state(self, batch_size, dtype, cell_state=None):
cand_symbols = tf.fill([batch_size, self.max_len], tf.constant(self.stop_token, dtype=tf.int32))
cand_logprobs = tf.ones((batch_size,), dtype=tf.float32) * -float('inf')
if cell_state is None:
cell_state = self.cell.zero_state(batch_size*self.beam_size, dtype=dtype)
else:
cell_state = BeamDecoder._tile_along_beam(self.beam_size, cell_state)
full_size = batch_size * self.beam_size
first_in_beam_mask = tf.equal(tf.range(full_size) % self.beam_size, 0)
beam_symbols = tf.fill([full_size, self.max_len], tf.constant(self.stop_token, dtype=tf.int32))
beam_logprobs = tf.select(
first_in_beam_mask,
tf.fill([full_size], 0.0),
tf.fill([full_size], -1e18), # top_k does not play well with -inf
# TODO: dtype-dependent value here
)
return (
cand_symbols,
cand_logprobs,
beam_symbols,
beam_logprobs,
cell_state,
)
def zero_state(self, batch_size_times_beam_size, dtype):
"""
Instead of calling this manually, please use
BeamDecoder.wrap_state(cell.zero_state(...)) instead
"""
batch_size = batch_size_times_beam_size / self.beam_size
return self.create_zero_state(batch_size, dtype)
def sparse_boolean_mask(tensor, mask):
"""
Creates a sparse tensor from masked elements of `tensor`
Inputs:
tensor: a 2-D tensor, [batch_size, T]
mask: a 2-D mask, [batch_size, T]
Output: a 2-D sparse tensor
"""
mask_lens = tf.reduce_sum(tf.cast(mask, tf.int32), -1, keep_dims=True)
mask_shape = tf.shape(mask)
left_shifted_mask = tf.tile(
tf.expand_dims(tf.range(mask_shape[1]), 0),
[mask_shape[0], 1]
) < mask_lens
return tf.SparseTensor(
indices=tf.where(left_shifted_mask),
values=tf.boolean_mask(tensor, mask),
shape=tf.cast(tf.pack([mask_shape[0], tf.reduce_max(mask_lens)]), tf.int64) # For 2D only
)
import tensorflow as tf
import numpy as np
from tf_beam_decoder import BeamDecoder
sess = tf.InteractiveSession()
class MarkovChainCell(tf.nn.rnn_cell.RNNCell):
"""
This cell type is only used for testing the beam decoder.
It represents a Markov chain characterized by a probability table p(x_t|x_{t-1},x_{t-2}).
"""
def __init__(self, table):
"""
table[a,b,c] = p(x_t=c|x_{t-1}=b,x_{t-2}=a)
"""
assert len(table.shape) == 3 and table.shape[0] == table.shape[1] == table.shape[2]
self.log_table = tf.log(np.asarray(table, dtype=np.float32))
self._output_size = table.shape[0]
def __call__(self, inputs, state, scope=None):
"""
inputs: [batch_size, 1] int tensor
state: [batch_size, 1] int tensor
"""
logits = tf.reshape(self.log_table, [-1, self.output_size])
indices = state[0] * self.output_size + inputs
return tf.gather(logits, tf.reshape(indices, [-1])), (inputs,)
@property
def state_size(self):
return (1,)
@property
def output_size(self):
return self._output_size
# Test 1
table = np.array([[[0.0, 0.6, 0.4],
[0.0, 0.4, 0.6],
[0.0, 0.0, 1.0]]] * 3)
cell = MarkovChainCell(table)
initial_state = cell.zero_state(1, tf.int32)
initial_input = initial_state[0]
beam_decoder = BeamDecoder(num_classes=3, stop_token=2, beam_size=7, max_len=5)
_, final_state = tf.nn.seq2seq.rnn_decoder(
[beam_decoder.wrap_input(initial_input)] + [None] * 4,
beam_decoder.wrap_state(initial_state),
beam_decoder.wrap_cell(cell),
loop_function = lambda prev_symbol, i: tf.reshape(prev_symbol, [-1, 1])
)
best_dense = beam_decoder.unwrap_output_dense(final_state)
best_sparse = beam_decoder.unwrap_output_sparse(final_state)
best_logprobs = beam_decoder.unwrap_output_logprobs(final_state)
assert all(best_sparse.eval().values == [2])
assert np.isclose(np.exp(best_logprobs.eval())[0], 0.4)
# Test 2
table = np.array([[[0.9, 0.1, 0],
[0, 0.9, 0.1],
[0, 0, 1.0]]] * 3)
cell = MarkovChainCell(table)
initial_state = cell.zero_state(1, tf.int32)
initial_input = initial_state[0]
beam_decoder = BeamDecoder(num_classes=3, stop_token=2, beam_size=10, max_len=3)
_, final_state = tf.nn.seq2seq.rnn_decoder(
[beam_decoder.wrap_input(initial_input)] + [None] * 2,
beam_decoder.wrap_state(initial_state),
beam_decoder.wrap_cell(cell),
loop_function = lambda prev_symbol, i: tf.reshape(prev_symbol, [-1, 1])
)
candidates, candidate_logprobs = sess.run((final_state[2], final_state[3]))
assert all(candidates[0,:] == [0,0,0])
assert np.isclose(np.exp(candidate_logprobs[0]), 0.9 * 0.9 * 0.9)
# Note that these three candidates all have the same score, and the sort order
# may change in the future
assert all(candidates[1,:] == [0,0,1])
assert np.isclose(np.exp(candidate_logprobs[1]), 0.9 * 0.9 * 0.1)
assert all(candidates[2,:] == [0,1,1])
assert np.isclose(np.exp(candidate_logprobs[2]), 0.9 * 0.1 * 0.9)
assert all(candidates[3,:] == [1,1,1])
assert np.isclose(np.exp(candidate_logprobs[3]), 0.1 * 0.9 * 0.9)
assert all(np.isclose(np.exp(candidate_logprobs[4:]), 0.0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment