Skip to content

Instantly share code, notes, and snippets.

@standy66
Created March 13, 2018 17:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save standy66/4dd14086133b0500d5c0e3c21debbfb2 to your computer and use it in GitHub Desktop.
Save standy66/4dd14086133b0500d5c0e3c21debbfb2 to your computer and use it in GitHub Desktop.
"""
Based on example here: https://gist.github.com/awni/56369a90d03953e370f3964c826ed4b0
"""
import numpy as np
import math
import collections
NEG_INF = -float("inf")
def logsumexp(x, y):
max_val = max(x, y)
if max_val == NEG_INF:
return NEG_INF
lsp = math.log(math.exp(x - max_val) + math.exp(y - max_val))
return max_val + lsp
class BeamProb:
__slots__ = ["blank", "label"]
def __init__(self, blank=NEG_INF, label=NEG_INF):
self.blank = blank
self.label = label
def update_label_prob(self, addendum):
self.label = logsumexp(self.label, addendum)
def update_blank_prob(self, addendum):
self.blank = logsumexp(self.blank, addendum)
@property
def total(self):
return logsumexp(self.blank, self.label)
def decode(probs, beam_size=100, blank=0):
"""
Performs inference for the given output probabilities.
Arguments:
probs: The output probabilities (e.g. post-softmax) for each
time step. Should be an array of shape (time x output dim).
beam_size (int): Size of the beam to use during inference.
blank (int): Index of the CTC blank label.
Returns the output label sequence and the corresponding negative
log-likelihood estimated by the decoder.
"""
T, S = probs.shape
probs = np.log(probs)
# Elements in the beam are (prefix, BeamProb(p_blank, p_label))
# Initialize the beam with the empty sequence, a probability of 1 for
# ending in blank and zero for ending in actual label (in log space).
beam = [((), BeamProb(0.0, NEG_INF))]
for t in range(T):
# A default dictionary to store the next step candidates.
next_beam = collections.defaultdict(BeamProb)
for s in range(S):
# The variables prob.blank and prob.label are respectively the probabilities
# for the prefix when it ends in a blank or an actual label at this time step.
for prefix, prob in beam:
# If we propose a blank the prefix doesn't change.
# Only the probability of ending in blank gets updated.
if s == blank:
next_beam[prefix].update_blank_prob(prob.total + probs[t, s])
continue
n_prefix = prefix + (s,)
if prefix and s != prefix[-1]:
# Extend the prefix by the new character s and add it to
# the beam. Only the probability of not ending in blank
# gets updated.
next_beam[n_prefix].update_label_prob(prob.total + probs[t, s])
else:
# We don't include the previous probability of not ending
# in blank (prob.label) if s is repeated at the end. The CTC
# algorithm merges characters not separated by a blank.
next_beam[n_prefix].update_label_prob(prob.blank + probs[t, s])
# If s is repeated at the end we also update the unchanged
# prefix. This is the merging case.
next_beam[prefix].update_label_prob(prob.label + probs[t, s])
# *NB* this would be a good place to include an LM score.
# update_with_lm(next_beam, n_prefix)
# Sort and trim the beam before moving on to the next time-step.
beam = sorted(next_beam.items(), key=lambda x: -x[1].total)[:beam_size]
labels, beam_prob = beam[0]
return labels, -logsumexp(beam_prob.blank, beam_prob.label)
if __name__ == "__main__":
np.random.seed(3)
time = 200
output_dim = 35
probs = np.random.rand(time, output_dim)
probs = probs / np.sum(probs, axis=1, keepdims=True)
labels, score = decode(probs)
print("Score {:.3f} labels length: {}".format(score, len(labels)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment