Skip to content

Instantly share code, notes, and snippets.

@andrewliao11
Last active March 1, 2023 08:54
Show Gist options
  • Save andrewliao11/32beb021cb2813dab7e8a3a08c78f21d to your computer and use it in GitHub Desktop.
Save andrewliao11/32beb021cb2813dab7e8a3a08c78f21d to your computer and use it in GitHub Desktop.
simple lstm cell with layernorm
# using pytorch==0.4.0
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.nn.modules.rnn import RNNCellBase
from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend
def _LayerNormLSTMCell(input, hidden, w_ih, w_hh, ln, b_ih=None, b_hh=None):
hx, cx = hidden
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
# use layer norm here
ingate = F.sigmoid(ln['ingate'](ingate))
forgetgate = F.sigmoid(ln['forgetgate'](forgetgate))
cellgate = F.tanh(ln['cellgate'](cellgate))
outgate = F.sigmoid(ln['outgate'](outgate))
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * F.tanh(ln['cy'](cy))
return hy, cy
# initialize as backend
backend = torch.nn.backends.thnn._get_thnn_function_backend()
backend.register_function('LayerNormLSTMCell', _LayerNormLSTMCell)
class LayerNormLSTMCell(RNNCellBase):
def __init__(self, input_size, hidden_size, bias=True):
super(LayerNormLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size))
self.weight_hh = Parameter(torch.Tensor(4 * hidden_size, hidden_size))
if bias:
self.bias_ih = Parameter(torch.Tensor(4 * hidden_size))
self.bias_hh = Parameter(torch.Tensor(4 * hidden_size))
else:
self.register_parameter('bias_ih', None)
self.register_parameter('bias_hh', None)
self.reset_parameters()
self.ln_ingate = nn.LayerNorm(hidden_size)
self.ln_forgetgate = nn.LayerNorm(hidden_size)
self.ln_cellgate = nn.LayerNorm(hidden_size)
self.ln_outgate = nn.LayerNorm(hidden_size)
self.ln_cy = nn.LayerNorm(hidden_size)
self.ln = {
'ingate': self.ln_ingate,
'forgetgate': self.ln_forgetgate,
'cellgate': self.ln_cellgate,
'outgate': self.ln_outgate,
'cy': self.ln_cy
}
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, input, hx):
self.check_forward_input(input)
self.check_forward_hidden(input, hx[0], '[0]')
self.check_forward_hidden(input, hx[1], '[1]')
return self._backend.LayerNormLSTMCell(
input, hx,
self.weight_ih, self.weight_hh, self.ln,
self.bias_ih, self.bias_hh,
)
def test_layer_norm():
rnn = LayerNormLSTMCell(10, 20)
input = torch.randn(6, 3, 10) # L, B, H
hx = torch.randn(3, 20)
cx = torch.randn(3, 20)
output = []
for i in range(6):
hx, cx = rnn(input[i], (hx, cx))
output.append(hx)
if __name__ == '__main__':
test_layer_norm()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment