Some experimental variations on LSTMs
Last active
April 9, 2018 19:55
-
-
Save garibarba/fe65d3b7a9308f1e4dbda4949d7103d2 to your computer and use it in GitHub Desktop.
LSTM variant experiments in pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class anLSTMCell(nn.Module): | |
def __init__(self, inout_size, hidden_sizes, output_size=None): | |
""" | |
inout_size: [int] | |
In absence of output_size, it is the number of features in the input and output of the cell. | |
If output_size is present, it is only the input size. | |
hidden_sizes: [int or list] | |
Sizes of the nested hidden states. Respectively also the inputs and output sizes of the nested cells. | |
output_size: [int (optional)] | |
If given, changes the behaviour of inout_size. | |
""" | |
super(anLSTMCell, self).__init__() | |
if output_size is not None: | |
self.output_size = output_size | |
self.input_size = inout_size | |
self.hidden_sizes = list(hidden_sizes) # make list for coherency | |
self.input_size = self.output_size = inout_size # same size by default | |
self.depth = len(self.hidden_sizes) | |
self.nested_cell = None | |
self.hidden = [Variable(torch.zeros(1, self.hidden_sizes[0])), | |
Variable(torch.zeros(1, self.hidden_sizes[0]))] | |
if len(self.hidden_sizes) > 1: | |
# updates to the error carousel happen through a nested cell | |
self.nested_cell = anLSTMCell(self.hidden_sizes[0], self.hidden_sizes[1:]) | |
self.ih = nn.Linear(self.input_size, 4 * self.hidden_sizes[0]) | |
self.hh = nn.Linear(self.hidden_sizes[0], 4 * self.hidden_sizes[0]) | |
self.h2o = nn.Linear(self.hidden_sizes[0], self.output_size) | |
if USE_CUDA: | |
self.hidden = [h.cuda() for h in self.hidden] | |
def detach_hidden(self): | |
for h in self.hidden: | |
h.detach_() | |
if self.nested_cell is not None: | |
self.nested_cell.detach_hidden() | |
def reset_hidden(self): | |
for h in self.hidden: | |
h.detach_() | |
h.zero_() | |
if self.nested_cell is not None: | |
self.nested_cell.reset_hidden() | |
def forward(self, input, nested_forget=1): | |
gates = self.ih(input) + self.hh(self.hidden[0] * nested_forget) | |
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) | |
ingate = F.sigmoid(ingate) | |
forgetgate = F.sigmoid(forgetgate) | |
cellgate = F.tanh(cellgate) | |
outgate = F.sigmoid(outgate) | |
cy = (forgetgate * self.hidden[1]) + (ingate * cellgate) | |
if self.nested_cell is not None: | |
cy = cy + F.tanh(self.nested_cell(F.tanh(cy))) | |
self.hidden[1] = cy | |
self.hidden[0] = outgate * F.tanh(cy) | |
oy = self.h2o(self.hidden[0]) | |
return oy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class attentionalLSTMCell(nn.Module): | |
def __init__(self, input_size, hidden_size, num_variants): | |
super(attentionalLSTMCell, self).__init__() | |
self.hidden_size = hidden_size | |
self.input_size = input_size | |
self.num_variants = num_variants | |
self.ih = nn.Linear(self.input_size, 4 * self.hidden_size * self.num_variants) | |
self.hh = nn.Linear(self.hidden_size, 4 * self.hidden_size * self.num_variants) | |
self.hhh = nn.Linear(self.hidden_size, self.num_variants) | |
def forward(self, input, hidden): | |
hx, cx = hidden | |
gates = self.ih(input) + self.hh(hx) | |
gates_weights = F.softmax(self.hhh(cx)) | |
gates = gates.view(-1, 4 * self.hidden_size, self.num_variants) | |
gates = torch.matmul(gates, gates_weights.squeeze()) | |
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) | |
ingate = F.sigmoid(ingate) | |
forgetgate = F.sigmoid(forgetgate) | |
cellgate = F.tanh(cellgate) | |
outgate = F.sigmoid(outgate) | |
cy = (forgetgate * cx) + (ingate * cellgate) | |
hy = outgate * F.tanh(cy) | |
return hy, cy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class nsLSTMCell(nn.Module): | |
def __init__(self, inout_size, hidden_sizes, output_size=None): | |
""" | |
inout_size: [int] | |
In absence of output_size, it is the number of features in the input and output of the cell. | |
If output_size is present, it is only the input size. | |
hidden_sizes: [int or list] | |
Sizes of the nested hidden states. Respectively also the inputs and output sizes of the nested cells. | |
output_size: [int (optional)] | |
If given, changes the behaviour of inout_size. | |
""" | |
super(snLSTMCell, self).__init__() | |
if output_size is not None: | |
self.output_size = output_size | |
self.input_size = inout_size | |
self.hidden_sizes = list(hidden_sizes) # make list for coherency | |
self.input_size = self.output_size = inout_size # same size by default | |
self.depth = len(self.hidden_sizes) | |
self.nested_cell = None | |
self.hidden = [Variable(torch.zeros(1, self.hidden_sizes[0]))] | |
if len(self.hidden_sizes) == 1: | |
self.hidden.append(Variable(torch.zeros(1, self.hidden_sizes[0]))) | |
self.forget_size = self.hidden_sizes[0] | |
else: | |
# updates to the error carousel happen through a nested cell | |
self.nested_cell = snLSTMCell(self.hidden_sizes[0], self.hidden_sizes[1:]) | |
self.forget_size = self.hidden_sizes[1] | |
self.ih = nn.Linear(self.input_size, 3 * self.hidden_sizes[0] + self.forget_size) | |
self.hh = nn.Linear(self.hidden_sizes[0], 3 * self.hidden_sizes[0] + self.forget_size) | |
self.h2o = nn.Linear(self.hidden_sizes[0], self.output_size) | |
if USE_CUDA: | |
self.hidden = [h.cuda() for h in self.hidden] | |
def detach_hidden(self): | |
for h in self.hidden: | |
h.detach_() | |
if self.nested_cell is not None: | |
self.nested_cell.detach_hidden() | |
def reset_hidden(self): | |
for h in self.hidden: | |
h.detach_() | |
h.zero_() | |
if self.nested_cell is not None: | |
self.nested_cell.reset_hidden() | |
def forward(self, input, nested_forget=1): | |
gates = self.ih(input) + self.hh(self.hidden[0] * nested_forget) | |
ingate, cellgate, outgate, forgetgate = gates.split(self.hidden_sizes[0], 1) | |
ingate = F.sigmoid(ingate) | |
forgetgate = F.sigmoid(forgetgate) | |
cellgate = F.tanh(cellgate) | |
outgate = F.sigmoid(outgate) | |
if self.nested_cell is not None: | |
cy = self.nested_cell(ingate * cellgate, forgetgate) | |
else: | |
cy = (forgetgate * self.hidden[1]) + (ingate * cellgate) | |
self.hidden[1] = cy | |
self.hidden[0] = outgate * F.tanh(cy) | |
oy = self.h2o(self.hidden[0]) | |
return oy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class sLSTMCell(nn.Module): | |
""" | |
LSTM variant including an additional linear layer that independizes the hidden size from the output size. | |
In particular, the output size is the same as the input size. | |
""" | |
def __init__(self, inout_size, hidden_size): | |
super(sLSTMCell, self).__init__() | |
self.hidden_size = hidden_size | |
self.inout_size = inout_size | |
self.ih = nn.Linear(self.inout_size, 4 * self.hidden_size) | |
self.hh = nn.Linear(self.hidden_size, 4 * self.hidden_size) | |
self.h2o = nn.Linear(self.hidden_size, self.inout_size) | |
self.hx = Variable(torch.zeros(1, self.hidden_size)) | |
self.cx = Variable(torch.zeros(1, self.hidden_size)) | |
if torch.cuda.is_available(): | |
self.hx = self.hx.cuda() | |
self.cx = self.cx.cuda() | |
def detach_hidden(self): | |
self.hx.detach_() | |
self.cx.detach_() | |
def reset_hidden(self): | |
self.detach_hidden() | |
self.hx.zero_() | |
self.cx.zero_() | |
def forward(self, input): | |
gates = self.ih(input) + self.hh(self.hx) | |
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) | |
ingate = F.sigmoid(ingate) | |
forgetgate = F.sigmoid(forgetgate) | |
cellgate = F.tanh(cellgate) | |
outgate = F.sigmoid(outgate) | |
cy = (forgetgate * self.cx) + (ingate * cellgate) | |
hy = outgate * F.tanh(cy) | |
oy = self.h2o(hy) | |
self.hx = hy | |
self.cx = cy | |
return oy |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment