Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active April 9, 2024 03:03
Show Gist options
  • Save vadimkantorov/73e1915178f444b64f9ef01a1e96c1e4 to your computer and use it in GitHub Desktop.
Save vadimkantorov/73e1915178f444b64f9ef01a1e96c1e4 to your computer and use it in GitHub Desktop.
An implementation of CTC re-formulation via cross-entropy with pseudo-labels, following "A Novel Re-weighting Method for Connectionist Temporal Classification"
# CTC vanilla and CTC via crossentropy are equal, and their gradients as well. In this reformulation it's easier to experiment with modifications of CTC.
# References on CTC regularization:
# "A Novel Re-weighting Method for Connectionist Temporal Classification", Li et al, https://arxiv.org/abs/1904.10619
# "Focal CTC Loss for Chinese Optical Character Recognition on Unbalanced Datasets", Feng et al, https://www.hindawi.com/journals/complexity/2019/9345861/
# "Improved training for online end-to-end speech recognition systems", Kim et al, https://arxiv.org/abs/1711.02212
import torch
import torch.nn.functional as F
## generate example data
# generation is not very stable because of this bug https://github.com/pytorch/pytorch/issues/31557
torch.manual_seed(1)
B, C, T, t, blank = 16, 64, 32, 8, 0
logits = torch.randn(B, C, T).requires_grad_()
input_lengths = torch.full((B,), T, dtype = torch.long)
target_lengths = torch.full((B,), t, dtype = torch.long)
targets = torch.randint(blank + 1, C, (B, t), dtype = torch.long)
## compute CTC alignment targets
log_probs = F.log_softmax(logits, dim = 1)
ctc_loss = F.ctc_loss(log_probs.permute(2, 0, 1), targets, input_lengths, target_lengths, blank = blank, reduction = 'sum')
ctc_grad, = torch.autograd.grad(ctc_loss, (logits,), retain_graph = True)
temporal_mask = (torch.arange(logits.shape[-1], device = input_lengths.device, dtype = input_lengths.dtype).unsqueeze(0) < input_lengths.unsqueeze(1))[:, None, :]
alignment_targets = (log_probs.exp() * temporal_mask - ctc_grad).detach()
ctc_loss_via_crossentropy = (-alignment_targets * log_probs).sum()
ctc_grad, = torch.autograd.grad(ctc_loss, logits, retain_graph = True)
ctc_grad_via_crossentropy, = torch.autograd.grad(ctc_loss_via_crossentropy, logits, retain_graph = True)
assert torch.allclose(ctc_grad, ctc_grad_via_crossentropy, rtol = 1e-3)
@rbracco
Copy link

rbracco commented Aug 23, 2020

Thank you for sharing this, it's extremely interesting. I did find a few typos/mixups when I tried to run the code.

  1. There's a comma missing on 13, it should read B, C, T, t, blank = 16, 64, 32, 8, 0
  2. ce_loss is not defined, it's defined as ctc_loss_via_crossentropy so one of the names needs to change.

@vadimkantorov
Copy link
Author

Thanks for bringing this to attention! Fixed!

@MohammedAljahdali
Copy link

Thank you for the implementation, I was looking for a way to add weights to classes in CTC loss, could you explain how to do that in your code example, this would be of great help to my work. Thank you.

@vadimkantorov
Copy link
Author

vadimkantorov commented Mar 18, 2021

@MohammedAljahdali just modify https://gist.github.com/vadimkantorov/73e1915178f444b64f9ef01a1e96c1e4#file-ctc_alignment_targets-py-L25, check their shapes and add weighting that you'd like. I guess one could also use a different loss other than log-loss (in that case the expected alignment would be computed wrt log-loss, but actual loss would be different loss)

@MohammedAljahdali
Copy link

Thanks for replying!

weights = torch.ones(1, 64, 1) / 64
ctc_loss_via_crossentropy = (-alignment_targets * log_probs * weights).sum() 

Is this the correct way to add weights, in case we have 64 classes?

@vadimkantorov
Copy link
Author

I guess so. If you checked the shapes and it makes sense, it must be so. I haven't touched the code for a very long time, so you have a better information here :)

@MohammedAljahdali
Copy link

Thank you! I will test it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment