Skip to content

Instantly share code, notes, and snippets.

@kastnerkyle
Forked from udibr/beamsearch.py
Created July 8, 2016 07:34
Show Gist options
  • Save kastnerkyle/0edc9d569009b84f19265878344aa7f9 to your computer and use it in GitHub Desktop.
Save kastnerkyle/0edc9d569009b84f19265878344aa7f9 to your computer and use it in GitHub Desktop.
beam search for Keras RNN
# variation to https://github.com/ryankiros/skip-thoughts/blob/master/decoding/search.py
def keras_rnn_predict(samples, empty=empty, rnn_model=model, maxlen=maxlen):
"""for every sample, calculate probability for every possible label
you need to supply your RNN model and maxlen - the length of sequences it can handle
"""
data = sequence.pad_sequences(samples, maxlen=maxlen, value=empty)
return rnn_model.predict(data, verbose=0)
def beamsearch(predict=keras_rnn_predict,
k=1, maxsample=400, use_unk=False, oov=oov, empty=empty, eos=eos):
"""return k samples (beams) and their NLL scores, each sample is a sequence of labels,
all samples starts with an `empty` label and end with `eos` or truncated to length of `maxsample`.
You need to supply `predict` which returns the label probability of each sample.
`use_unk` allow usage of `oov` (out-of-vocabulary) label in samples
"""
dead_k = 0 # samples that reached eos
dead_samples = []
dead_scores = []
live_k = 1 # samples that did not yet reached eos
live_samples = [[empty]]
live_scores = [0]
while live_k and dead_k < k:
# for every possible live sample calc prob for every possible label
probs = predict(live_samples, empty=empty)
# total score for every sample is sum of -log of word prb
cand_scores = np.array(live_scores)[:,None] - np.log(probs)
if not use_unk and oov is not None:
cand_scores[:,oov] = 1e20
cand_flat = cand_scores.flatten()
# find the best (lowest) scores we have from all possible samples and new words
ranks_flat = cand_flat.argsort()[:(k-dead_k)]
live_scores = cand_flat[ranks_flat]
# append the new words to their appropriate live sample
voc_size = probs.shape[1]
live_samples = [live_samples[r//voc_size]+[r%voc_size] for r in ranks_flat]
# live samples that should be dead are...
zombie = [s[-1] == eos or len(s) >= maxsample for s in live_samples]
# add zombies to the dead
dead_samples += [s for s,z in zip(live_samples,zombie) if z] # remove first label == empty
dead_scores += [s for s,z in zip(live_scores,zombie) if z]
dead_k = len(dead_samples)
# remove zombies from the living
live_samples = [s for s,z in zip(live_samples,zombie) if not z]
live_scores = [s for s,z in zip(live_scores,zombie) if not z]
live_k = len(live_samples)
return dead_samples + live_samples, dead_scores + live_scores
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment