Skip to content

Instantly share code, notes, and snippets.

@meowcat
Created April 15, 2020 16:10
Show Gist options
  • Save meowcat/e3eaa4b8543a7c8444f4a74a9074b9ae to your computer and use it in GitHub Desktop.
Save meowcat/e3eaa4b8543a7c8444f4a74a9074b9ae to your computer and use it in GitHub Desktop.
Vectorized beam search decoder in TensorFlow
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 15 17:02:36 2020
@author: stravsm
"""
import numpy as np
import tensorflow as tf
from tensorflow import math as tm
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, TimeDistributed, RepeatVector
from tensorflow.keras.layers import concatenate
from tqdm import tqdm
n_states = 1000
n_dims_in = 30
n_tokens_out = 60
INITIAL_TOKEN = 13
# n: the number of parallel sequences to decode,
# k: beam size per sequence.
# In reality I am looking to do n=32, k=64, steps=128.
n = 32
k = 64
steps = 50
def build_model_sham(n_states, n_dims_in, n_tokens_out, steps=1, units = 64):
'''
This is really just a dummy model [states_n, input] -> [states_n+1, output]
that doesn't really do much, initialized with random values.
In reality this is a quite complex LSTM + tricks model.
During training it's used with teacher forcing and steps >>1,
during beam search we can only predict 1 step at a time.
'''
sequence = Input((steps, n_dims_in), dtype="float32")
states = Input((n_states,), dtype="float32")
sequence_dense_time = TimeDistributed(Dense(units,
kernel_initializer='random_uniform',
bias_initializer='zeros')
)(sequence)
states_dense = Dense(units,
kernel_initializer='random_uniform',
bias_initializer='zeros')(states)
states_dense_time = RepeatVector(steps)(states_dense)
concat_dense_time = concatenate([sequence_dense_time, states_dense_time], axis=2)
sequence_out = TimeDistributed(Dense(n_tokens_out,
kernel_initializer='random_uniform',
bias_initializer='zeros',
activation = "softmax"))(concat_dense_time)
states_out = Dense(n_states,
kernel_initializer='random_uniform',
bias_initializer='zeros')(states_dense)
return Model(inputs=[sequence, states],
outputs = [sequence_out, states_out])
model_decode = build_model_sham(n_states, n_dims_in, n_tokens_out)
# Embedding matrix from n_tokens_out to n_dims_in
ytox_matrix = np.random.rand(n_tokens_out, n_dims_in).astype("float32")
@tf.function
def embed_y_to_x(y):
return tf.expand_dims(tf.gather(ytox_matrix, y, axis=0), 1)
# Generate starting tensors
# Note: the "1" index instead of "0" below was intentionally introduced
# to ease debugging, to make sure the modulo divisions work appropriately
y_init = np.zeros((n,k), dtype='int32')
y_init[:,1] = INITIAL_TOKEN
y_init = np.reshape(y_init, (-1,))
# y_init = tf.convert_to_tensor(y_init, dtype="int32")
# All invalid inputs start with a score of -infinity so they are only
# ever continued if there is not enough valid possibilities (i.e. in the
# first step when k > tokens)
scores_init = np.full((n,k), -np.inf, dtype='float32')
scores_init[:,1] = 0
scores_init = np.reshape(scores_init, (-1,))
@tf.function
def decode_beam(states_init, scores_init, y_init, steps,
k, n):
states = states_init
scores = scores_init
xstep = embed_y_to_x(y_init)
# Keep the results in TensorArrays
y_chain = tf.TensorArray(dtype="int32", size=steps)
sequences_chain = tf.TensorArray(dtype="int32", size=steps)
scores_chain = tf.TensorArray(dtype="float32", size=steps)
for i in range(steps):
ystep, states = model_decode([xstep, states])
y = ystep
# Add scores of step n to input scores, kill the sequence if we reach the end
scores_y = tf.expand_dims(tf.reshape(scores, y.shape[:-1]), 2) \
+ tm.log(y)
# Reshape into (n,k,tokens) and find the best k sequences to continue for each of n candidates
scores_y = tf.reshape(scores_y, [n, -1])
top_k = tm.top_k(scores_y, k, sorted=False)
# Transform the indices. I was using tf.unravel_index but
# `tf.debugging.set_log_device_placement(True)` indicated that this would be placed on the CPU
# thus I rewrote it
top_k_index = tf.reshape(
top_k[1] + tf.reshape(tf.range(n), (-1, 1)) * scores_y.shape[1], [-1])
ysequence = top_k_index // y.shape[2]
ymax = top_k_index % y.shape[2]
# this gives us two (n*k,) tensors with parent sequence (ysequence)
# and chosen character (ymax) per sequence.
# For continuation, pick the states, and "return" the scores
states = tf.gather(states, ysequence)
scores = tf.reshape(top_k[0], [-1])
# Write the results into the TensorArrays,
# and embed for the next step
xstep = embed_y_to_x(ymax)
y_chain = y_chain.write(i, ymax)
sequences_chain = sequences_chain.write(i, ysequence)
scores_chain = scores_chain.write(i, scores)
# Done: Stack up the results and return them
sequences_final = sequences_chain.stack()
y_final = y_chain.stack()
scores_final = scores_chain.stack()
return sequences_final, y_final, scores_final
# The test:
# run the decoder first once for the graph optimization,
# then 50 times to time it
# Get n random starting states that are repeated k times
# and decode.
# The initial states are zeroes except for our valid input
# note: y_init and scores_init always stay the same, no need to redo them
for repeats in [1, 50]:
for _ in tqdm(range(repeats)):
states_init = np.random.rand(n, n_states).astype("float32")
states_init = np.repeat(states_init, k, axis=0)
beam_sequences = decode_beam(states_init, scores_init, y_init,
steps=steps, k = k, n = n)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment