Skip to content

Instantly share code, notes, and snippets.

@tadas-subonis
Created July 31, 2019 10:13
Show Gist options
  • Save tadas-subonis/795f20c5f2b4e549fa2aecc84d474db2 to your computer and use it in GitHub Desktop.
Save tadas-subonis/795f20c5f2b4e549fa2aecc84d474db2 to your computer and use it in GitHub Desktop.
from torch_baidu_ctc import ctc_loss
class CTCLossT(nn.Module):
def __init__(self):
super(CTCLossT, self).__init__()
self.blank = 0
self.reduction = 'sum'
def forward(self, log_probs, targets):
#log_probs = log_probs.cpu()
#targets = targets.cpu()
#print(log_probs.shape, targets.shape)
#targets = targets.permute(1, 0)
batch_size = log_probs.size(1)
#print(log_probs)
#T = input_image_max_len
T = log_probs.size(0)
N = batch_size
D = targets.size(1)
input_lengths = torch.full(size=(N,), fill_value=D, dtype=torch.long)
target_lengths = torch.full(size=(N,), fill_value=D, dtype=torch.long)
with torch.backends.cudnn.flags(enabled=False):
loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction, zero_infinity=True)
#print("loss", loss)
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment