Skip to content

Instantly share code, notes, and snippets.

@rdp
Created February 24, 2018 01:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rdp/bc27be54ec883109989426a9af79ca39 to your computer and use it in GitHub Desktop.
Save rdp/bc27be54ec883109989426a9af79ca39 to your computer and use it in GitHub Desktop.
import torch
from torch.autograd import Variable
from warpctc_pytorch import CTCLoss
ctc_loss = CTCLoss()
# expected shape of seqLength x batchSize x alphabet_size
probs = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).contiguous()
labels = Variable(torch.IntTensor([1, 2]))
label_sizes = Variable(torch.IntTensor([2]))
probs_sizes = Variable(torch.IntTensor([2]))
probs = Variable(probs, requires_grad=True) # tells autograd to compute gradients for probs
cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
cost.backward()
print "warp ctc is working!"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment