Skip to content

Instantly share code, notes, and snippets.

@udibr
Last active October 4, 2021 11:50
Show Gist options
  • Save udibr/67be473cf053d8c38730 to your computer and use it in GitHub Desktop.
Save udibr/67be473cf053d8c38730 to your computer and use it in GitHub Desktop.
beam search for Keras RNN
# variation to https://github.com/ryankiros/skip-thoughts/blob/master/decoding/search.py
def keras_rnn_predict(samples, empty=empty, rnn_model=model, maxlen=maxlen):
"""for every sample, calculate probability for every possible label
you need to supply your RNN model and maxlen - the length of sequences it can handle
"""
data = sequence.pad_sequences(samples, maxlen=maxlen, value=empty)
return rnn_model.predict(data, verbose=0)
def beamsearch(predict=keras_rnn_predict,
k=1, maxsample=400, use_unk=False, oov=oov, empty=empty, eos=eos):
"""return k samples (beams) and their NLL scores, each sample is a sequence of labels,
all samples starts with an `empty` label and end with `eos` or truncated to length of `maxsample`.
You need to supply `predict` which returns the label probability of each sample.
`use_unk` allow usage of `oov` (out-of-vocabulary) label in samples
"""
dead_k = 0 # samples that reached eos
dead_samples = []
dead_scores = []
live_k = 1 # samples that did not yet reached eos
live_samples = [[empty]]
live_scores = [0]
while live_k and dead_k < k:
# for every possible live sample calc prob for every possible label
probs = predict(live_samples, empty=empty)
# total score for every sample is sum of -log of word prb
cand_scores = np.array(live_scores)[:,None] - np.log(probs)
if not use_unk and oov is not None:
cand_scores[:,oov] = 1e20
cand_flat = cand_scores.flatten()
# find the best (lowest) scores we have from all possible samples and new words
ranks_flat = cand_flat.argsort()[:(k-dead_k)]
live_scores = cand_flat[ranks_flat]
# append the new words to their appropriate live sample
voc_size = probs.shape[1]
live_samples = [live_samples[r//voc_size]+[r%voc_size] for r in ranks_flat]
# live samples that should be dead are...
zombie = [s[-1] == eos or len(s) >= maxsample for s in live_samples]
# add zombies to the dead
dead_samples += [s for s,z in zip(live_samples,zombie) if z] # remove first label == empty
dead_scores += [s for s,z in zip(live_scores,zombie) if z]
dead_k = len(dead_samples)
# remove zombies from the living
live_samples = [s for s,z in zip(live_samples,zombie) if not z]
live_scores = [s for s,z in zip(live_scores,zombie) if not z]
live_k = len(live_samples)
return dead_samples + live_samples, dead_scores + live_scores
Copy link

ghost commented Jun 16, 2018

@jeetp465 is your issue resolved ?

@patelsmit566
Copy link

patelsmit566 commented Jan 7, 2019

@jeetp465, @andersonzhu, According to my understanding, Beam search is not the part of model definition. It is The way how we decode the output of LSTM(RNN). So, BeamSearch is used where you are generating words from the predicted output received from LSTM. So The output of full_model.predict() will be passed to BeamSearch to get output captions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment