-
-
Save Kaixhin/57901e91e5c5a8bac3eb0cbbdd3aba81 to your computer and use it in GitHub Desktop.
# Collection of LSTM cells (including forget gates) | |
# https://en.wikipedia.org/w/index.php?title=Long_short-term_memory&oldid=784163987 | |
import torch | |
from torch import nn | |
from torch.nn import Parameter | |
from torch.nn import functional as F | |
from torch.nn.modules.utils import _pair | |
from torch.autograd import Variable | |
class LSTMCell(nn.LSTMCell): | |
def forward(self, input, hx): | |
h, c = hx | |
wx = F.linear(input, self.weight_ih, self.bias_ih) # Weights combined into one matrix | |
wh = F.linear(h, self.weight_hh, self.bias_hh) | |
wxh = wx + wh | |
i = F.sigmoid(wxh[:, :self.hidden_size]) # Input gate | |
f = F.sigmoid(wxh[:, self.hidden_size:2 * self.hidden_size]) # Forget gate | |
g = F.tanh(wxh[:, 2 * self.hidden_size:3 * self.hidden_size]) # Cell gate? | |
o = F.sigmoid(wxh[:, 3 * self.hidden_size:]) # Output gate | |
c = f * c + i * g # Cell | |
h = o * F.tanh(c) # Hidden state | |
return h, (h, c) | |
class PeepholeLSTMCell(nn.LSTMCell): | |
def __init__(self, input_size, hidden_size, bias=True): | |
super(PeepholeLSTMCell, self).__init__(input_size, hidden_size, bias) | |
self.weight_ch = Parameter(torch.Tensor(3 * hidden_size, hidden_size)) | |
if bias: | |
self.bias_ch = Parameter(torch.Tensor(3 * hidden_size)) | |
else: | |
self.register_parameter('bias_ch', None) | |
self.register_buffer('wc_blank', torch.zeros(hidden_size)) | |
self.reset_parameters() | |
def forward(self, input, hx): | |
h, c = hx | |
wx = F.linear(input, self.weight_ih, self.bias_ih) | |
wh = F.linear(h, self.weight_hh, self.bias_hh) | |
wc = F.linear(c, self.weight_ch, self.bias_ch) | |
wxhc = wx + wh + torch.cat((wc[:, :2 * self.hidden_size], Variable(self.wc_blank).expand_as(h), wc[:, 2 * self.hidden_size:]), 1) | |
i = F.sigmoid(wxhc[:, :self.hidden_size]) | |
f = F.sigmoid(wxhc[:, self.hidden_size:2 * self.hidden_size]) | |
g = F.tanh(wxhc[:, 2 * self.hidden_size:3 * self.hidden_size]) # No cell involvement | |
o = F.sigmoid(wxhc[:, 3 * self.hidden_size:]) | |
c = f * c + i * g | |
h = o * F.tanh(c) | |
return h, (h, c) | |
class Conv2dLSTMCell(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): | |
super(Conv2dLSTMCell, self).__init__() | |
if in_channels % groups != 0: | |
raise ValueError('in_channels must be divisible by groups') | |
if out_channels % groups != 0: | |
raise ValueError('out_channels must be divisible by groups') | |
kernel_size = _pair(kernel_size) | |
stride = _pair(stride) | |
padding = _pair(padding) | |
dilation = _pair(dilation) | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.padding = padding | |
self.padding_h = tuple(k // 2 for k, s, p, d in zip(kernel_size, stride, padding, dilation)) | |
self.dilation = dilation | |
self.groups = groups | |
self.weight_ih = Parameter(torch.Tensor(4 * out_channels, in_channels // groups, *kernel_size)) | |
self.weight_hh = Parameter(torch.Tensor(4 * out_channels, out_channels // groups, *kernel_size)) | |
self.weight_ch = Parameter(torch.Tensor(3 * out_channels, out_channels // groups, *kernel_size)) | |
if bias: | |
self.bias_ih = Parameter(torch.Tensor(4 * out_channels)) | |
self.bias_hh = Parameter(torch.Tensor(4 * out_channels)) | |
self.bias_ch = Parameter(torch.Tensor(3 * out_channels)) | |
else: | |
self.register_parameter('bias_ih', None) | |
self.register_parameter('bias_hh', None) | |
self.register_parameter('bias_ch', None) | |
self.register_buffer('wc_blank', torch.zeros(out_channels)) | |
self.reset_parameters() | |
def reset_parameters(self): | |
n = 4 * self.in_channels | |
for k in self.kernel_size: | |
n *= k | |
stdv = 1. / math.sqrt(n) | |
self.weight_ih.data.uniform_(-stdv, stdv) | |
self.weight_hh.data.uniform_(-stdv, stdv) | |
self.weight_ch.data.uniform_(-stdv, stdv) | |
if self.bias_ih is not None: | |
self.bias_ih.data.uniform_(-stdv, stdv) | |
self.bias_hh.data.uniform_(-stdv, stdv) | |
self.bias_ch.data.uniform_(-stdv, stdv) | |
def forward(self, input, hx): | |
h_0, c_0 = hx | |
wx = F.conv2d(input, self.weight_ih, self.bias_ih, self.stride, self.padding, self.dilation, self.groups) | |
wh = F.conv2d(h_0, self.weight_hh, self.bias_hh, self.stride, self.padding_h, self.dilation, self.groups) | |
# Cell uses a Hadamard product instead of a convolution? | |
wc = F.conv2d(c_0, self.weight_ch, self.bias_ch, self.stride, self.padding_h, self.dilation, self.groups) | |
wxhc = wx + wh + torch.cat((wc[:, :2 * self.out_channels], Variable(self.wc_blank).expand(wc.size(0), wc.size(1) // 3, wc.size(2), wc.size(3)), wc[:, 2 * self.out_channels:]), 1) | |
i = F.sigmoid(wxhc[:, :self.out_channels]) | |
f = F.sigmoid(wxhc[:, self.out_channels:2 * self.out_channels]) | |
g = F.tanh(wxhc[:, 2 * self.out_channels:3 * self.out_channels]) | |
o = F.sigmoid(wxhc[:, 3 * self.out_channels:]) | |
c_1 = f * c_0 + i * g | |
h_1 = o * F.tanh(c_1) | |
return h_1, (h_1, c_1) | |
lstm = LSTMCell(2, 1) | |
peeplstm = PeepholeLSTMCell(2, 1) | |
convlstm = Conv2dLSTMCell(2, 1, (3, 5), stride=1, padding=(0, 1)) | |
x = Variable(torch.ones(1, 2)) | |
h = Variable(torch.ones(1, 1)) | |
c = Variable(torch.ones(1, 1)) | |
t = Variable(torch.zeros(1, 1)) | |
y, hx = lstm(x, (h, c)) | |
loss = (y - t).mean() | |
loss.backward() | |
y, hx = peeplstm(x, (h, c)) | |
loss = (y - t).mean() | |
loss.backward() | |
x.data.resize_(1, 2, 22, 22).random_() | |
h.data.resize_(1, 1, 20, 20).random_() | |
c.data.resize_(1, 1, 20, 20).random_() | |
t.data.resize_(1, 1, 20, 20).random_() | |
y, hx = convlstm(x, (h, c)) | |
loss = (y - t).mean() | |
loss.backward() |
Ah sorry I got _pair
from the source code of Conv2d. I'll add the correct import now.
Hi @Kaixhin I found your code to be quite helpful ! Can you please tell me why for the PeepholeLSTM
you have used the previous hidden states for computing each of the gates.
shouldn't wh=0
(according to the wikipedia article) ?
Also can you please tell me what is the format of input
? Based on the format here
https://pytorch.org/docs/master/nn.html?highlight=nn%20lstm#torch.nn.LSTM
input of shape (seq_len, batch, input_size)
in your case what is the first dimension signifying ?
how to make your code work using seq_len ?
Thanks in advance!
I've done a fairly general formulation which includes both h
and c
as is done in some of the original papers, but there's plenty of variants of these.
These are recurrent cells, so only take a single timestep of the input at once - so (batch, input_size)
. I haven't considered making this work with a sequence, but doing so in an efficient manner would require a fair bit of work I believe.
Hi @Kaixhin, I found your code is really helpful. But I don't understand the purpose of the
_pair
function inConv2dLSTMCell
class(e.g. line 65-68). Can you explain this? Thanks