Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save yaatehr/aac21cae05b24101f2369c97cfecb47b to your computer and use it in GitHub Desktop.
Save yaatehr/aac21cae05b24101f2369c97cfecb47b to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
class LSTMBasic(nn.Module):
"""
Dropout note - lstm won't apply dropout to the last layer, so the only dropout to be apllies is
#TODO deprecate the build_in_dropout paramm for LSTM, turn to just a switch for dropout
#NOTE do not change the number of classes unless you want to switch off of binary classification (and change crit$
"""
def __init__(self, args, num_classes=1, built_in_dropout=False):
super(LSTMBasic, self).__init__()
dropout = args.dropout_input
self.hidden_dim = args.hidden_dim
self.bi = args.bidirectional == 1
self.num_layers = 1 # TODO remove param or deprecate?
self.num_classes = num_classes
self.embedding_depth = (
(args.number_of_characters + len(args.extra_characters))*args.max_embedding_length
if args.use_char_encoding
else args.embedding_depth * args.max_embedding_length
)
self.lstm = nn.LSTM(
self.embedding_depth,
hidden_size=self.hidden_dim,
bidirectional=self.bi,
batch_first=False,
dropout=dropout,
)
# Linear layer that maps from hidden state space to output space
self.hidden2out = nn.Linear(
(2 if self.bi else 1) * self.hidden_dim, self.num_classes
)
self.hidden = self.init_hidden()
def init_hidden(self):
# The axes semantics are (num_layers, minibatch_size, hidden_dim)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
h0 = torch.zeros(
(2 if self.bi else 1) * self.num_layers, 1, self.hidden_dim
).to(device)
h1 = torch.zeros(
(2 if self.bi else 1) * self.num_layers, 1, self.hidden_dim
).to(device)
return (h0, h1)
def forward(self, sequence):
print('Sequence shape:', sequence.shape)
sequence = sequence.clone().view(len(sequence), 1, -1)
print("flattened shape: ", sequence.shape)
lstm_out, hidden = self.lstm(
sequence, self.hidden
)
print(lstm_out.shape)
out_space = self.hidden2out(lstm_out[:, -1])
self.hidden = hidden
print("hiddens")
print(hidden[0].shape)
print(hidden[1].shape)
print(" out_space: ", out_space.shape)
out_scores = torch.sigmoid(out_space)
print("out_scores: ", out_scores.shape)
out = out_scores.squeeze()
print(out.shape)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment