Skip to content

Instantly share code, notes, and snippets.

@zjplab
Forked from calclavia/sru.py
Last active January 11, 2020 22:51
Show Gist options
  • Save zjplab/aef7cddc26b9a0f1cc7cd94289f8fa5f to your computer and use it in GitHub Desktop.
Save zjplab/aef7cddc26b9a0f1cc7cd94289f8fa5f to your computer and use it in GitHub Desktop.
Pure Pytorch Implementation of SRU
import torch.nn as nn
import torch
class SRU(nn.Module):
""" Simple Recurrent Unit https://arxiv.org/pdf/1709.02755.pdf """
def __init__(self, input_size, hidden_size, activation=F.tanh):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.linear_transform = nn.Linear(input_size, hidden_size, bias=False) #for x
self.gate = nn.Linear(input_size, 2 * hidden_size, bias=True) # Wf and Wr
self.activation = activation
self.gate_ln = nn.LayerNorm(2 * hidden_size)
self.act_ln = nn.LayerNorm(hidden_size)
self.v = nn.Parameter(torch.randn(2*hidden_size))
def forward(self, x, c):
if c is None:
c = torch.zeros((x.size(0), self.hidden_size), dtype=x.dtype, device=x.device)
x_tilde = self.linear_transform(x)
gate = F.sigmoid(self.gate_ln(self.gate(x) + torch.einsum('bs,bs->b', self.v, c) ) )
f = gate[:, :, :self.hidden_size]
r = gate[:, :, self.hidden_size:]
new_data = (1 - f) * x_tilde + f*c
cell_states = []
for t in range(x.size(1)):
# Every timestep
c = f[:, t] * c + new_data[:, t]
cell_states.append(c)
all_c = torch.stack(cell_states, dim=1)
h = r * self.activation(self.act_ln(all_c)) + (1 - r) * x
return h, c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment