Skip to content

Instantly share code, notes, and snippets.

@bentrevett
Last active May 5, 2020 18:11
Show Gist options
  • Save bentrevett/412ffd6ca152b711db1b7af435137665 to your computer and use it in GitHub Desktop.
Save bentrevett/412ffd6ca152b711db1b7af435137665 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
class LSTM(nn.Module):
def __init__(self, in_dim, hid_dim):
super().__init__()
self.w_f = nn.Linear(in_dim, hid_dim, bias = False)
self.u_f = nn.Linear(hid_dim, hid_dim, bias = False)
self.w_i = nn.Linear(in_dim, hid_dim, bias = False)
self.u_i = nn.Linear(hid_dim, hid_dim, bias = False)
self.w_o = nn.Linear(in_dim, hid_dim, bias = False)
self.u_o = nn.Linear(hid_dim, hid_dim, bias = False)
self.w_g = nn.Linear(in_dim, hid_dim, bias = False)
self.u_g = nn.Linear(hid_dim, hid_dim, bias = False)
self.b_f = nn.Parameter(torch.FloatTensor([hid_dim]))
self.b_i = nn.Parameter(torch.FloatTensor([hid_dim]))
self.b_o = nn.Parameter(torch.FloatTensor([hid_dim]))
self.b_g = nn.Parameter(torch.FloatTensor([hid_dim]))
def forward(self, x, h, c):
#x = [batch, in dim]
#h = [batch, hid dim]
#c = [batch, hid dim]
f = torch.sigmoid(self.w_f(x) + self.u_f(h) + self.b_f)
i = torch.sigmoid(self.w_i(x) + self.u_i(h) + self.b_i)
o = torch.sigmoid(self.w_o(x) + self.u_o(h) + self.b_o)
#f/i/o = [batch, hid dim]
g = torch.tanh(self.w_g(x) * self.u_g(h) + self.b_g)
#g = [batch, hid dim]
c = f * c + i * g
#c = [batch, hid dim]
h = o * torch.tanh(c)
#h = [batch, hid dim]
return h, c
batch_size = 32
in_dim = 100
hid_dim = 256
lstm = LSTM(in_dim, hid_dim)
x = torch.randn(batch_size, in_dim)
h_0 = torch.randn(batch_size, hid_dim)
c_0 = torch.randn(batch_size), hid_dim)
h, c = lstm(x, h_0, c_0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment