Skip to content

Instantly share code, notes, and snippets.

@JRMeyer
Forked from awni/ctc_decoder.py
Last active July 28, 2020 01:02
Show Gist options
  • Save JRMeyer/06b4c4c4bfbad1c31087501f2d589075 to your computer and use it in GitHub Desktop.
Save JRMeyer/06b4c4c4bfbad1c31087501f2d589075 to your computer and use it in GitHub Desktop.
Example CTC Decoder in Python
"""
Author: Awni Hannun
This is an example CTC decoder written in Python. The code is
intended to be a simple example and is not designed to be
especially efficient.
The algorithm is a prefix beam search for a model trained
with the CTC loss function.
For more details checkout either of these references:
https://distill.pub/2017/ctc/#inference
https://arxiv.org/abs/1408.2873
"""
import numpy as np
import math
import collections
NEG_INF = -float("inf")
def make_new_beam():
fn = lambda : (NEG_INF, NEG_INF)
return collections.defaultdict(fn)
def logsumexp(*args):
"""
Stable log sum exp.
"""
if all(a == NEG_INF for a in args):
return NEG_INF
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max)
for a in args))
return a_max + lsp
def decode(probs, beam_size=100, blank=0):
"""
Performs inference for the given output probabilities.
Arguments:
probs: The output probabilities (e.g. post-softmax) for each
time step. Should be an array of shape (time x output dim).
beam_size (int): Size of the beam to use during inference.
blank (int): Index of the CTC blank label.
Returns: (1) the output label sequence
(2) the corresponding negative log-likelihood
"""
T, S = probs.shape # time x len_alphabet
probs = np.log(probs) # working in log space for math reasons
# Elements in the beam are (prefix, (p_blank, p_no_blank))
# Initialize the beam with the empty sequence, a probability of
# 1 for ending in blank and zero for ending in non-blank
# (in log space).
beam = [(tuple(), (0.0, NEG_INF))]
for t in range(T): # Loop over time
# A default dictionary to store the next step candidates.
next_beam = make_new_beam()
for s in range(S): # Loop over valid letters in the alphabet (aka the "vocab")
p = probs[t, s]
# The variables p_b and p_nb are respectively the
# probabilities for the prefix given that it ends in a
# blank and does not end in a blank at this time step.
for prefix, (p_b, p_nb) in beam: # Loop over beam
# If we propose a blank the prefix doesn't change.
# Only the probability of ending in blank gets updated.
if s == blank:
n_p_b, n_p_nb = next_beam[prefix]
n_p_b = logsumexp(n_p_b, p_b + p, p_nb + p)
next_beam[prefix] = (n_p_b, n_p_nb)
continue
# Extend the prefix by the new character s and add it to
# the beam. Only the probability of not ending in blank
# gets updated.
end_t = prefix[-1] if prefix else None
n_prefix = prefix + (s,)
n_p_b, n_p_nb = next_beam[n_prefix]
if s != end_t:
n_p_nb = logsumexp(n_p_nb, p_b + p, p_nb + p)
else:
# We don't include the previous probability of not ending
# in blank (p_nb) if s is repeated at the end. The CTC
# algorithm merges characters not separated by a blank.
n_p_nb = logsumexp(n_p_nb, p_b + p)
# *NB* this would be a good place to include an LM score.
next_beam[n_prefix] = (n_p_b, n_p_nb)
# If s is repeated at the end we also update the unchanged
# prefix. This is the merging case.
if s == end_t:
n_p_b, n_p_nb = next_beam[prefix]
n_p_nb = logsumexp(n_p_nb, p_nb + p)
next_beam[prefix] = (n_p_b, n_p_nb)
# Sort and trim the beam before moving on to the
# next time-step.
beam = sorted(next_beam.items(),
key=lambda x : logsumexp(*x[1]),
reverse=True)
beam = beam[:beam_size]
best = beam[0]
return best[0], -logsumexp(*best[1])
if __name__ == "__main__":
np.random.seed(3)
time = 50
output_dim = 20
probs = np.random.rand(time, output_dim)
probs = probs / np.sum(probs, axis=1, keepdims=True)
labels, score = decode(probs)
print("Score {:.3f}".format(score))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment