Skip to content

Instantly share code, notes, and snippets.

@Kaixhin
Last active June 18, 2021 01:35
Show Gist options
  • Save Kaixhin/57901e91e5c5a8bac3eb0cbbdd3aba81 to your computer and use it in GitHub Desktop.
Save Kaixhin/57901e91e5c5a8bac3eb0cbbdd3aba81 to your computer and use it in GitHub Desktop.
Collection of LSTMs
# Collection of LSTM cells (including forget gates)
# https://en.wikipedia.org/w/index.php?title=Long_short-term_memory&oldid=784163987
import torch
from torch import nn
from torch.nn import Parameter
from torch.nn import functional as F
from torch.nn.modules.utils import _pair
from torch.autograd import Variable
class LSTMCell(nn.LSTMCell):
def forward(self, input, hx):
h, c = hx
wx = F.linear(input, self.weight_ih, self.bias_ih) # Weights combined into one matrix
wh = F.linear(h, self.weight_hh, self.bias_hh)
wxh = wx + wh
i = F.sigmoid(wxh[:, :self.hidden_size]) # Input gate
f = F.sigmoid(wxh[:, self.hidden_size:2 * self.hidden_size]) # Forget gate
g = F.tanh(wxh[:, 2 * self.hidden_size:3 * self.hidden_size]) # Cell gate?
o = F.sigmoid(wxh[:, 3 * self.hidden_size:]) # Output gate
c = f * c + i * g # Cell
h = o * F.tanh(c) # Hidden state
return h, (h, c)
class PeepholeLSTMCell(nn.LSTMCell):
def __init__(self, input_size, hidden_size, bias=True):
super(PeepholeLSTMCell, self).__init__(input_size, hidden_size, bias)
self.weight_ch = Parameter(torch.Tensor(3 * hidden_size, hidden_size))
if bias:
self.bias_ch = Parameter(torch.Tensor(3 * hidden_size))
else:
self.register_parameter('bias_ch', None)
self.register_buffer('wc_blank', torch.zeros(hidden_size))
self.reset_parameters()
def forward(self, input, hx):
h, c = hx
wx = F.linear(input, self.weight_ih, self.bias_ih)
wh = F.linear(h, self.weight_hh, self.bias_hh)
wc = F.linear(c, self.weight_ch, self.bias_ch)
wxhc = wx + wh + torch.cat((wc[:, :2 * self.hidden_size], Variable(self.wc_blank).expand_as(h), wc[:, 2 * self.hidden_size:]), 1)
i = F.sigmoid(wxhc[:, :self.hidden_size])
f = F.sigmoid(wxhc[:, self.hidden_size:2 * self.hidden_size])
g = F.tanh(wxhc[:, 2 * self.hidden_size:3 * self.hidden_size]) # No cell involvement
o = F.sigmoid(wxhc[:, 3 * self.hidden_size:])
c = f * c + i * g
h = o * F.tanh(c)
return h, (h, c)
class Conv2dLSTMCell(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
super(Conv2dLSTMCell, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.padding_h = tuple(k // 2 for k, s, p, d in zip(kernel_size, stride, padding, dilation))
self.dilation = dilation
self.groups = groups
self.weight_ih = Parameter(torch.Tensor(4 * out_channels, in_channels // groups, *kernel_size))
self.weight_hh = Parameter(torch.Tensor(4 * out_channels, out_channels // groups, *kernel_size))
self.weight_ch = Parameter(torch.Tensor(3 * out_channels, out_channels // groups, *kernel_size))
if bias:
self.bias_ih = Parameter(torch.Tensor(4 * out_channels))
self.bias_hh = Parameter(torch.Tensor(4 * out_channels))
self.bias_ch = Parameter(torch.Tensor(3 * out_channels))
else:
self.register_parameter('bias_ih', None)
self.register_parameter('bias_hh', None)
self.register_parameter('bias_ch', None)
self.register_buffer('wc_blank', torch.zeros(out_channels))
self.reset_parameters()
def reset_parameters(self):
n = 4 * self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight_ih.data.uniform_(-stdv, stdv)
self.weight_hh.data.uniform_(-stdv, stdv)
self.weight_ch.data.uniform_(-stdv, stdv)
if self.bias_ih is not None:
self.bias_ih.data.uniform_(-stdv, stdv)
self.bias_hh.data.uniform_(-stdv, stdv)
self.bias_ch.data.uniform_(-stdv, stdv)
def forward(self, input, hx):
h_0, c_0 = hx
wx = F.conv2d(input, self.weight_ih, self.bias_ih, self.stride, self.padding, self.dilation, self.groups)
wh = F.conv2d(h_0, self.weight_hh, self.bias_hh, self.stride, self.padding_h, self.dilation, self.groups)
# Cell uses a Hadamard product instead of a convolution?
wc = F.conv2d(c_0, self.weight_ch, self.bias_ch, self.stride, self.padding_h, self.dilation, self.groups)
wxhc = wx + wh + torch.cat((wc[:, :2 * self.out_channels], Variable(self.wc_blank).expand(wc.size(0), wc.size(1) // 3, wc.size(2), wc.size(3)), wc[:, 2 * self.out_channels:]), 1)
i = F.sigmoid(wxhc[:, :self.out_channels])
f = F.sigmoid(wxhc[:, self.out_channels:2 * self.out_channels])
g = F.tanh(wxhc[:, 2 * self.out_channels:3 * self.out_channels])
o = F.sigmoid(wxhc[:, 3 * self.out_channels:])
c_1 = f * c_0 + i * g
h_1 = o * F.tanh(c_1)
return h_1, (h_1, c_1)
lstm = LSTMCell(2, 1)
peeplstm = PeepholeLSTMCell(2, 1)
convlstm = Conv2dLSTMCell(2, 1, (3, 5), stride=1, padding=(0, 1))
x = Variable(torch.ones(1, 2))
h = Variable(torch.ones(1, 1))
c = Variable(torch.ones(1, 1))
t = Variable(torch.zeros(1, 1))
y, hx = lstm(x, (h, c))
loss = (y - t).mean()
loss.backward()
y, hx = peeplstm(x, (h, c))
loss = (y - t).mean()
loss.backward()
x.data.resize_(1, 2, 22, 22).random_()
h.data.resize_(1, 1, 20, 20).random_()
c.data.resize_(1, 1, 20, 20).random_()
t.data.resize_(1, 1, 20, 20).random_()
y, hx = convlstm(x, (h, c))
loss = (y - t).mean()
loss.backward()
@leido
Copy link

leido commented Jul 13, 2017

Hi @Kaixhin, I found your code is really helpful. But I don't understand the purpose of the _pair function in Conv2dLSTMCell class(e.g. line 65-68). Can you explain this? Thanks

@Kaixhin
Copy link
Author

Kaixhin commented Jul 24, 2017

Ah sorry I got _pair from the source code of Conv2d. I'll add the correct import now.

@devraj89
Copy link

devraj89 commented Jun 7, 2018

Hi @Kaixhin I found your code to be quite helpful ! Can you please tell me why for the PeepholeLSTM
you have used the previous hidden states for computing each of the gates.
shouldn't wh=0 (according to the wikipedia article) ?

Also can you please tell me what is the format of input ? Based on the format here
https://pytorch.org/docs/master/nn.html?highlight=nn%20lstm#torch.nn.LSTM
input of shape (seq_len, batch, input_size)
in your case what is the first dimension signifying ?
how to make your code work using seq_len ?

Thanks in advance!

@Kaixhin
Copy link
Author

Kaixhin commented Jun 28, 2018

I've done a fairly general formulation which includes both h and c as is done in some of the original papers, but there's plenty of variants of these.
These are recurrent cells, so only take a single timestep of the input at once - so (batch, input_size). I haven't considered making this work with a sequence, but doing so in an efficient manner would require a fair bit of work I believe.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment