Skip to content

Instantly share code, notes, and snippets.

@michael-iuzzolino
Created April 28, 2019 03:12
Show Gist options
  • Save michael-iuzzolino/09c9224fdd991f13ed484968c36eb532 to your computer and use it in GitHub Desktop.
Save michael-iuzzolino/09c9224fdd991f13ed484968c36eb532 to your computer and use it in GitHub Desktop.
Hello RNN vanilla cell
class VanillaCell(nn.Module):
def __init__(self, num_chars, num_hidden):
super().__init__()
self.num_chars = num_chars
self.num_hidden = num_hidden
# Network Parameters
# Potential Input
self.Wxh = nn.Parameter(torch.randn((num_chars, num_hidden)))
self.Whh = nn.Parameter(torch.randn((num_hidden, num_hidden)))
self.bh = nn.Parameter(torch.zeros((num_hidden)))
# Hidden -> Output
self.Why = nn.Parameter(torch.randn((num_hidden, num_chars)))
self.by = nn.Parameter(torch.zeros((num_chars)))
# Activations
self.tanh = nn.Tanh()
def init(self):
# Initialize hidden state to zero
self.h = torch.zeros((self.num_hidden))
def forward(self, x):
self.h = self.tanh((x @ self.Wxh) + (self.h @ self.Whh + self.bh))
y_output = self.h @ self.Why + self.by
return y_output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment