Skip to content

Instantly share code, notes, and snippets.

@bentrevett
Created July 18, 2019 16:03
Show Gist options
  • Save bentrevett/c085fdc17539e5a0e90aa2ea47c4edfc to your computer and use it in GitHub Desktop.
Save bentrevett/c085fdc17539e5a0e90aa2ea47c4edfc to your computer and use it in GitHub Desktop.
convLSTM model from https://arxiv.org/abs/1901.03559 without all the bells and whistles (encoded observation skip connection, top down skip connection, pool-and-inject)
import torch
import torch.nn as nn
class ConvLSTMCell(nn.Module):
def __init__(self,
input_size,
input_dim,
hidden_dim,
kernel_size,
bias):
super().__init__()
self.height, self.width = input_size
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
self.bias = bias
self.Wxi = nn.Conv2d(self.input_dim, self.hidden_dim, self.kernel_size, 1, self.padding, self.bias)
self.Wxf = nn.Conv2d(self.input_dim, self.hidden_dim, self.kernel_size, 1, self.padding, self.bias)
self.Wxo = nn.Conv2d(self.input_dim, self.hidden_dim, self.kernel_size, 1, self.padding, self.bias)
self.Wxg = nn.Conv2d(self.input_dim, self.hidden_dim, self.kernel_size, 1, self.padding, self.bias)
self.Whi = nn.Conv2d(self.hidden_dim, self.hidden_dim, self.kernel_size, 1, self.padding, self.bias)
self.Whf = nn.Conv2d(self.hidden_dim, self.hidden_dim, self.kernel_size, 1, self.padding, self.bias)
self.Who = nn.Conv2d(self.hidden_dim, self.hidden_dim, self.kernel_size, 1, self.padding, self.bias)
self.Whg = nn.Conv2d(self.hidden_dim, self.hidden_dim, self.kernel_size, 1, self.padding, self.bias)
def forward(self, x, h, c):
#x = [batch, input dim, height, width]
#h = [batch, hidden dim, height, width]
#c = [batch, hidden dim, height, width]
batch_size = x.shape[0]
assert x.shape == (batch_size, self.input_dim, self.height, self.width)
assert h.shape == (batch_size, self.hidden_dim, self.height, self.width)
assert c.shape == (batch_size, self.hidden_dim, self.height, self.width)
i = torch.sigmoid(self.Wxi(x) + self.Whi(h))
f = torch.sigmoid(self.Wxf(x) + self.Whf(h))
o = torch.sigmoid(self.Wxo(x) + self.Who(h))
g = torch.tanh(self.Wxg(x) + self.Whg(h))
#i/f/o/g = [batch, hidden dim, height, width]
assert i.shape == (batch_size, self.hidden_dim, self.height, self.width)
assert f.shape == (batch_size, self.hidden_dim, self.height, self.width)
assert o.shape == (batch_size, self.hidden_dim, self.height, self.width)
assert g.shape == (batch_size, self.hidden_dim, self.height, self.width)
c = f * c + i * g
h = o * torch.tanh(c)
#h = [batch, hidden dim, height, width]
#c = [batch, hidden dim, height, width]
assert h.shape == (batch_size, self.hidden_dim, self.height, self.width)
assert c.shape == (batch_size, self.hidden_dim, self.height, self.width)
return h, c
class ConvLSTM(nn.Module):
def __init__(self,
input_size,
input_dim,
hidden_dim,
kernel_size,
n_layers,
bias):
super().__init__()
self.height, self.width = input_size
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.n_layers = n_layers
self.bias = bias
convs = []
for i in range(n_layers):
input_dim = input_dim if i == 0 else hidden_dim
convs.append(ConvLSTMCell(input_size=(self.height, self.width),
input_dim=input_dim,
hidden_dim=self.hidden_dim,
kernel_size=self.kernel_size,
bias=bias))
self.convs = nn.ModuleList(convs)
def forward(self, x):
#x = [batch, length, input dim, height, width]
batch_size = x.shape[0]
seq_length = x.shape[1]
assert x.shape == (batch_size, seq_length, self.input_dim, self.height, self.width)
for i, conv in enumerate(self.convs):
h = torch.zeros(batch_size, self.hidden_dim, self.height, self.width)
c = torch.zeros(batch_size, self.hidden_dim, self.height, self.width)
H = torch.zeros(batch_size, seq_length, self.hidden_dim, self.height, self.width)
for t in range(seq_length):
h, c = conv(x[:,t,:,:,:], h, c)
#h = [batch, hidden dim, height, width]
#c = [batch, hidden dim, height, width]
assert h.shape == (batch_size, self.hidden_dim, self.height, self.width)
assert c.shape == (batch_size, self.hidden_dim, self.height, self.width)
H[:,t,:,:,:] = h
#H = [batch, seq length, hidden dim, height, width]
x = H
assert H.shape == (batch_size, seq_length, self.hidden_dim, self.height, self.width)
assert h.shape == (batch_size, self.hidden_dim, self.height, self.width)
return H, h
image_height = 80
image_width = 80
input_channels = 3
hidden_dim = 64
kernel_height = 3
kernel_width = 3
n_layers = 3
bias = True
conv_lstm = ConvLSTM(input_size = (image_height, image_width),
input_dim = input_channels,
hidden_dim = hidden_dim,
kernel_size = (kernel_height, kernel_width),
n_layers = n_layers,
bias = bias)
x = torch.randn(32, 5, 3, 80, 80)
Y, y = conv_lstm(x)
print(Y.shape, y.shape)
assert torch.all(torch.eq(Y[:,-1,:,:,:], y))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment