Skip to content

Instantly share code, notes, and snippets.

@colesbury
Created November 9, 2017 17:00
Show Gist options
  • Save colesbury/1da5041f2c20745c3775074af73aa12b to your computer and use it in GitHub Desktop.
Save colesbury/1da5041f2c20745c3775074af73aa12b to your computer and use it in GitHub Desktop.
commit 8f491cadf7b01468acc37ab749cdcac8133b8856
Author: Sam Gross <sgross@fb.com>
Date: Thu Nov 9 08:34:32 2017 -0800
WIP: packed sequence
diff --git a/test/test_nn.py b/test/test_nn.py
index 3e65a25..995af82 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -2100,7 +2100,8 @@ class TestNN(NNTestCase):
def pad(tensor, length):
return torch.cat([tensor, tensor.new(length - tensor.size(0), *tensor.size()[1:]).zero_()])
lengths = [10, 8, 4, 2, 2, 2, 1]
- max_length = lengths[0]
+ random.shuffle(lengths)
+ max_length = max(lengths)
batch_sizes = [sum(map(bool, filter(lambda x: x >= i, lengths))) for i in range(1, max_length + 1)]
offset = 0
padded = torch.cat([pad(i * 100 + torch.arange(1, 5 * l + 1).view(l, 1, 5), max_length)
diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py
index 5f36278..9257ccd 100644
--- a/torch/nn/modules/rnn.py
+++ b/torch/nn/modules/rnn.py
@@ -122,7 +122,7 @@ class RNNBase(Module):
def forward(self, input, hx=None):
is_packed = isinstance(input, PackedSequence)
if is_packed:
- input, batch_sizes = input
+ input, batch_sizes, indices = input
max_batch_size = batch_sizes[0]
else:
batch_sizes = None
@@ -159,7 +159,7 @@ class RNNBase(Module):
)
output, hidden = func(input, self.all_weights, hx)
if is_packed:
- output = PackedSequence(output, batch_sizes)
+ output = PackedSequence(output, batch_sizes, indices)
return output, hidden
def __repr__(self):
diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py
index 1a47bb3..95b6e3f 100644
--- a/torch/nn/utils/rnn.py
+++ b/torch/nn/utils/rnn.py
@@ -3,7 +3,7 @@ import torch
from torch.autograd import Variable
-PackedSequence_ = namedtuple('PackedSequence', ['data', 'batch_sizes'])
+PackedSequence_ = namedtuple('PackedSequence', ['data', 'batch_sizes', 'indices'])
class PackedSequence(PackedSequence_):
@@ -19,6 +19,7 @@ class PackedSequence(PackedSequence_):
data (Variable): Variable containing packed sequence
batch_sizes (list[int]): list of integers holding information about
the batch size at each sequence step
+ indices (list[int]): original ordering of the source sequence
"""
pass
@@ -58,13 +59,17 @@ def pack_padded_sequence(input, lengths, batch_first=False):
steps = []
batch_sizes = []
- lengths_iter = reversed(lengths)
batch_size = input.size(1)
if len(lengths) != batch_size:
raise ValueError("lengths array has incorrect size")
+ indices, lengths = zip(*sorted(enumerate(lengths), key=lambda x: x[1], reverse=True))
+ if indices != list(range(len(lengths))):
+ input = input[:, indices]
+ # .index_select(1, Variable(torch.LongTensor(indices)))
+
prev_l = 0
- for i, l in enumerate(lengths_iter):
+ for i, l in enumerate(reversed(lengths)):
if l > prev_l:
c_batch_size = batch_size - i
steps.append(input[prev_l:l, :c_batch_size].contiguous().view(-1, *input.size()[2:]))
@@ -73,7 +78,7 @@ def pack_padded_sequence(input, lengths, batch_first=False):
elif prev_l > l: # remember that new_length is the preceding length in the array
raise ValueError("lengths array has to be sorted in decreasing order")
- return PackedSequence(torch.cat(steps), batch_sizes)
+ return PackedSequence(torch.cat(steps), batch_sizes, indices)
def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0):
@@ -97,7 +102,7 @@ def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0):
Tuple of Variable containing the padded sequence, and a list of lengths
of each sequence in the batch.
"""
- var_data, batch_sizes = sequence
+ var_data, batch_sizes, indices = sequence
max_batch_size = batch_sizes[0]
output = var_data.data.new(len(batch_sizes), max_batch_size, *var_data.size()[1:]).fill_(padding_value)
output = Variable(output)
@@ -123,6 +128,16 @@ def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0):
lengths.extend((i + 1,) * batch_size)
lengths.reverse()
+ def as_long_tensor(indices):
+ v = Variable(torch.LongTensor(indices))
+ if sequence[0].is_cuda:
+ v = v.cuda()
+ return v
+
+ if indices != list(range(len(indices))):
+ output = torch.zeros_like(output).index_add_(1, as_long_tensor(indices), output)
+ lengths = [lengths[i] for i in indices]
+
if batch_first:
output = output.transpose(0, 1)
return output, lengths
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment