Skip to content

Instantly share code, notes, and snippets.

@michael-iuzzolino
Last active April 28, 2019 03:11
Show Gist options
  • Save michael-iuzzolino/25f685ce3d8dd78b26b8d4165180cac9 to your computer and use it in GitHub Desktop.
Save michael-iuzzolino/25f685ce3d8dd78b26b8d4165180cac9 to your computer and use it in GitHub Desktop.
Hello RNN lstm cell
class LSTMCell(nn.Module):
def __init__(self, num_chars, num_hidden):
super().__init__()
self.num_chars = num_chars
self.num_hidden = num_hidden
# Network Parameters
# Potential Input
self.Wxh = nn.Parameter(torch.randn((num_chars, num_hidden)))
self.Whh = nn.Parameter(torch.randn((num_hidden, num_hidden)))
self.bh = nn.Parameter(torch.zeros((num_hidden)))
# Input gate parameters
self.Wxh_i = nn.Parameter(torch.randn_like(self.Wxh))
self.Whh_i = nn.Parameter(torch.randn_like(self.Whh))
self.bh_i = nn.Parameter(torch.randn_like(self.bh))
# Forget gate parameters
self.Wxh_f = nn.Parameter(torch.randn_like(self.Wxh))
self.Whh_f = nn.Parameter(torch.randn_like(self.Whh))
self.bh_f = nn.Parameter(torch.randn_like(self.bh))
# Output gate parameters
self.Wxh_o = nn.Parameter(torch.randn_like(self.Wxh))
self.Whh_o = nn.Parameter(torch.randn_like(self.Whh))
self.bh_o = nn.Parameter(torch.randn_like(self.bh))
# Hidden -> Output
self.Why = nn.Parameter(torch.randn((num_hidden, num_chars)))
self.by = nn.Parameter(torch.zeros((num_chars)))
# Activations
self.tanh = nn.Tanh()
self.sigmoid = nn.Sigmoid()
def init(self):
self.h = torch.zeros((self.num_hidden)) # Hidden state
self.c = torch.zeros((self.num_hidden)) # Cell state
def forward(self, x):
potential_input = self.tanh((x @ self.Wxh) + (self.h @ self.Whh + self.bh))
# Gate updates
input_gate = self.sigmoid((x @ self.Wxh_i) + (self.h @ self.Whh_i + self.bh_i))
forget_gate = self.sigmoid((x @ self.Wxh_f) + (self.h @ self.Whh_f + self.bh_f))
output_gate = self.sigmoid((x @ self.Wxh_o) + (self.h @ self.Whh_o + self.bh_o))
# Update c and h
self.c = self.c * forget_gate + potential_input * input_gate
self.h = output_gate * self.tanh(self.c)
y_output = self.h @ self.Why + self.by
return y_output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment