Skip to content

Instantly share code, notes, and snippets.

@ereday
Forked from Tushar-N/pad_packed_demo.py
Created October 5, 2018 12:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ereday/401b77f689197f6d6da95ddb008eae72 to your computer and use it in GitHub Desktop.
Save ereday/401b77f689197f6d6da95ddb008eae72 to your computer and use it in GitHub Desktop.
How to use pad_packed_sequence in pytorch
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F
import numpy as np
import itertools
def flatten(l):
return list(itertools.chain.from_iterable(l))
seqs = ['gigantic_string','tiny_str','medium_str']
# make <pad> idx 0
vocab = ['<pad>'] + sorted(list(set(flatten(seqs))))
# make model
embed = nn.Embedding(len(vocab), 10).cuda()
lstm = nn.LSTM(10, 5).cuda()
vectorized_seqs = [[vocab.index(tok) for tok in seq]for seq in seqs]
# get the length of each seq in your batch
seq_lengths = torch.cuda.LongTensor(map(len, vectorized_seqs))
# dump padding everywhere, and place seqs on the left.
# NOTE: you only need a tensor as big as your longest sequence
seq_tensor = Variable(torch.zeros((len(vectorized_seqs), seq_lengths.max()))).long().cuda()
for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)):
seq_tensor[idx, :seqlen] = torch.LongTensor(seq)
# SORT YOUR TENSORS BY LENGTH!
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
seq_tensor = seq_tensor[perm_idx]
# utils.rnn lets you give (B,L,D) tensors where B is the batch size, L is the maxlength, if you use batch_first=True
# Otherwise, give (L,B,D) tensors
seq_tensor = seq_tensor.transpose(0,1) # (B,L,D) -> (L,B,D)
# embed your sequences
seq_tensor = embed(seq_tensor)
# pack them up nicely
packed_input = pack_padded_sequence(seq_tensor, seq_lengths.cpu().numpy())
# throw them through your LSTM (remember to give batch_first=True here if you packed with it)
packed_output, (ht, ct) = lstm(packed_input)
# unpack your output if required
output, _ = pad_packed_sequence(packed_output)
print output
# Or if you just want the final hidden state?
print ht[-1]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment