Skip to content

Instantly share code, notes, and snippets.

@williamFalcon
Last active July 23, 2020 12:06
Show Gist options
  • Save williamFalcon/b0dc6d25b39e7da0d05e5713ef0a57af to your computer and use it in GitHub Desktop.
Save williamFalcon/b0dc6d25b39e7da0d05e5713ef0a57af to your computer and use it in GitHub Desktop.
"""
Blog post:
Taming LSTMs: Variable-sized mini-batches and why PyTorch is good for your health:
https://medium.com/@_willfalcon/taming-lstms-variable-sized-mini-batches-and-why-pytorch-is-good-for-your-health-61d35642972e
"""
def forward(self, X, X_lengths):
# reset the LSTM hidden state. Must be done before you run a new batch. Otherwise the LSTM will treat
# a new batch as a continuation of a sequence
self.hidden = self.init_hidden()
batch_size, seq_len, _ = X.size()
# ---------------------
# 1. embed the input
# Dim transformation: (batch_size, seq_len, 1) -> (batch_size, seq_len, embedding_dim)
X = self.word_embedding(X)
# ---------------------
# 2. Run through RNN
# TRICK 2 ********************************
# Dim transformation: (batch_size, seq_len, embedding_dim) -> (batch_size, seq_len, nb_lstm_units)
# pack_padded_sequence so that padded items in the sequence won't be shown to the LSTM
X = torch.nn.utils.rnn.pack_padded_sequence(x, X_lengths, batch_first=True)
# now run through LSTM
X, self.hidden = self.lstm(X, self.hidden)
# undo the packing operation
X, _ = torch.nn.utils.rnn.pad_packed_sequence(X, batch_first=True)
# ---------------------
# 3. Project to tag space
# Dim transformation: (batch_size, seq_len, nb_lstm_units) -> (batch_size * seq_len, nb_lstm_units)
# this one is a bit tricky as well. First we need to reshape the data so it goes into the linear layer
X = X.contiguous()
X = X.view(-1, X.shape[2])
# run through actual linear layer
X = self.hidden_to_tag(X)
# ---------------------
# 4. Create softmax activations bc we're doing classification
# Dim transformation: (batch_size * seq_len, nb_lstm_units) -> (batch_size, seq_len, nb_tags)
X = F.log_softmax(X, dim=1)
# I like to reshape for mental sanity so we're back to (batch_size, seq_len, nb_tags)
X = X.view(batch_size, seq_len, self.nb_tags)
Y_hat = X
return Y_hat
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment