Created
November 20, 2017 14:52
-
-
Save vabh/2210ac54bc5cb202bfb1133df48ae58b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ConvLSTMCell(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, padding=1): | |
super(ConvLSTMCell, self).__init__() | |
self.k = kernel_size | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.padding = padding | |
self.w_i = nn.Parameter(torch.Tensor(4*out_channels, in_channels, kernel_size, kernel_size)) | |
self.w_h = nn.Parameter(torch.Tensor(4*out_channels, in_channels, kernel_size, kernel_size)) | |
self.w_c = nn.Parameter(torch.Tensor(3*out_channels, in_channels, kernel_size, kernel_size)) | |
# TODO include bias terms | |
self.reset_parameters() | |
def reset_parameters(self): | |
n = 4 * self.in_channels * self.k * self.k | |
stdv = 1. / math.sqrt(n) | |
self.w_i.data.uniform_(-stdv, stdv) | |
self.w_h.data.uniform_(-stdv, stdv) | |
self.w_c.data.uniform_(-stdv, stdv) | |
def forward(self, x, hx): | |
h, c = hx | |
wx = F.conv2d(x, self.w_i, padding=self.padding) | |
wh = F.conv2d(h, self.w_h, padding=self.padding) | |
wc = F.conv2d(c, self.w_c, padding=self.padding) | |
i = F.sigmoid(wx[:, :self.out_channels] + wh[:, :self.out_channels] + wc[:, :self.out_channels]) | |
f = F.sigmoid(wx[:, self.out_channels:2*self.out_channels] + wh[:, self.out_channels:2*self.out_channels] | |
+ wc[:, self.out_channels:2*self.out_channels]) | |
g = F.tanh(wx[:, 2*self.out_channels:3*self.out_channels] + wh[:, 2*self.out_channels:3*self.out_channels]) | |
c_t = f * c + i * g | |
o_t = F.sigmoid(wx[:, 3*self.out_channels:] + wh[:, 3*self.out_channels:] | |
+ wc[:, 2*self.out_channels: ]*c_t) | |
h_t = o_t * F.tanh(c_t) | |
return h_t, (h_t, c_t) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment