-
-
Save bredfern/82b6709752ed4503a3a9cf7781ad011b to your computer and use it in GitHub Desktop.
Tensorflow Beam Search
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Beam decoder for tensorflow | |
Sample usage: | |
``` | |
beam_decoder = BeamDecoder(NUM_CLASSES, beam_size=10, max_len=MAX_LEN) | |
_, final_state = tf.nn.seq2seq.rnn_decoder( | |
[beam_decoder.wrap_input(initial_input)] + [None] * (MAX_LEN - 1), | |
beam_decoder.wrap_state(initial_state), | |
beam_decoder.wrap_cell(my_cell), | |
loop_function = lambda prev_symbol, i: tf.nn.embedding_lookup(my_embedding, prev_symbol) | |
) | |
best_dense = beam_decoder.unwrap_output_dense(final_state) # Dense tensor output, right-aligned | |
best_sparse = beam_decoder.unwrap_output_sparse(final_state) # Output, this time as a sparse tensor | |
``` | |
""" | |
import tensorflow as tf | |
try: | |
from tensorflow.python.util import nest | |
except ImportError: | |
# Backwards-compatibility | |
from tensorflow.python.ops import rnn_cell | |
class NestModule(object): pass | |
nest = NestModule() | |
nest.is_sequence = rnn_cell._is_sequence | |
nest.flatten = rnn_cell._unpacked_state | |
nest.pack_sequence_as = rnn_cell._packed_state | |
def nest_map(func, nested): | |
if not nest.is_sequence(nested): | |
return func(nested) | |
flat = nest.flatten(nested) | |
return nest.pack_sequence_as(nested, list(map(func, flat))) | |
class BeamDecoder(object): | |
def __init__(self, num_classes, stop_token=0, beam_size=7, max_len=20): | |
""" | |
num_classes: int. Number of output classes used | |
stop_token: int. | |
beam_size: int. | |
max-len: int or scalar Tensor. If this cell is called recurrently more | |
than max_len times in a row, the outputs will not be valid! | |
""" | |
self.num_classes = num_classes | |
self.stop_token = stop_token | |
self.beam_size = beam_size | |
self.max_len = max_len | |
@classmethod | |
def _tile_along_beam(cls, beam_size, state): | |
if nest.is_sequence(state): | |
return nest_map( | |
lambda val: cls._tile_along_beam(beam_size, val), | |
state | |
) | |
if not isinstance(state, tf.Tensor): | |
raise ValueError("State should be a sequence or tensor") | |
tensor = state | |
tensor_shape = tensor.get_shape().with_rank_at_least(1) | |
try: | |
new_first_dim = tensor_shape[0] * beam_size | |
except: | |
new_first_dim = None | |
dynamic_tensor_shape = tf.unpack(tf.shape(tensor)) | |
res = tf.expand_dims(tensor, 1) | |
res = tf.tile(res, [1, beam_size] + [1] * (tensor_shape.ndims-1)) | |
res = tf.reshape(res, [-1] + list(dynamic_tensor_shape[1:])) | |
res.set_shape([new_first_dim] + list(tensor_shape[1:])) | |
return res | |
def wrap_cell(self, cell): | |
""" | |
Wraps a cell for use with the beam decoder | |
""" | |
return BeamDecoderCellWrapper(cell, self.num_classes, self.max_len, self.stop_token, self.beam_size) | |
def wrap_state(self, state): | |
dummy = BeamDecoderCellWrapper(None, self.num_classes, self.max_len, self.stop_token, self.beam_size) | |
if nest.is_sequence(state): | |
batch_size = tf.shape(nest.flatten(state)[0])[0] | |
dtype = nest.flatten(state)[0].dtype | |
else: | |
batch_size = tf.shape(state)[0] | |
dtype = state.dtype | |
return dummy._create_state(batch_size, dtype, cell_state=state) | |
def wrap_input(self, input): | |
""" | |
Wraps an input for use with the beam decoder. | |
Should be used for the initial input at timestep zero, as well as any side-channel | |
inputs that are per-batch (e.g. attention targets) | |
""" | |
return self._tile_along_beam(self.beam_size, input) | |
def unwrap_output_dense(self, final_state, include_stop_tokens=True): | |
""" | |
Retreive the beam search output from the final state. | |
Returns a [batch_size, max_len]-sized Tensor. | |
""" | |
res = final_state[0] | |
if include_stop_tokens: | |
res = tf.concat(1, [res[:,1:], tf.ones_like(res[:,0:1]) * self.stop_token]) | |
return res | |
def unwrap_output_sparse(self, final_state, include_stop_tokens=True): | |
""" | |
Retreive the beam search output from the final state. | |
Returns a sparse tensor with underlying dimensions of [batch_size, max_len] | |
""" | |
output_dense = final_state[0] | |
mask = tf.not_equal(output_dense, self.stop_token) | |
if include_stop_tokens: | |
output_dense = tf.concat(1, [output_dense[:,1:], tf.ones_like(output_dense[:,0:1]) * self.stop_token]) | |
mask = tf.concat(1, [mask[:,1:], tf.cast(tf.ones_like(mask[:,0:1], dtype=tf.int8), tf.bool)]) | |
return sparse_boolean_mask(output_dense, mask) | |
def unwrap_output_logprobs(self, final_state): | |
""" | |
Retreive the log-probabilities associated with the selected beams. | |
""" | |
return final_state[1] | |
class BeamDecoderCellWrapper(tf.nn.rnn_cell.RNNCell): | |
def __init__(self, cell, num_classes, max_len, stop_token=0, beam_size=7): | |
# TODO: determine if we can have dynamic shapes instead of pre-filling up to max_len | |
self.cell = cell | |
self.num_classes = num_classes | |
self.stop_token = stop_token | |
self.beam_size = beam_size | |
self.max_len = max_len | |
# Note: masking out entries to -inf plays poorly with top_k, so just subtract out | |
# a large number. | |
# TODO: consider using slice+fill+concat instead of adding a mask | |
# TODO: consider making the large negative constant dtype-dependent | |
self._nondone_mask = tf.reshape( | |
tf.cast(tf.equal(tf.range(self.num_classes), self.stop_token), tf.float32) * -1e18, | |
[1, 1, self.num_classes] | |
) | |
self._nondone_mask = tf.reshape(tf.tile(self._nondone_mask, [1, self.beam_size, 1]), | |
[-1, self.beam_size*self.num_classes]) | |
def __call__(self, inputs, state, scope=None): | |
( | |
past_cand_symbols, # [batch_size, max_len] | |
past_cand_logprobs,# [batch_size] | |
past_beam_symbols, # [batch_size*self.beam_size, max_len], right-aligned!!! | |
past_beam_logprobs,# [batch_size*self.beam_size] | |
past_cell_state, | |
) = state | |
batch_size = tf.shape(past_cand_symbols)[0] # TODO: get as int, if possible | |
full_size = batch_size * self.beam_size | |
cell_inputs = inputs | |
cell_outputs, raw_cell_state = self.cell(cell_inputs, past_cell_state) | |
logprobs = tf.nn.log_softmax(cell_outputs) | |
logprobs_batched = tf.reshape(logprobs + tf.expand_dims(past_beam_logprobs, 1), | |
[-1, self.beam_size * self.num_classes]) | |
logprobs_batched.set_shape((None, self.beam_size * self.num_classes)) | |
beam_logprobs, indices = tf.nn.top_k( | |
tf.reshape(logprobs_batched + self._nondone_mask, [-1, self.beam_size * self.num_classes]), | |
self.beam_size | |
) | |
beam_logprobs = tf.reshape(beam_logprobs, [-1]) | |
# For continuing to the next symbols | |
symbols = indices % self.num_classes # [batch_size, self.beam_size] | |
parent_refs = tf.reshape(indices // self.num_classes, [-1]) # [batch_size*self.beam_size] | |
# TODO: this technically doesn't need to be recalculated every loop | |
parent_refs_offsets = (tf.range(batch_size * self.beam_size) // self.beam_size) * self.beam_size | |
parent_refs = parent_refs + parent_refs_offsets | |
symbols_history = tf.gather(past_beam_symbols, parent_refs) | |
beam_symbols = tf.concat(1, [symbols_history[:,1:], tf.reshape(symbols, [-1, 1])]) | |
# Handle the output and the cell state shuffling | |
outputs = tf.reshape(symbols, [-1]) # [batch_size*beam_size, 1] | |
cell_state = nest_map( | |
lambda element: tf.gather(element, parent_refs), | |
raw_cell_state | |
) | |
# Handling for getting a done token | |
logprobs_done = tf.reshape(logprobs_batched, [-1, self.beam_size, self.num_classes])[:,:,self.stop_token] | |
done_parent_refs = tf.to_int32(tf.argmax(logprobs_done, 1)) | |
done_parent_refs_offsets = tf.range(batch_size) * self.beam_size | |
done_symbols = tf.gather(past_beam_symbols, done_parent_refs + done_parent_refs_offsets) | |
logprobs_done_max = tf.reduce_max(logprobs_done, 1) | |
cand_symbols = tf.select(logprobs_done_max > past_cand_logprobs, | |
done_symbols, | |
past_cand_symbols) | |
cand_logprobs = tf.maximum(logprobs_done_max, past_cand_logprobs) | |
return outputs, ( | |
cand_symbols, | |
cand_logprobs, | |
beam_symbols, | |
beam_logprobs, | |
cell_state, | |
) | |
@property | |
def state_size(self): | |
return (self.max_len, | |
1, | |
self.max_len, | |
1, | |
self.cell.state_size | |
) | |
@property | |
def output_size(self): | |
return 1 | |
def _create_state(self, batch_size, dtype, cell_state=None): | |
cand_symbols = tf.fill([batch_size, self.max_len], tf.constant(self.stop_token, dtype=tf.int32)) | |
cand_logprobs = tf.ones((batch_size,), dtype=tf.float32) * -float('inf') | |
if cell_state is None: | |
cell_state = self.cell.zero_state(batch_size*self.beam_size, dtype=dtype) | |
else: | |
cell_state = BeamDecoder._tile_along_beam(self.beam_size, cell_state) | |
full_size = batch_size * self.beam_size | |
first_in_beam_mask = tf.equal(tf.range(full_size) % self.beam_size, 0) | |
beam_symbols = tf.fill([full_size, self.max_len], tf.constant(self.stop_token, dtype=tf.int32)) | |
beam_logprobs = tf.select( | |
first_in_beam_mask, | |
tf.fill([full_size], 0.0), | |
tf.fill([full_size], -1e18), # top_k does not play well with -inf | |
# TODO: dtype-dependent value here | |
) | |
return ( | |
cand_symbols, | |
cand_logprobs, | |
beam_symbols, | |
beam_logprobs, | |
cell_state, | |
) | |
def zero_state(self, batch_size_times_beam_size, dtype): | |
""" | |
Instead of calling this manually, please use | |
BeamDecoder.wrap_state(cell.zero_state(...)) instead | |
""" | |
batch_size = batch_size_times_beam_size / self.beam_size | |
return self.create_zero_state(batch_size, dtype) | |
def sparse_boolean_mask(tensor, mask): | |
""" | |
Creates a sparse tensor from masked elements of `tensor` | |
Inputs: | |
tensor: a 2-D tensor, [batch_size, T] | |
mask: a 2-D mask, [batch_size, T] | |
Output: a 2-D sparse tensor | |
""" | |
mask_lens = tf.reduce_sum(tf.cast(mask, tf.int32), -1, keep_dims=True) | |
mask_shape = tf.shape(mask) | |
left_shifted_mask = tf.tile( | |
tf.expand_dims(tf.range(mask_shape[1]), 0), | |
[mask_shape[0], 1] | |
) < mask_lens | |
return tf.SparseTensor( | |
indices=tf.where(left_shifted_mask), | |
values=tf.boolean_mask(tensor, mask), | |
shape=tf.cast(tf.pack([mask_shape[0], tf.reduce_max(mask_lens)]), tf.int64) # For 2D only | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
from tf_beam_decoder import BeamDecoder | |
sess = tf.InteractiveSession() | |
class MarkovChainCell(tf.nn.rnn_cell.RNNCell): | |
""" | |
This cell type is only used for testing the beam decoder. | |
It represents a Markov chain characterized by a probability table p(x_t|x_{t-1},x_{t-2}). | |
""" | |
def __init__(self, table): | |
""" | |
table[a,b,c] = p(x_t=c|x_{t-1}=b,x_{t-2}=a) | |
""" | |
assert len(table.shape) == 3 and table.shape[0] == table.shape[1] == table.shape[2] | |
self.log_table = tf.log(np.asarray(table, dtype=np.float32)) | |
self._output_size = table.shape[0] | |
def __call__(self, inputs, state, scope=None): | |
""" | |
inputs: [batch_size, 1] int tensor | |
state: [batch_size, 1] int tensor | |
""" | |
logits = tf.reshape(self.log_table, [-1, self.output_size]) | |
indices = state[0] * self.output_size + inputs | |
return tf.gather(logits, tf.reshape(indices, [-1])), (inputs,) | |
@property | |
def state_size(self): | |
return (1,) | |
@property | |
def output_size(self): | |
return self._output_size | |
# Test 1 | |
table = np.array([[[0.0, 0.6, 0.4], | |
[0.0, 0.4, 0.6], | |
[0.0, 0.0, 1.0]]] * 3) | |
cell = MarkovChainCell(table) | |
initial_state = cell.zero_state(1, tf.int32) | |
initial_input = initial_state[0] | |
beam_decoder = BeamDecoder(num_classes=3, stop_token=2, beam_size=7, max_len=5) | |
_, final_state = tf.nn.seq2seq.rnn_decoder( | |
[beam_decoder.wrap_input(initial_input)] + [None] * 4, | |
beam_decoder.wrap_state(initial_state), | |
beam_decoder.wrap_cell(cell), | |
loop_function = lambda prev_symbol, i: tf.reshape(prev_symbol, [-1, 1]) | |
) | |
best_dense = beam_decoder.unwrap_output_dense(final_state) | |
best_sparse = beam_decoder.unwrap_output_sparse(final_state) | |
best_logprobs = beam_decoder.unwrap_output_logprobs(final_state) | |
assert all(best_sparse.eval().values == [2]) | |
assert np.isclose(np.exp(best_logprobs.eval())[0], 0.4) | |
# Test 2 | |
table = np.array([[[0.9, 0.1, 0], | |
[0, 0.9, 0.1], | |
[0, 0, 1.0]]] * 3) | |
cell = MarkovChainCell(table) | |
initial_state = cell.zero_state(1, tf.int32) | |
initial_input = initial_state[0] | |
beam_decoder = BeamDecoder(num_classes=3, stop_token=2, beam_size=10, max_len=3) | |
_, final_state = tf.nn.seq2seq.rnn_decoder( | |
[beam_decoder.wrap_input(initial_input)] + [None] * 2, | |
beam_decoder.wrap_state(initial_state), | |
beam_decoder.wrap_cell(cell), | |
loop_function = lambda prev_symbol, i: tf.reshape(prev_symbol, [-1, 1]) | |
) | |
candidates, candidate_logprobs = sess.run((final_state[2], final_state[3])) | |
assert all(candidates[0,:] == [0,0,0]) | |
assert np.isclose(np.exp(candidate_logprobs[0]), 0.9 * 0.9 * 0.9) | |
# Note that these three candidates all have the same score, and the sort order | |
# may change in the future | |
assert all(candidates[1,:] == [0,0,1]) | |
assert np.isclose(np.exp(candidate_logprobs[1]), 0.9 * 0.9 * 0.1) | |
assert all(candidates[2,:] == [0,1,1]) | |
assert np.isclose(np.exp(candidate_logprobs[2]), 0.9 * 0.1 * 0.9) | |
assert all(candidates[3,:] == [1,1,1]) | |
assert np.isclose(np.exp(candidate_logprobs[3]), 0.1 * 0.9 * 0.9) | |
assert all(np.isclose(np.exp(candidate_logprobs[4:]), 0.0)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment