Created
July 17, 2020 21:33
-
-
Save holmdk/804281c24e36371c0141ac3aca8e7e00 to your computer and use it in GitHub Desktop.
convlstm_cell
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import torch | |
class ConvLSTMCell(nn.Module): | |
def __init__(self, input_dim, hidden_dim, kernel_size, bias): | |
""" | |
Initialize ConvLSTM cell. | |
Parameters | |
---------- | |
input_dim: int | |
Number of channels of input tensor. | |
hidden_dim: int | |
Number of channels of hidden state. | |
kernel_size: (int, int) | |
Size of the convolutional kernel. | |
bias: bool | |
Whether or not to add the bias. | |
""" | |
super(ConvLSTMCell, self).__init__() | |
self.input_dim = input_dim | |
self.hidden_dim = hidden_dim | |
self.kernel_size = kernel_size | |
self.padding = kernel_size[0] // 2, kernel_size[1] // 2 | |
self.bias = bias | |
self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, | |
out_channels=4 * self.hidden_dim, | |
kernel_size=self.kernel_size, | |
padding=self.padding, | |
bias=self.bias) | |
def forward(self, input_tensor, cur_state): | |
h_cur, c_cur = cur_state | |
combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis | |
combined_conv = self.conv(combined) | |
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) | |
i = torch.sigmoid(cc_i) | |
f = torch.sigmoid(cc_f) | |
o = torch.sigmoid(cc_o) | |
g = torch.tanh(cc_g) | |
c_next = f * c_cur + i * g | |
h_next = o * torch.tanh(c_next) | |
return h_next, c_next | |
def init_hidden(self, batch_size, image_size): | |
height, width = image_size | |
return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device), | |
torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment