Skip to content

Instantly share code, notes, and snippets.

@halochou
Last active April 6, 2023 07:46
Show Gist options
  • Save halochou/acbd669af86ecb8f988325084ba7a749 to your computer and use it in GitHub Desktop.
Save halochou/acbd669af86ecb8f988325084ba7a749 to your computer and use it in GitHub Desktop.
A simple implementation of Convolutional GRU cell in Pytorch
# Inspired by Alfredo Canziani (http://tinyurl.com/CortexNet/)
import torch
import torch.nn as nn
import torch.nn.functional as f
import torch.nn.init as init
from torch.autograd import Variable
class ConvGRUCell(nn.Module):
"""
Generate a convolutional GRU cell
"""
def __init__(self, input_size, hidden_size, kernel_size):
super().__init__()
padding = kernel_size // 2
self.input_size = input_size
self.hidden_size = hidden_size
self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
init.orthogonal(self.reset_gate.weight)
init.orthogonal(self.update_gate.weight)
init.orthogonal(self.out_gate.weight)
init.constant(self.reset_gate.bias, 0.)
init.constant(self.update_gate.bias, 0.)
init.constant(self.out_gate.bias, 0.)
def forward(self, input_, prev_state):
# get batch and spatial sizes
batch_size = input_.data.size()[0]
spatial_size = input_.data.size()[2:]
# generate empty prev_state, if None is provided
if prev_state is None:
state_size = [batch_size, self.hidden_size] + list(spatial_size)
prev_state = Variable(torch.zeros(state_size)).cuda()
# data size is [batch, channel, height, width]
stacked_inputs = torch.cat([input_, prev_state], dim=1)
update = f.sigmoid(self.update_gate(stacked_inputs))
reset = f.sigmoid(self.reset_gate(stacked_inputs))
out_inputs = f.tanh(self.out_gate(torch.cat([input_, prev_state * reset], dim=1)))
new_state = prev_state * (1 - update) + out_inputs * update
return new_state
@varunagrawal
Copy link

@AllenIrving that line is as per the original paper.

@nabsabraham
Copy link

hi all, does anyone know how to train this GRU if I use a regular 2d network downstream of it?

@lukoshkin
Copy link

Why do we need here an orthogonal initialization of weights and zeroed biases?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment