Last active
April 9, 2024 03:03
-
-
Save vadimkantorov/73e1915178f444b64f9ef01a1e96c1e4 to your computer and use it in GitHub Desktop.
An implementation of CTC re-formulation via cross-entropy with pseudo-labels, following "A Novel Re-weighting Method for Connectionist Temporal Classification"
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# CTC vanilla and CTC via crossentropy are equal, and their gradients as well. In this reformulation it's easier to experiment with modifications of CTC. | |
# References on CTC regularization: | |
# "A Novel Re-weighting Method for Connectionist Temporal Classification", Li et al, https://arxiv.org/abs/1904.10619 | |
# "Focal CTC Loss for Chinese Optical Character Recognition on Unbalanced Datasets", Feng et al, https://www.hindawi.com/journals/complexity/2019/9345861/ | |
# "Improved training for online end-to-end speech recognition systems", Kim et al, https://arxiv.org/abs/1711.02212 | |
import torch | |
import torch.nn.functional as F | |
## generate example data | |
# generation is not very stable because of this bug https://github.com/pytorch/pytorch/issues/31557 | |
torch.manual_seed(1) | |
B, C, T, t, blank = 16, 64, 32, 8, 0 | |
logits = torch.randn(B, C, T).requires_grad_() | |
input_lengths = torch.full((B,), T, dtype = torch.long) | |
target_lengths = torch.full((B,), t, dtype = torch.long) | |
targets = torch.randint(blank + 1, C, (B, t), dtype = torch.long) | |
## compute CTC alignment targets | |
log_probs = F.log_softmax(logits, dim = 1) | |
ctc_loss = F.ctc_loss(log_probs.permute(2, 0, 1), targets, input_lengths, target_lengths, blank = blank, reduction = 'sum') | |
ctc_grad, = torch.autograd.grad(ctc_loss, (logits,), retain_graph = True) | |
temporal_mask = (torch.arange(logits.shape[-1], device = input_lengths.device, dtype = input_lengths.dtype).unsqueeze(0) < input_lengths.unsqueeze(1))[:, None, :] | |
alignment_targets = (log_probs.exp() * temporal_mask - ctc_grad).detach() | |
ctc_loss_via_crossentropy = (-alignment_targets * log_probs).sum() | |
ctc_grad, = torch.autograd.grad(ctc_loss, logits, retain_graph = True) | |
ctc_grad_via_crossentropy, = torch.autograd.grad(ctc_loss_via_crossentropy, logits, retain_graph = True) | |
assert torch.allclose(ctc_grad, ctc_grad_via_crossentropy, rtol = 1e-3) |
@MohammedAljahdali just modify https://gist.github.com/vadimkantorov/73e1915178f444b64f9ef01a1e96c1e4#file-ctc_alignment_targets-py-L25, check their shapes and add weighting that you'd like. I guess one could also use a different loss other than log-loss (in that case the expected alignment would be computed wrt log-loss, but actual loss would be different loss)
Thanks for replying!
weights = torch.ones(1, 64, 1) / 64
ctc_loss_via_crossentropy = (-alignment_targets * log_probs * weights).sum()
Is this the correct way to add weights, in case we have 64 classes?
I guess so. If you checked the shapes and it makes sense, it must be so. I haven't touched the code for a very long time, so you have a better information here :)
Thank you! I will test it.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thank you for the implementation, I was looking for a way to add weights to classes in CTC loss, could you explain how to do that in your code example, this would be of great help to my work. Thank you.