Skip to content

Instantly share code, notes, and snippets.

@Yevgnen
Created October 29, 2017 08:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Yevgnen/619a80ba6f1f06abaf8905804c707042 to your computer and use it in GitHub Desktop.
Save Yevgnen/619a80ba6f1f06abaf8905804c707042 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Illustrating pytorch's padding API
import torch
import numpy as np
vocab_size = 20
max_len = 10
n_samples = 3
emb_size = 5
hidden_size = 6
emb = torch.nn.Embedding(vocab_size + 1, emb_size)
gru = torch.nn.GRU(emb_size, hidden_size)
seqs = [
np.random.randint(1, vocab_size, np.random.randint(1, max_len + 1))
for x in range(n_samples)
]
lens = [len(x) for x in seqs]
sorted_seqs, lens = zip(*sorted(zip(seqs, lens), key=lambda x: -x[1]))
padded_seqs = np.array(
[np.pad(x, (0, max_len - len(x)), 'constant') for x in sorted_seqs]).T
var = torch.autograd.Variable(torch.LongTensor(padded_seqs))
embedded = emb(var)
# Approach 1: Zeros will be treated as word index: WRONG!
init_hidden = torch.autograd.Variable(torch.zeros([1, n_samples, hidden_size]))
outputs, hidden = gru(embedded, init_hidden)
# Approach 2: Padding API
ppseqs = torch.nn.utils.rnn.pack_padded_sequence(embedded, lens)
pp_outputs, pp_hidden = gru(ppseqs, init_hidden)
pp_outputs_ = torch.nn.utils.rnn.pad_packed_sequence(pp_outputs)[0]
pp_outputs_ = pp_outputs_[[u - 1 for u in lens], list(range(len(lens)))]
# Approach 3: Manually forward one by one
ops = []
hds = []
for x in sorted_seqs:
x = torch.autograd.Variable(torch.LongTensor(x).view(-1, 1))
x = emb(x)
o, h = gru(x, torch.autograd.Variable(torch.zeros([1, 1, hidden_size])))
ops.append(o)
hds.append(h)
hds = torch.cat(list(hds), 1)
ops = torch.cat([x[-1] for x in ops], 0)
assert torch.norm(hds - pp_hidden, p=1).data.numpy() < 1e-10
assert torch.norm(hds - pp_hidden, p=1).data.numpy() < 1e-10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment