Skip to content

Instantly share code, notes, and snippets.

@michael-iuzzolino
Last active April 28, 2019 02:39
Show Gist options
  • Save michael-iuzzolino/e4bd557f6632f1826b8f485263630bcd to your computer and use it in GitHub Desktop.
Save michael-iuzzolino/e4bd557f6632f1826b8f485263630bcd to your computer and use it in GitHub Desktop.
Hello RNN - Model
# import torch.nn as nn
class HelloRNN(nn.Module):
def __init__(self, num_chars, num_hidden=10):
super().__init__()
self.num_chars = num_chars
self.num_hidden = num_hidden
# Network Parameters
# Connection Matrices
self.Wxh = nn.Parameter(torch.randn((num_chars, num_hidden))
self.Whh = nn.Parameter(torch.randn((num_hidden, num_hidden))
self.Why = nn.Parameter(torch.randn((num_hidden, num_chars))
# Biases
self.bh = nn.Parameter(torch.zeros((num_hidden))
self.by = nn.Parameter(torch.zeros((num_chars))
self._init_weights()
def _init_weights(self):
for param in self.parameters():
param.requires_grad_(True)
if param.data.ndimension() >= 2:
nn.init.xavier_uniform_(param.data)
else:
nn.init.zeros_(param.data)
def forward(self, X):
# Initialize hidden state to zero
self.h = torch.zeros((self.num_hidden))
# Setup outputs container
outputs = torch.zeros_like(X)
# Iterate through sequence
for i, x in enumerate(X):
self.h = self.h + torch.tanh( (x @ self.Wxh) + (self.h @ self.Whh + self.bh) )
outputs[i] = self.h @ self.Why + self.by
return outputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment