Skip to content

Instantly share code, notes, and snippets.

@michael-iuzzolino
Created April 28, 2019 03:12
Show Gist options
  • Save michael-iuzzolino/49a921d78bf410c3f57f6858466634bd to your computer and use it in GitHub Desktop.
Save michael-iuzzolino/49a921d78bf410c3f57f6858466634bd to your computer and use it in GitHub Desktop.
Hello RNN GRU Cell
class GRUCell(nn.Module):
def __init__(self, num_chars, num_hidden):
super().__init__()
self.num_chars = num_chars
self.num_hidden = num_hidden
# Network Parameters
# Potential Input
self.Wxh = nn.Parameter(torch.randn((num_chars, num_hidden)))
self.Whh = nn.Parameter(torch.randn((num_hidden, num_hidden)))
self.bh = nn.Parameter(torch.zeros((num_hidden)))
# Update gate parameters
self.Wxh_u = nn.Parameter(torch.randn_like(self.Wxh))
self.Whh_u = nn.Parameter(torch.randn_like(self.Whh))
self.bh_u = nn.Parameter(torch.randn_like(self.bh))
# Reset gate parameters
self.Wxh_r = nn.Parameter(torch.randn_like(self.Wxh))
self.Whh_r = nn.Parameter(torch.randn_like(self.Whh))
self.bh_r = nn.Parameter(torch.randn_like(self.bh))
# Hidden -> Output
self.Why = nn.Parameter(torch.randn((num_hidden, num_chars)))
self.by = nn.Parameter(torch.zeros((num_chars)))
# Activations
self.tanh = nn.Tanh()
self.sigmoid = nn.Sigmoid()
def init(self):
self.h = torch.zeros((self.num_hidden)) # Hidden state
def forward(self, x):
# Gate updates
update_gate = self.sigmoid((x @ self.Wxh_u) + (self.h @ self.Whh_u + self.bh_u))
reset_gate = self.sigmoid((x @ self.Wxh_r) + (self.h @ self.Whh_r + self.bh_r))
potential_input = self.tanh((x @ self.Wxh) + (reset_gate @ self.Whh + self.bh))
self.h = self.h * (1-update_gate) + (potential_input * update_gate)
y_output = self.h @ self.Why + self.by
return y_output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment