Created
November 9, 2017 17:00
-
-
Save colesbury/1da5041f2c20745c3775074af73aa12b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
commit 8f491cadf7b01468acc37ab749cdcac8133b8856 | |
Author: Sam Gross <sgross@fb.com> | |
Date: Thu Nov 9 08:34:32 2017 -0800 | |
WIP: packed sequence | |
diff --git a/test/test_nn.py b/test/test_nn.py | |
index 3e65a25..995af82 100644 | |
--- a/test/test_nn.py | |
+++ b/test/test_nn.py | |
@@ -2100,7 +2100,8 @@ class TestNN(NNTestCase): | |
def pad(tensor, length): | |
return torch.cat([tensor, tensor.new(length - tensor.size(0), *tensor.size()[1:]).zero_()]) | |
lengths = [10, 8, 4, 2, 2, 2, 1] | |
- max_length = lengths[0] | |
+ random.shuffle(lengths) | |
+ max_length = max(lengths) | |
batch_sizes = [sum(map(bool, filter(lambda x: x >= i, lengths))) for i in range(1, max_length + 1)] | |
offset = 0 | |
padded = torch.cat([pad(i * 100 + torch.arange(1, 5 * l + 1).view(l, 1, 5), max_length) | |
diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py | |
index 5f36278..9257ccd 100644 | |
--- a/torch/nn/modules/rnn.py | |
+++ b/torch/nn/modules/rnn.py | |
@@ -122,7 +122,7 @@ class RNNBase(Module): | |
def forward(self, input, hx=None): | |
is_packed = isinstance(input, PackedSequence) | |
if is_packed: | |
- input, batch_sizes = input | |
+ input, batch_sizes, indices = input | |
max_batch_size = batch_sizes[0] | |
else: | |
batch_sizes = None | |
@@ -159,7 +159,7 @@ class RNNBase(Module): | |
) | |
output, hidden = func(input, self.all_weights, hx) | |
if is_packed: | |
- output = PackedSequence(output, batch_sizes) | |
+ output = PackedSequence(output, batch_sizes, indices) | |
return output, hidden | |
def __repr__(self): | |
diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py | |
index 1a47bb3..95b6e3f 100644 | |
--- a/torch/nn/utils/rnn.py | |
+++ b/torch/nn/utils/rnn.py | |
@@ -3,7 +3,7 @@ import torch | |
from torch.autograd import Variable | |
-PackedSequence_ = namedtuple('PackedSequence', ['data', 'batch_sizes']) | |
+PackedSequence_ = namedtuple('PackedSequence', ['data', 'batch_sizes', 'indices']) | |
class PackedSequence(PackedSequence_): | |
@@ -19,6 +19,7 @@ class PackedSequence(PackedSequence_): | |
data (Variable): Variable containing packed sequence | |
batch_sizes (list[int]): list of integers holding information about | |
the batch size at each sequence step | |
+ indices (list[int]): original ordering of the source sequence | |
""" | |
pass | |
@@ -58,13 +59,17 @@ def pack_padded_sequence(input, lengths, batch_first=False): | |
steps = [] | |
batch_sizes = [] | |
- lengths_iter = reversed(lengths) | |
batch_size = input.size(1) | |
if len(lengths) != batch_size: | |
raise ValueError("lengths array has incorrect size") | |
+ indices, lengths = zip(*sorted(enumerate(lengths), key=lambda x: x[1], reverse=True)) | |
+ if indices != list(range(len(lengths))): | |
+ input = input[:, indices] | |
+ # .index_select(1, Variable(torch.LongTensor(indices))) | |
+ | |
prev_l = 0 | |
- for i, l in enumerate(lengths_iter): | |
+ for i, l in enumerate(reversed(lengths)): | |
if l > prev_l: | |
c_batch_size = batch_size - i | |
steps.append(input[prev_l:l, :c_batch_size].contiguous().view(-1, *input.size()[2:])) | |
@@ -73,7 +78,7 @@ def pack_padded_sequence(input, lengths, batch_first=False): | |
elif prev_l > l: # remember that new_length is the preceding length in the array | |
raise ValueError("lengths array has to be sorted in decreasing order") | |
- return PackedSequence(torch.cat(steps), batch_sizes) | |
+ return PackedSequence(torch.cat(steps), batch_sizes, indices) | |
def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0): | |
@@ -97,7 +102,7 @@ def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0): | |
Tuple of Variable containing the padded sequence, and a list of lengths | |
of each sequence in the batch. | |
""" | |
- var_data, batch_sizes = sequence | |
+ var_data, batch_sizes, indices = sequence | |
max_batch_size = batch_sizes[0] | |
output = var_data.data.new(len(batch_sizes), max_batch_size, *var_data.size()[1:]).fill_(padding_value) | |
output = Variable(output) | |
@@ -123,6 +128,16 @@ def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0): | |
lengths.extend((i + 1,) * batch_size) | |
lengths.reverse() | |
+ def as_long_tensor(indices): | |
+ v = Variable(torch.LongTensor(indices)) | |
+ if sequence[0].is_cuda: | |
+ v = v.cuda() | |
+ return v | |
+ | |
+ if indices != list(range(len(indices))): | |
+ output = torch.zeros_like(output).index_add_(1, as_long_tensor(indices), output) | |
+ lengths = [lengths[i] for i in indices] | |
+ | |
if batch_first: | |
output = output.transpose(0, 1) | |
return output, lengths |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment