Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active January 3, 2020 21:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vadimkantorov/c1aa417cffa1450b03716c740795f107 to your computer and use it in GitHub Desktop.
Save vadimkantorov/c1aa417cffa1450b03716c740795f107 to your computer and use it in GitHub Desktop.
A primitive forward pass of CTC loss
# reimpl of forward pass from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/LossCTC.cpp#L37
# a vectorized version in https://github.com/vadimkantorov/ctc
import torch
# does only reduction = 'none' and does not support zero_infinity = True
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank = 0):
targets_ = torch.full((targets.shape[0], 2 * targets.shape[-1] + 1), blank, device = targets.device, dtype = targets.dtype)
temporal_mask = torch.arange(targets.shape[-1], device = input_lengths.device, dtype = input_lengths.dtype).unsqueeze(0) < target_lengths.unsqueeze(1)
targets_[:, 1::2] = temporal_mask * targets + (~temporal_mask) * targets_[:, 1::2]
max_target_length = int(target_lengths.max())
batch_size = targets.shape[0]
log_alpha = torch.empty(batch_size, log_probs.shape[0], 2 * max_target_length + 1, device = log_probs.device, dtype = log_probs.dtype)
neg_log_likelihood = torch.empty(batch_size, device = log_probs.device, dtype = log_probs.dtype)
lpp = log_probs.permute(1, 0, 2)
neginf = torch.as_tensor([float('-inf')], device = log_probs.device, dtype = log_probs.dtype)
log_alpha.narrow(1, 0, 1).fill_(neginf.sum())
for b in range(batch_size):
input_length = input_lengths[b]
target_length = target_lengths[b]
log_alpha_a = log_alpha[b]
log_probs_a = lpp[b]
get_target_prime = targets_[b]
log_alpha_a[0, 0] = log_probs_a[0, blank]
log_alpha_a[0, 1] = log_probs_a[0, get_target_prime[1]]
for t in range(1, input_length):
for s in range(0, 2 * target_length + 1):
current_target_prime = get_target_prime[s]
la1 = log_alpha_a[t - 1, s]
lamax = la1
if s > 0:
la2 = log_alpha_a[t - 1, s-1]
if la2 > lamax:
lamax = la2
else:
la2 = neginf
if s > 1 and get_target_prime[s - 2] != current_target_prime:
la3 = log_alpha_a[t - 1, s-2]
if la3 > lamax:
lamax = la3
else:
la3 = neginf
if lamax == neginf:
lamax = 0
log_alpha_a[t, s] = torch.log(torch.exp(la1 - lamax) + torch.exp(la2 - lamax) + torch.exp(la3 - lamax)) + lamax + log_probs_a[t, current_target_prime]
l1 = log_alpha_a[input_length - 1, target_length * 2]
l2 = log_alpha_a[input_length - 1, target_length * 2 - 1]
m = torch.max(l1, l2)
m = 0 if m == neginf else m
log_likelihood = torch.log(torch.exp(l1 - m) + torch.exp(l2 - m)) + m
neg_log_likelihood[b] = -log_likelihood
return neg_log_likelihood
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment