Skip to content

Instantly share code, notes, and snippets.

@garibarba
Last active April 9, 2018 19:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save garibarba/fe65d3b7a9308f1e4dbda4949d7103d2 to your computer and use it in GitHub Desktop.
Save garibarba/fe65d3b7a9308f1e4dbda4949d7103d2 to your computer and use it in GitHub Desktop.
LSTM variant experiments in pytorch
class anLSTMCell(nn.Module):
def __init__(self, inout_size, hidden_sizes, output_size=None):
"""
inout_size: [int]
In absence of output_size, it is the number of features in the input and output of the cell.
If output_size is present, it is only the input size.
hidden_sizes: [int or list]
Sizes of the nested hidden states. Respectively also the inputs and output sizes of the nested cells.
output_size: [int (optional)]
If given, changes the behaviour of inout_size.
"""
super(anLSTMCell, self).__init__()
if output_size is not None:
self.output_size = output_size
self.input_size = inout_size
self.hidden_sizes = list(hidden_sizes) # make list for coherency
self.input_size = self.output_size = inout_size # same size by default
self.depth = len(self.hidden_sizes)
self.nested_cell = None
self.hidden = [Variable(torch.zeros(1, self.hidden_sizes[0])),
Variable(torch.zeros(1, self.hidden_sizes[0]))]
if len(self.hidden_sizes) > 1:
# updates to the error carousel happen through a nested cell
self.nested_cell = anLSTMCell(self.hidden_sizes[0], self.hidden_sizes[1:])
self.ih = nn.Linear(self.input_size, 4 * self.hidden_sizes[0])
self.hh = nn.Linear(self.hidden_sizes[0], 4 * self.hidden_sizes[0])
self.h2o = nn.Linear(self.hidden_sizes[0], self.output_size)
if USE_CUDA:
self.hidden = [h.cuda() for h in self.hidden]
def detach_hidden(self):
for h in self.hidden:
h.detach_()
if self.nested_cell is not None:
self.nested_cell.detach_hidden()
def reset_hidden(self):
for h in self.hidden:
h.detach_()
h.zero_()
if self.nested_cell is not None:
self.nested_cell.reset_hidden()
def forward(self, input, nested_forget=1):
gates = self.ih(input) + self.hh(self.hidden[0] * nested_forget)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = F.sigmoid(outgate)
cy = (forgetgate * self.hidden[1]) + (ingate * cellgate)
if self.nested_cell is not None:
cy = cy + F.tanh(self.nested_cell(F.tanh(cy)))
self.hidden[1] = cy
self.hidden[0] = outgate * F.tanh(cy)
oy = self.h2o(self.hidden[0])
return oy
class attentionalLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, num_variants):
super(attentionalLSTMCell, self).__init__()
self.hidden_size = hidden_size
self.input_size = input_size
self.num_variants = num_variants
self.ih = nn.Linear(self.input_size, 4 * self.hidden_size * self.num_variants)
self.hh = nn.Linear(self.hidden_size, 4 * self.hidden_size * self.num_variants)
self.hhh = nn.Linear(self.hidden_size, self.num_variants)
def forward(self, input, hidden):
hx, cx = hidden
gates = self.ih(input) + self.hh(hx)
gates_weights = F.softmax(self.hhh(cx))
gates = gates.view(-1, 4 * self.hidden_size, self.num_variants)
gates = torch.matmul(gates, gates_weights.squeeze())
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = F.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * F.tanh(cy)
return hy, cy

Some experimental variations on LSTMs

class nsLSTMCell(nn.Module):
def __init__(self, inout_size, hidden_sizes, output_size=None):
"""
inout_size: [int]
In absence of output_size, it is the number of features in the input and output of the cell.
If output_size is present, it is only the input size.
hidden_sizes: [int or list]
Sizes of the nested hidden states. Respectively also the inputs and output sizes of the nested cells.
output_size: [int (optional)]
If given, changes the behaviour of inout_size.
"""
super(snLSTMCell, self).__init__()
if output_size is not None:
self.output_size = output_size
self.input_size = inout_size
self.hidden_sizes = list(hidden_sizes) # make list for coherency
self.input_size = self.output_size = inout_size # same size by default
self.depth = len(self.hidden_sizes)
self.nested_cell = None
self.hidden = [Variable(torch.zeros(1, self.hidden_sizes[0]))]
if len(self.hidden_sizes) == 1:
self.hidden.append(Variable(torch.zeros(1, self.hidden_sizes[0])))
self.forget_size = self.hidden_sizes[0]
else:
# updates to the error carousel happen through a nested cell
self.nested_cell = snLSTMCell(self.hidden_sizes[0], self.hidden_sizes[1:])
self.forget_size = self.hidden_sizes[1]
self.ih = nn.Linear(self.input_size, 3 * self.hidden_sizes[0] + self.forget_size)
self.hh = nn.Linear(self.hidden_sizes[0], 3 * self.hidden_sizes[0] + self.forget_size)
self.h2o = nn.Linear(self.hidden_sizes[0], self.output_size)
if USE_CUDA:
self.hidden = [h.cuda() for h in self.hidden]
def detach_hidden(self):
for h in self.hidden:
h.detach_()
if self.nested_cell is not None:
self.nested_cell.detach_hidden()
def reset_hidden(self):
for h in self.hidden:
h.detach_()
h.zero_()
if self.nested_cell is not None:
self.nested_cell.reset_hidden()
def forward(self, input, nested_forget=1):
gates = self.ih(input) + self.hh(self.hidden[0] * nested_forget)
ingate, cellgate, outgate, forgetgate = gates.split(self.hidden_sizes[0], 1)
ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = F.sigmoid(outgate)
if self.nested_cell is not None:
cy = self.nested_cell(ingate * cellgate, forgetgate)
else:
cy = (forgetgate * self.hidden[1]) + (ingate * cellgate)
self.hidden[1] = cy
self.hidden[0] = outgate * F.tanh(cy)
oy = self.h2o(self.hidden[0])
return oy
class sLSTMCell(nn.Module):
"""
LSTM variant including an additional linear layer that independizes the hidden size from the output size.
In particular, the output size is the same as the input size.
"""
def __init__(self, inout_size, hidden_size):
super(sLSTMCell, self).__init__()
self.hidden_size = hidden_size
self.inout_size = inout_size
self.ih = nn.Linear(self.inout_size, 4 * self.hidden_size)
self.hh = nn.Linear(self.hidden_size, 4 * self.hidden_size)
self.h2o = nn.Linear(self.hidden_size, self.inout_size)
self.hx = Variable(torch.zeros(1, self.hidden_size))
self.cx = Variable(torch.zeros(1, self.hidden_size))
if torch.cuda.is_available():
self.hx = self.hx.cuda()
self.cx = self.cx.cuda()
def detach_hidden(self):
self.hx.detach_()
self.cx.detach_()
def reset_hidden(self):
self.detach_hidden()
self.hx.zero_()
self.cx.zero_()
def forward(self, input):
gates = self.ih(input) + self.hh(self.hx)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = F.sigmoid(outgate)
cy = (forgetgate * self.cx) + (ingate * cellgate)
hy = outgate * F.tanh(cy)
oy = self.h2o(hy)
self.hx = hy
self.cx = cy
return oy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment