Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save napoler/90e1d9c6447cf26bf57b079cb067a578 to your computer and use it in GitHub Desktop.
Save napoler/90e1d9c6447cf26bf57b079cb067a578 to your computer and use it in GitHub Desktop.
convlstm_encdec
import torch
import torch.nn as nn
from models.ConvLSTMCell import ConvLSTMCell
class EncoderDecoderConvLSTM(nn.Module):
def __init__(self, nf, in_chan):
super(EncoderDecoderConvLSTM, self).__init__()
""" ARCHITECTURE
# Encoder (ConvLSTM)
# Encoder Vector (final hidden state of encoder)
# Decoder (ConvLSTM) - takes Encoder Vector as input
# Decoder (3D CNN) - produces regression predictions for our model
"""
self.encoder_1_convlstm = ConvLSTMCell(input_dim=in_chan,
hidden_dim=nf,
kernel_size=(3, 3),
bias=True)
self.encoder_2_convlstm = ConvLSTMCell(input_dim=nf,
hidden_dim=nf,
kernel_size=(3, 3),
bias=True)
self.decoder_1_convlstm = ConvLSTMCell(input_dim=nf, # nf + 1
hidden_dim=nf,
kernel_size=(3, 3),
bias=True)
self.decoder_2_convlstm = ConvLSTMCell(input_dim=nf,
hidden_dim=nf,
kernel_size=(3, 3),
bias=True)
self.decoder_CNN = nn.Conv3d(in_channels=nf,
out_channels=1,
kernel_size=(1, 3, 3),
padding=(0, 1, 1))
def autoencoder(self, x, seq_len, future_step, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4):
outputs = []
# encoder
for t in range(seq_len):
h_t, c_t = self.encoder_1_convlstm(input_tensor=x[:, t, :, :],
cur_state=[h_t, c_t]) # we could concat to provide skip conn here
h_t2, c_t2 = self.encoder_2_convlstm(input_tensor=h_t,
cur_state=[h_t2, c_t2]) # we could concat to provide skip conn here
# encoder_vector
encoder_vector = h_t2
# decoder
for t in range(future_step):
h_t3, c_t3 = self.decoder_1_convlstm(input_tensor=encoder_vector,
cur_state=[h_t3, c_t3]) # we could concat to provide skip conn here
h_t4, c_t4 = self.decoder_2_convlstm(input_tensor=h_t3,
cur_state=[h_t4, c_t4]) # we could concat to provide skip conn here
encoder_vector = h_t4
outputs += [h_t4] # predictions
outputs = torch.stack(outputs, 1)
outputs = outputs.permute(0, 2, 1, 3, 4)
outputs = self.decoder_CNN(outputs)
outputs = torch.nn.Sigmoid()(outputs)
return outputs
def forward(self, x, future_seq=0, hidden_state=None):
"""
Parameters
----------
input_tensor:
5-D Tensor of shape (b, t, c, h, w) # batch, time, channel, height, width
"""
# find size of different input dimensions
b, seq_len, _, h, w = x.size()
# initialize hidden states
h_t, c_t = self.encoder_1_convlstm.init_hidden(batch_size=b, image_size=(h, w))
h_t2, c_t2 = self.encoder_2_convlstm.init_hidden(batch_size=b, image_size=(h, w))
h_t3, c_t3 = self.decoder_1_convlstm.init_hidden(batch_size=b, image_size=(h, w))
h_t4, c_t4 = self.decoder_2_convlstm.init_hidden(batch_size=b, image_size=(h, w))
# autoencoder forward
outputs = self.autoencoder(x, seq_len, future_seq, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4)
return outputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment