-
-
Save napoler/c69931e9b9eb07443b6fa42cc83a177c to your computer and use it in GitHub Desktop.
convlstm_cell
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import torch | |
class ConvLSTMCell(nn.Module): | |
def __init__(self, input_dim, hidden_dim, kernel_size, bias): | |
""" | |
Initialize ConvLSTM cell. | |
Parameters | |
---------- | |
input_dim: int | |
Number of channels of input tensor. | |
hidden_dim: int | |
Number of channels of hidden state. | |
kernel_size: (int, int) | |
Size of the convolutional kernel. | |
bias: bool | |
Whether or not to add the bias. | |
""" | |
super(ConvLSTMCell, self).__init__() | |
self.input_dim = input_dim | |
self.hidden_dim = hidden_dim | |
self.kernel_size = kernel_size | |
self.padding = kernel_size[0] // 2, kernel_size[1] // 2 | |
self.bias = bias | |
self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, | |
out_channels=4 * self.hidden_dim, | |
kernel_size=self.kernel_size, | |
padding=self.padding, | |
bias=self.bias) | |
def forward(self, input_tensor, cur_state): | |
h_cur, c_cur = cur_state | |
combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis | |
combined_conv = self.conv(combined) | |
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) | |
i = torch.sigmoid(cc_i) | |
f = torch.sigmoid(cc_f) | |
o = torch.sigmoid(cc_o) | |
g = torch.tanh(cc_g) | |
c_next = f * c_cur + i * g | |
h_next = o * torch.tanh(c_next) | |
return h_next, c_next | |
def init_hidden(self, batch_size, image_size): | |
height, width = image_size | |
return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device), | |
torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment