Skip to content

Instantly share code, notes, and snippets.

Forked from Tushar-N/
Last active June 23, 2024 12:46
Show Gist options
  • Save HarshTrivedi/f4e7293e941b17d19058f6fb90ab0fec to your computer and use it in GitHub Desktop.
Save HarshTrivedi/f4e7293e941b17d19058f6fb90ab0fec to your computer and use it in GitHub Desktop.
Minimal tutorial on packing (pack_padded_sequence) and unpacking (pad_packed_sequence) sequences in pytorch.
import torch
from torch import LongTensor
from torch.nn import Embedding, LSTM
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
## We want to run LSTM on a batch of 3 character sequences ['long_str', 'tiny', 'medium']
# Step 1: Construct Vocabulary
# Step 2: Load indexed data (list of instances, where each instance is list of character indices)
# Step 3: Make Model
# * Step 4: Pad instances with 0s till max length sequence
# * Step 5: Sort instances by sequence length in descending order
# * Step 6: Embed the instances
# * Step 7: Call pack_padded_sequence with embeded instances and sequence lengths
# * Step 8: Forward with LSTM
# * Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector
# * Summary of Shape Transformations
# We want to run LSTM on a batch following 3 character sequences
seqs = ['long_str', # len = 8
'tiny', # len = 4
'medium'] # len = 6
## Step 1: Construct Vocabulary ##
# make sure <pad> idx is 0
vocab = ['<pad>'] + sorted(set([char for seq in seqs for char in seq]))
# => ['<pad>', '_', 'd', 'e', 'g', 'i', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'y']
## Step 2: Load indexed data (list of instances, where each instance is list of character indices) ##
vectorized_seqs = [[vocab.index(tok) for tok in seq]for seq in seqs]
# vectorized_seqs => [[6, 9, 8, 4, 1, 11, 12, 10],
# [12, 5, 8, 14],
# [7, 3, 2, 5, 13, 7]]
## Step 3: Make Model ##
embed = Embedding(len(vocab), 4) # embedding_dim = 4
lstm = LSTM(input_size=4, hidden_size=5, batch_first=True) # input_dim = 4, hidden_dim = 5
## Step 4: Pad instances with 0s till max length sequence ##
# get the length of each seq in your batch
seq_lengths = LongTensor(list(map(len, vectorized_seqs)))
# seq_lengths => [ 8, 4, 6]
# batch_sum_seq_len: 8 + 4 + 6 = 18
# max_seq_len: 8
seq_tensor = Variable(torch.zeros((len(vectorized_seqs), seq_lengths.max()))).long()
# seq_tensor => [[0 0 0 0 0 0 0 0]
# [0 0 0 0 0 0 0 0]
# [0 0 0 0 0 0 0 0]]
for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)):
seq_tensor[idx, :seqlen] = LongTensor(seq)
# seq_tensor => [[ 6 9 8 4 1 11 12 10] # long_str
# [12 5 8 14 0 0 0 0] # tiny
# [ 7 3 2 5 13 7 0 0]] # medium
# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8)
## Step 5: Sort instances by sequence length in descending order ##
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
seq_tensor = seq_tensor[perm_idx]
# seq_tensor => [[ 6 9 8 4 1 11 12 10] # long_str
# [ 7 3 2 5 13 7 0 0] # medium
# [12 5 8 14 0 0 0 0]] # tiny
# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8)
## Step 6: Embed the instances ##
embedded_seq_tensor = embed(seq_tensor)
# embedded_seq_tensor =>
# [[[-0.77578706 -1.8080667 -1.1168439 1.1059115 ] l
# [-0.23622951 2.0361056 0.15435742 -0.04513785] o
# [-0.6000342 1.1732816 0.19938554 -1.5976517 ] n
# [ 0.40524676 0.98665565 -0.08621677 -1.1728264 ] g
# [-1.6334635 -0.6100042 1.7509955 -1.931793 ] _
# [-0.6470658 -0.6266589 -1.7463604 1.2675372 ] s
# [ 0.64004815 0.45813003 0.3476034 -0.03451729] t
# [-0.22739866 -0.45782727 -0.6643252 0.25129375]] r
# [[ 0.16031227 -0.08209462 -0.16297023 0.48121014] m
# [-0.7303265 -0.857339 0.58913064 -1.1068314 ] e
# [ 0.48159844 -1.4886451 0.92639893 0.76906884] d
# [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] i
# [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ] u
# [ 0.16031227 -0.08209462 -0.16297023 0.48121014] m
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] <pad>
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ]] <pad>
# [[ 0.64004815 0.45813003 0.3476034 -0.03451729] t
# [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] i
# [-0.6000342 1.1732816 0.19938554 -1.5976517 ] n
# [-1.284392 0.68294704 1.4064184 -0.42879772] y
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] <pad>
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] <pad>
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ] <pad>
# [ 0.2691206 -0.43435425 0.87935454 -2.2269666 ]]] <pad>
# embedded_seq_tensor.shape : (batch_size X max_seq_len X embedding_dim) = (3 X 8 X 4)
## Step 7: Call pack_padded_sequence with embeded instances and sequence lengths ##
packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True)
# packed_input (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes
# =>
# [[-0.77578706 -1.8080667 -1.1168439 1.1059115 ] l
# [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ] m
# [-0.6470658 -0.6266589 -1.7463604 1.2675372 ] t
# [ 0.16031227 -0.08209462 -0.16297023 0.48121014] o
# [ 0.40524676 0.98665565 -0.08621677 -1.1728264 ] e
# [-1.284392 0.68294704 1.4064184 -0.42879772] i
# [ 0.64004815 0.45813003 0.3476034 -0.03451729] n
# [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] d
# [ 0.64004815 0.45813003 0.3476034 -0.03451729] n
# [-0.23622951 2.0361056 0.15435742 -0.04513785] g
# [ 0.16031227 -0.08209462 -0.16297023 0.48121014] i
# [-0.22739866 -0.45782727 -0.6643252 0.25129375]] y
# [-0.7303265 -0.857339 0.58913064 -1.1068314 ] _
# [-1.6334635 -0.6100042 1.7509955 -1.931793 ] u
# [ 0.27616557 -1.224429 -1.342848 -0.7495876 ] s
# [-0.6000342 1.1732816 0.19938554 -1.5976517 ] m
# [-0.6000342 1.1732816 0.19938554 -1.5976517 ] t
# [ 0.48159844 -1.4886451 0.92639893 0.76906884] r
# : (batch_sum_seq_len X embedding_dim) = (18 X 4)
# packed_input.batch_sizes => [ 3, 3, 3, 3, 2, 2, 1, 1]
# visualization :
# l o n g _ s t r #(long_str)
# m e d i u m #(medium)
# t i n y #(tiny)
# 3 3 3 3 2 2 1 1 (sum = 18 [batch_sum_seq_len])
## Step 8: Forward with LSTM ##
packed_output, (ht, ct) = lstm(packed_input)
# packed_output (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes
# :
# [[-0.00947162 0.07743231 0.20343193 0.29611713 0.07992904] l
# [ 0.08596145 0.09205993 0.20892891 0.21788561 0.00624391] o
# [ 0.16861682 0.07807446 0.18812777 -0.01148055 -0.01091915] n
# [ 0.20994528 0.17932937 0.17748171 0.05025435 0.15717036] g
# [ 0.01364102 0.11060348 0.14704391 0.24145307 0.12879576] _
# [ 0.02610307 0.00965587 0.31438383 0.246354 0.08276576] s
# [ 0.09527554 0.14521319 0.1923058 -0.05925677 0.18633027] t
# [ 0.09872741 0.13324396 0.19446367 0.4307988 -0.05149471] r
# [ 0.03895474 0.08449443 0.18839942 0.02205326 0.23149511] m
# [ 0.14620507 0.07822411 0.2849248 -0.22616537 0.15480657] e
# [ 0.00884941 0.05762182 0.30557525 0.373712 0.08834908] d
# [ 0.12460691 0.21189159 0.04823487 0.06384943 0.28563985] i
# [ 0.01368293 0.15872964 0.03759198 -0.13403234 0.23890573] u
# [ 0.00377969 0.05943518 0.2961751 0.35107893 0.15148178] m
# [ 0.00737647 0.17101538 0.28344846 0.18878219 0.20339936] t
# [ 0.0864429 0.11173367 0.3158251 0.37537992 0.11876849] i
# [ 0.17885767 0.12713005 0.28287745 0.05562563 0.10871304] n
# [ 0.09486895 0.12772645 0.34048414 0.25930756 0.12044918]] y
# : (batch_sum_seq_len X hidden_dim) = (18 X 5)
# packed_output.batch_sizes => [ 3, 3, 3, 3, 2, 2, 1, 1] (same as packed_input.batch_sizes)
# visualization :
# l o n g _ s t r #(long_str)
# m e d i u m #(medium)
# t i n y #(tiny)
# 3 3 3 3 2 2 1 1 (sum = 18 [batch_sum_seq_len])
## Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector ##
# unpack your output if required
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)
# output:
# output =>
# [[[-0.00947162 0.07743231 0.20343193 0.29611713 0.07992904] l
# [ 0.20994528 0.17932937 0.17748171 0.05025435 0.15717036] o
# [ 0.09527554 0.14521319 0.1923058 -0.05925677 0.18633027] n
# [ 0.14620507 0.07822411 0.2849248 -0.22616537 0.15480657] g
# [ 0.01368293 0.15872964 0.03759198 -0.13403234 0.23890573] _
# [ 0.00737647 0.17101538 0.28344846 0.18878219 0.20339936] s
# [ 0.17885767 0.12713005 0.28287745 0.05562563 0.10871304] t
# [ 0.09486895 0.12772645 0.34048414 0.25930756 0.12044918]] r
# [[ 0.08596145 0.09205993 0.20892891 0.21788561 0.00624391] m
# [ 0.01364102 0.11060348 0.14704391 0.24145307 0.12879576] e
# [ 0.09872741 0.13324396 0.19446367 0.4307988 -0.05149471] d
# [ 0.00884941 0.05762182 0.30557525 0.373712 0.08834908] i
# [ 0.00377969 0.05943518 0.2961751 0.35107893 0.15148178] u
# [ 0.0864429 0.11173367 0.3158251 0.37537992 0.11876849] m
# [ 0. 0. 0. 0. 0. ] <pad>
# [ 0. 0. 0. 0. 0. ]] <pad>
# [[ 0.16861682 0.07807446 0.18812777 -0.01148055 -0.01091915] t
# [ 0.02610307 0.00965587 0.31438383 0.246354 0.08276576] i
# [ 0.03895474 0.08449443 0.18839942 0.02205326 0.23149511] n
# [ 0.12460691 0.21189159 0.04823487 0.06384943 0.28563985] y
# [ 0. 0. 0. 0. 0. ] <pad>
# [ 0. 0. 0. 0. 0. ] <pad>
# [ 0. 0. 0. 0. 0. ] <pad>
# [ 0. 0. 0. 0. 0. ]]] <pad>
# output.shape : ( batch_size X max_seq_len X hidden_dim) = (3 X 8 X 5)
# Or if you just want the final hidden state?
## Summary of Shape Transformations ##
# (batch_size X max_seq_len X embedding_dim) --> Sort by seqlen ---> (batch_size X max_seq_len X embedding_dim)
# (batch_size X max_seq_len X embedding_dim) ---> Pack ---> (batch_sum_seq_len X embedding_dim)
# (batch_sum_seq_len X embedding_dim) ---> LSTM ---> (batch_sum_seq_len X hidden_dim)
# (batch_sum_seq_len X hidden_dim) ---> UnPack ---> (batch_size X max_seq_len X hidden_dim)
Copy link

Great tutorial!
A question: in which situations we might want to use pad_packed_sequence? When calculating loss, wouldn't it be simpler to work with packed (without pads) scores (LSTM outputs) and packed targets? Or do I need/have to sometimes pad scores and targets using pad_packed_sequence? If yes, then when is it used?

Copy link

Thank you very much!

Copy link

Great work!

Copy link

Thanks a lot for putting this together.

Copy link

RudRho commented Jul 20, 2020

Line#146 is the icing on the cake.


Copy link

Great work!

Copy link

Great work.

Copy link

Thanks a lot! 👍

Copy link

thank you, it is very helpful!

Copy link

This is great! Congratulation

Copy link

Bro where did the len object in line 51 come from?

Copy link

Perfectly explained! Was always confused on what data goes into the batch.

Copy link

laifi commented Nov 19, 2020

Why sort instances by sequence length in descending order step is needed?

pack_padded_sequence does not need sorting anymore,its a parameter in the function (Doc)

**enforce_sorted** (bool, optional) –if True, the input is expected to contain sequences sorted by length in a decreasing order. If False, the input will get sorted unconditionally. Default: True.

Copy link

rayryeng commented Nov 22, 2020

@jackfrost29 - len is a built-in method in classes. When calling len, it accesses the __len__ method for whatever object is used as input. The usual understanding with len is that it finds the length / size of whatever object you pass to it. In this case, the object is a list of token lists so it finds the length of every token list in vectorized_seqs.

Copy link

Wonder why nobody complains about lines 120-138, as the packed sequence is clearly wrong.

Clearly, the first three rows in the packed sequence are not l, m, t but l, u, s for example. There are also too many closing brackets in line 132.

Copy link

tombosc commented Feb 25, 2021

Pretty helpful, thank you

Copy link

Thankyou very much.It's a very important paper.

Copy link

you sort them, then you need back to original position right? I want to use a hidden state, is that right?
''' a_lengths, idx = text_length.sort(0, descending=True)
_, un_idx = t.sort(idx, dim=0)
seq = text[idx]

    seq = self.dropout(self.embedding(seq))

    a_packed_input = t.nn.utils.rnn.pack_padded_sequence(input=seq,'cpu'), batch_first=True)
    packed_output, (hidden, cell) = self.rnn(a_packed_input)
    out, _ = t.nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
    hidden = self.dropout([-2, :, :], hidden[-1, :, :]), dim=1))

    hidden = t.index_select(hidden, 0, un_idx)

Copy link

just what i was looking for, thanks

Copy link

elch10 commented Dec 30, 2021

I can't find any performance comparision. Did anyone compare using pack_padded_sequence with just padded sequence?

Copy link

Copy link

Y-jiji commented Mar 19, 2022

Why sort instances by sequence length in descending order step is needed?

If you want to export this model as ONNX, enforce_sorted option must be True.
However, if this model is not to be used in production, you can set enforce_sorted=False to avoid sorting.

Copy link

jainraj commented Apr 16, 2022


Copy link

Very helpful!

Copy link

hungkien05 commented Jul 7, 2022

Most easy-to-understand explanation I have read !

Copy link

awnsome !

Copy link

gunnxx commented Dec 6, 2022

I think packed_output is wrong. It should be the same order as the packed_input. Otherwise, calling pad_packed_sequence will generate inconsistent behavior between packed_output and packed_input.

See this example

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

x = F.one_hot(torch.arange(9) % 3, num_classes=3).reshape(3, 3, 3).float()
length = torch.tensor([3, 1, 2], dtype=torch.int64)
# tensor([[[1., 0., 0.],
#          [0., 1., 0.],
#          [0., 0., 1.]],

#         [[1., 0., 0.],
#          [0., 1., 0.],
#          [0., 0., 1.]],

#         [[1., 0., 0.],
#          [0., 1., 0.],
#          [0., 0., 1.]]])

packed_x = pack_padded_sequence(x, length, batch_first=True, enforce_sorted=False)
# PackedSequence(data=tensor([[1., 0., 0.],
#         [1., 0., 0.],
#         [1., 0., 0.],
#         [0., 1., 0.],
#         [0., 1., 0.],
#         [0., 0., 1.]]), batch_sizes=tensor([3, 2, 1]), sorted_indices=tensor([0, 2, 1]), unsorted_indices=tensor([0, 2, 1]))

m = nn.RNN(3, 1, batch_first=True, bias=False)
# [[Parameter containing:
# tensor([[-0.7161,  0.8613, -0.8458]], requires_grad=True), Parameter containing:
# tensor([[0.2222]], requires_grad=True)]]

packed_lstm_out, _ = m.forward(packed_x)
# PackedSequence(data=tensor([[-0.6145],
#         [-0.6145],
#         [-0.6145],
#         [ 0.6198],
#         [ 0.6198],
#         [-0.6094]], grad_fn=<CatBackward0>), batch_sizes=tensor([3, 2, 1]), sorted_indices=tensor([0, 2, 1]), unsorted_indices=tensor([0, 2, 1]))

unpacked_lstm_out, unpacked_length = pad_packed_sequence(packed_lstm_out, batch_first=True)
# tensor([[[-0.6145],
#          [ 0.6198],
#          [-0.6094]],

#         [[-0.6145],
#          [ 0.0000],
#          [ 0.0000]],

#         [[-0.6145],
#          [ 0.6198],
#          [ 0.0000]]], grad_fn=<IndexSelectBackward0>)

Copy link

This is very helpful. Thank you.

Copy link

XianghengHee commented Feb 7, 2024

The explanation for packed_output, (ht, ct) = lstm(packed_input) seems not correct.

should be:

# packed_output (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes
# :
#                          [[-0.00947162  0.07743231  0.20343193  0.29611713  0.07992904]   l
#                           [ 0.08596145  0.09205993  0.20892891  0.21788561  0.00624391]   m
#                           [ 0.16861682  0.07807446  0.18812777 -0.01148055 -0.01091915]   t
#                           [ 0.20994528  0.17932937  0.17748171  0.05025435  0.15717036]   o
#                           [ 0.01364102  0.11060348  0.14704391  0.24145307  0.12879576]   e
#                           [ 0.02610307  0.00965587  0.31438383  0.246354    0.08276576]   i
#                           [ 0.09527554  0.14521319  0.1923058  -0.05925677  0.18633027]   n
#                           [ 0.09872741  0.13324396  0.19446367  0.4307988  -0.05149471]   d
#                           [ 0.03895474  0.08449443  0.18839942  0.02205326  0.23149511]   n
#                           [ 0.14620507  0.07822411  0.2849248  -0.22616537  0.15480657]   g
#                           [ 0.00884941  0.05762182  0.30557525  0.373712    0.08834908]   i
#                           [ 0.12460691  0.21189159  0.04823487  0.06384943  0.28563985]   y
#                           [ 0.01368293  0.15872964  0.03759198 -0.13403234  0.23890573]   _
#                           [ 0.00377969  0.05943518  0.2961751   0.35107893  0.15148178]   u
#                           [ 0.00737647  0.17101538  0.28344846  0.18878219  0.20339936]   s
#                           [ 0.0864429   0.11173367  0.3158251   0.37537992  0.11876849]   m
#                           [ 0.17885767  0.12713005  0.28287745  0.05562563  0.10871304]   t
#                           [ 0.09486895  0.12772645  0.34048414  0.25930756  0.12044918]]  r
# : (batch_sum_seq_len X hidden_dim) = (18 X 5)

# packed_output.batch_sizes => [ 3,  3,  3,  3,  2,  2,  1,  1] (same as packed_input.batch_sizes)
# visualization :
# l  o  n  g  _  s  t  r   #(long_str)
# m  e  d  i  u  m         #(medium)
# t  i  n  y               #(tiny)
# 3  3  3  3  2  2  1  1   (sum = 18 [batch_sum_seq_len])

Copy link

It's really help me to understand

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment