-
-
Save meowcat/e3eaa4b8543a7c8444f4a74a9074b9ae to your computer and use it in GitHub Desktop.
Vectorized beam search decoder in TensorFlow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Created on Wed Apr 15 17:02:36 2020 | |
@author: stravsm | |
""" | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow import math as tm | |
from tensorflow.keras.models import Model | |
from tensorflow.keras.layers import Input, Dense, TimeDistributed, RepeatVector | |
from tensorflow.keras.layers import concatenate | |
from tqdm import tqdm | |
n_states = 1000 | |
n_dims_in = 30 | |
n_tokens_out = 60 | |
INITIAL_TOKEN = 13 | |
# n: the number of parallel sequences to decode, | |
# k: beam size per sequence. | |
# In reality I am looking to do n=32, k=64, steps=128. | |
n = 32 | |
k = 64 | |
steps = 50 | |
def build_model_sham(n_states, n_dims_in, n_tokens_out, steps=1, units = 64): | |
''' | |
This is really just a dummy model [states_n, input] -> [states_n+1, output] | |
that doesn't really do much, initialized with random values. | |
In reality this is a quite complex LSTM + tricks model. | |
During training it's used with teacher forcing and steps >>1, | |
during beam search we can only predict 1 step at a time. | |
''' | |
sequence = Input((steps, n_dims_in), dtype="float32") | |
states = Input((n_states,), dtype="float32") | |
sequence_dense_time = TimeDistributed(Dense(units, | |
kernel_initializer='random_uniform', | |
bias_initializer='zeros') | |
)(sequence) | |
states_dense = Dense(units, | |
kernel_initializer='random_uniform', | |
bias_initializer='zeros')(states) | |
states_dense_time = RepeatVector(steps)(states_dense) | |
concat_dense_time = concatenate([sequence_dense_time, states_dense_time], axis=2) | |
sequence_out = TimeDistributed(Dense(n_tokens_out, | |
kernel_initializer='random_uniform', | |
bias_initializer='zeros', | |
activation = "softmax"))(concat_dense_time) | |
states_out = Dense(n_states, | |
kernel_initializer='random_uniform', | |
bias_initializer='zeros')(states_dense) | |
return Model(inputs=[sequence, states], | |
outputs = [sequence_out, states_out]) | |
model_decode = build_model_sham(n_states, n_dims_in, n_tokens_out) | |
# Embedding matrix from n_tokens_out to n_dims_in | |
ytox_matrix = np.random.rand(n_tokens_out, n_dims_in).astype("float32") | |
@tf.function | |
def embed_y_to_x(y): | |
return tf.expand_dims(tf.gather(ytox_matrix, y, axis=0), 1) | |
# Generate starting tensors | |
# Note: the "1" index instead of "0" below was intentionally introduced | |
# to ease debugging, to make sure the modulo divisions work appropriately | |
y_init = np.zeros((n,k), dtype='int32') | |
y_init[:,1] = INITIAL_TOKEN | |
y_init = np.reshape(y_init, (-1,)) | |
# y_init = tf.convert_to_tensor(y_init, dtype="int32") | |
# All invalid inputs start with a score of -infinity so they are only | |
# ever continued if there is not enough valid possibilities (i.e. in the | |
# first step when k > tokens) | |
scores_init = np.full((n,k), -np.inf, dtype='float32') | |
scores_init[:,1] = 0 | |
scores_init = np.reshape(scores_init, (-1,)) | |
@tf.function | |
def decode_beam(states_init, scores_init, y_init, steps, | |
k, n): | |
states = states_init | |
scores = scores_init | |
xstep = embed_y_to_x(y_init) | |
# Keep the results in TensorArrays | |
y_chain = tf.TensorArray(dtype="int32", size=steps) | |
sequences_chain = tf.TensorArray(dtype="int32", size=steps) | |
scores_chain = tf.TensorArray(dtype="float32", size=steps) | |
for i in range(steps): | |
ystep, states = model_decode([xstep, states]) | |
y = ystep | |
# Add scores of step n to input scores, kill the sequence if we reach the end | |
scores_y = tf.expand_dims(tf.reshape(scores, y.shape[:-1]), 2) \ | |
+ tm.log(y) | |
# Reshape into (n,k,tokens) and find the best k sequences to continue for each of n candidates | |
scores_y = tf.reshape(scores_y, [n, -1]) | |
top_k = tm.top_k(scores_y, k, sorted=False) | |
# Transform the indices. I was using tf.unravel_index but | |
# `tf.debugging.set_log_device_placement(True)` indicated that this would be placed on the CPU | |
# thus I rewrote it | |
top_k_index = tf.reshape( | |
top_k[1] + tf.reshape(tf.range(n), (-1, 1)) * scores_y.shape[1], [-1]) | |
ysequence = top_k_index // y.shape[2] | |
ymax = top_k_index % y.shape[2] | |
# this gives us two (n*k,) tensors with parent sequence (ysequence) | |
# and chosen character (ymax) per sequence. | |
# For continuation, pick the states, and "return" the scores | |
states = tf.gather(states, ysequence) | |
scores = tf.reshape(top_k[0], [-1]) | |
# Write the results into the TensorArrays, | |
# and embed for the next step | |
xstep = embed_y_to_x(ymax) | |
y_chain = y_chain.write(i, ymax) | |
sequences_chain = sequences_chain.write(i, ysequence) | |
scores_chain = scores_chain.write(i, scores) | |
# Done: Stack up the results and return them | |
sequences_final = sequences_chain.stack() | |
y_final = y_chain.stack() | |
scores_final = scores_chain.stack() | |
return sequences_final, y_final, scores_final | |
# The test: | |
# run the decoder first once for the graph optimization, | |
# then 50 times to time it | |
# Get n random starting states that are repeated k times | |
# and decode. | |
# The initial states are zeroes except for our valid input | |
# note: y_init and scores_init always stay the same, no need to redo them | |
for repeats in [1, 50]: | |
for _ in tqdm(range(repeats)): | |
states_init = np.random.rand(n, n_states).astype("float32") | |
states_init = np.repeat(states_init, k, axis=0) | |
beam_sequences = decode_beam(states_init, scores_init, y_init, | |
steps=steps, k = k, n = n) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment