Skip to content

Instantly share code, notes, and snippets.

@calclavia
Created July 22, 2018 06:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save calclavia/bb64b2f9dd3920ff6ad9546a606718e1 to your computer and use it in GitHub Desktop.
Save calclavia/bb64b2f9dd3920ff6ad9546a606718e1 to your computer and use it in GitHub Desktop.
Pure Pytorch Implementation of SRU
import torch.nn as nn
class SRU(nn.Module):
""" Simple Recurrent Unit https://arxiv.org/pdf/1709.02755.pdf """
def __init__(self, input_size, hidden_size, activation=F.tanh):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.linear_transform = nn.Linear(input_size, hidden_size, bias=False)
self.gate = nn.Linear(input_size, 2 * hidden_size)
self.activation = activation
self.gate_ln = nn.LayerNorm(2 * hidden_size)
self.act_ln = nn.LayerNorm(hidden_size)
def forward(self, x, c):
if c is None:
c = torch.zeros((x.size(0), self.hidden_size), dtype=x.dtype, device=x.device)
x_tilde = self.linear_transform(x)
gate = F.sigmoid(self.gate_ln(self.gate(x)))
f = gate[:, :, :self.hidden_size]
r = gate[:, :, self.hidden_size:]
new_data = (1 - f) * x_tilde
cell_states = []
for t in range(x.size(1)):
# Every timestep
c = f[:, t] * c + new_data[:, t]
cell_states.append(c)
all_c = torch.stack(cell_states, dim=1)
h = r * self.activation(self.act_ln(all_c)) + (1 - r) * x
return h, c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment