Last active
April 6, 2023 07:46
-
-
Save halochou/acbd669af86ecb8f988325084ba7a749 to your computer and use it in GitHub Desktop.
A simple implementation of Convolutional GRU cell in Pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Inspired by Alfredo Canziani (http://tinyurl.com/CortexNet/) | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as f | |
import torch.nn.init as init | |
from torch.autograd import Variable | |
class ConvGRUCell(nn.Module): | |
""" | |
Generate a convolutional GRU cell | |
""" | |
def __init__(self, input_size, hidden_size, kernel_size): | |
super().__init__() | |
padding = kernel_size // 2 | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) | |
self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) | |
self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) | |
init.orthogonal(self.reset_gate.weight) | |
init.orthogonal(self.update_gate.weight) | |
init.orthogonal(self.out_gate.weight) | |
init.constant(self.reset_gate.bias, 0.) | |
init.constant(self.update_gate.bias, 0.) | |
init.constant(self.out_gate.bias, 0.) | |
def forward(self, input_, prev_state): | |
# get batch and spatial sizes | |
batch_size = input_.data.size()[0] | |
spatial_size = input_.data.size()[2:] | |
# generate empty prev_state, if None is provided | |
if prev_state is None: | |
state_size = [batch_size, self.hidden_size] + list(spatial_size) | |
prev_state = Variable(torch.zeros(state_size)).cuda() | |
# data size is [batch, channel, height, width] | |
stacked_inputs = torch.cat([input_, prev_state], dim=1) | |
update = f.sigmoid(self.update_gate(stacked_inputs)) | |
reset = f.sigmoid(self.reset_gate(stacked_inputs)) | |
out_inputs = f.tanh(self.out_gate(torch.cat([input_, prev_state * reset], dim=1))) | |
new_state = prev_state * (1 - update) + out_inputs * update | |
return new_state |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Why do we need here an orthogonal initialization of weights and zeroed biases?