Skip to content

Instantly share code, notes, and snippets.

@dolaameng
Last active December 31, 2021 05:19
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save dolaameng/8918fefc05d0589f66279eab763d1e13 to your computer and use it in GitHub Desktop.
Save dolaameng/8918fefc05d0589f66279eab763d1e13 to your computer and use it in GitHub Desktop.
Variable Length Sequence for RNN in pytorch Example
import torch
import torch.nn as nn
from torch.autograd import Variable
batch_size = 3
max_length = 3
hidden_size = 2
n_layers =1
# container
batch_in = torch.zeros((batch_size, max_length, 1))
#data
vec_1 = torch.FloatTensor([[1, 2, 3]])
vec_2 = torch.FloatTensor([[1, 2, 0]])
vec_3 = torch.FloatTensor([[1, 0, 0]])
batch_in[0] = vec_1
batch_in[1] = vec_2
batch_in[2] = vec_3
batch_in = Variable(batch_in)
seq_lengths = [3,2,1]
pack = torch.nn.utils.rnn.pack_padded_sequence(batch_in, seq_lengths, batch_first=True)
print(batch_in.size()) # >>> torch.Size([3, 3, 1])
rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
out, _ = rnn(pack, h0)
unpacked, unpacked_len = torch.nn.utils.rnn.pad_packed_sequence(out)
print(unpacked.size()) # >>> torch.Size([3, 3, 2])
print(unpacked_len) # >>> [3, 2, 1]
print(unpacked[2, ...])
# >>>
# Variable containing:
# -0.0818 -0.4678
# 0.0000 0.0000
# 0.0000 0.0000
# [torch.FloatTensor of size 3x2]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment