Last active
April 6, 2023 07:46
-
-
Save halochou/acbd669af86ecb8f988325084ba7a749 to your computer and use it in GitHub Desktop.
A simple implementation of Convolutional GRU cell in Pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Inspired by Alfredo Canziani (http://tinyurl.com/CortexNet/) | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as f | |
import torch.nn.init as init | |
from torch.autograd import Variable | |
class ConvGRUCell(nn.Module): | |
""" | |
Generate a convolutional GRU cell | |
""" | |
def __init__(self, input_size, hidden_size, kernel_size): | |
super().__init__() | |
padding = kernel_size // 2 | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) | |
self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) | |
self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) | |
init.orthogonal(self.reset_gate.weight) | |
init.orthogonal(self.update_gate.weight) | |
init.orthogonal(self.out_gate.weight) | |
init.constant(self.reset_gate.bias, 0.) | |
init.constant(self.update_gate.bias, 0.) | |
init.constant(self.out_gate.bias, 0.) | |
def forward(self, input_, prev_state): | |
# get batch and spatial sizes | |
batch_size = input_.data.size()[0] | |
spatial_size = input_.data.size()[2:] | |
# generate empty prev_state, if None is provided | |
if prev_state is None: | |
state_size = [batch_size, self.hidden_size] + list(spatial_size) | |
prev_state = Variable(torch.zeros(state_size)).cuda() | |
# data size is [batch, channel, height, width] | |
stacked_inputs = torch.cat([input_, prev_state], dim=1) | |
update = f.sigmoid(self.update_gate(stacked_inputs)) | |
reset = f.sigmoid(self.reset_gate(stacked_inputs)) | |
out_inputs = f.tanh(self.out_gate(torch.cat([input_, prev_state * reset], dim=1))) | |
new_state = prev_state * (1 - update) + out_inputs * update | |
return new_state |
@AllenIrving that line is as per the original paper.
hi all, does anyone know how to train this GRU if I use a regular 2d network downstream of it?
Why do we need here an orthogonal initialization of weights and zeroed biases?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi there. I was wondering what's the difference between
new_state = prev_state * (1 - update) + out_inputs * update
andnew_state = prev_state * update + out_inputs * (1 - update)
. Because according to the formulation of GRU, the new hidden state should be calculated by the latter. I really appreciate it if you could help me out.