Skip to content

Instantly share code, notes, and snippets.

@thomwolf
Created October 3, 2017 10:55
Show Gist options
  • Save thomwolf/49b59feb5ebd2ccb779b0b282b86b97a to your computer and use it in GitHub Desktop.
Save thomwolf/49b59feb5ebd2ccb779b0b282b86b97a to your computer and use it in GitHub Desktop.
Preparer a pyTorch PackedSequence for a batch of sequences
# input_seqs is a batch of input sequences as a numpy array of integers (word indices in vocabulary) padded with zeroas
input_seqs = Variable(torch.from_numpy(input_seqs.astype('int64')).long())
# First: order the batch by decreasing sequence length
input_lengths = torch.LongTensor([torch.max(input_seqs[i, :].data.nonzero()) + 1 for i in range(input_seqs.size()[0])])
input_lengths, perm_idx = input_lengths.sort(0, descending=True)
input_seqs = input_seqs[perm_idx][:, :input_lengths.max()]
# Then pack the sequences
packed_input = pack_padded_sequence(input_seqs, input_lengths.cpu().numpy(), batch_first=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment