Created
May 9, 2020 17:07
-
-
Save yaatehr/aac21cae05b24101f2369c97cfecb47b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class LSTMBasic(nn.Module): | |
""" | |
Dropout note - lstm won't apply dropout to the last layer, so the only dropout to be apllies is | |
#TODO deprecate the build_in_dropout paramm for LSTM, turn to just a switch for dropout | |
#NOTE do not change the number of classes unless you want to switch off of binary classification (and change crit$ | |
""" | |
def __init__(self, args, num_classes=1, built_in_dropout=False): | |
super(LSTMBasic, self).__init__() | |
dropout = args.dropout_input | |
self.hidden_dim = args.hidden_dim | |
self.bi = args.bidirectional == 1 | |
self.num_layers = 1 # TODO remove param or deprecate? | |
self.num_classes = num_classes | |
self.embedding_depth = ( | |
(args.number_of_characters + len(args.extra_characters))*args.max_embedding_length | |
if args.use_char_encoding | |
else args.embedding_depth * args.max_embedding_length | |
) | |
self.lstm = nn.LSTM( | |
self.embedding_depth, | |
hidden_size=self.hidden_dim, | |
bidirectional=self.bi, | |
batch_first=False, | |
dropout=dropout, | |
) | |
# Linear layer that maps from hidden state space to output space | |
self.hidden2out = nn.Linear( | |
(2 if self.bi else 1) * self.hidden_dim, self.num_classes | |
) | |
self.hidden = self.init_hidden() | |
def init_hidden(self): | |
# The axes semantics are (num_layers, minibatch_size, hidden_dim) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
h0 = torch.zeros( | |
(2 if self.bi else 1) * self.num_layers, 1, self.hidden_dim | |
).to(device) | |
h1 = torch.zeros( | |
(2 if self.bi else 1) * self.num_layers, 1, self.hidden_dim | |
).to(device) | |
return (h0, h1) | |
def forward(self, sequence): | |
print('Sequence shape:', sequence.shape) | |
sequence = sequence.clone().view(len(sequence), 1, -1) | |
print("flattened shape: ", sequence.shape) | |
lstm_out, hidden = self.lstm( | |
sequence, self.hidden | |
) | |
print(lstm_out.shape) | |
out_space = self.hidden2out(lstm_out[:, -1]) | |
self.hidden = hidden | |
print("hiddens") | |
print(hidden[0].shape) | |
print(hidden[1].shape) | |
print(" out_space: ", out_space.shape) | |
out_scores = torch.sigmoid(out_space) | |
print("out_scores: ", out_scores.shape) | |
out = out_scores.squeeze() | |
print(out.shape) | |
return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment