Skip to content

Instantly share code, notes, and snippets.

@rtyasdf

rtyasdf/rnn.py Secret

Created April 9, 2022 09:54
Show Gist options
  • Save rtyasdf/8d2e4d6c7a6502673fdafe0898cc16f9 to your computer and use it in GitHub Desktop.
Save rtyasdf/8d2e4d6c7a6502673fdafe0898cc16f9 to your computer and use it in GitHub Desktop.
bidirectional_rnn
import torch
import torch.nn as nn
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, bidirectional=False):
super(RNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bidirectional = bidirectional
self.w_ih = [torch.randn(hidden_size, input_size)]
if bidirectional:
self.w_ih_reverse = [torch.randn(hidden_size, input_size)]
for layer in range(num_layers - 1):
if bidirectional:
self.w_ih_reverse.append(torch.randn(hidden_size, 2 * hidden_size))
self.w_ih.append(torch.randn(hidden_size, 2 * hidden_size))
else:
self.w_ih.append(torch.randn(hidden_size, hidden_size))
self.w_hh = torch.randn(num_layers, hidden_size, hidden_size)
if bidirectional:
self.w_hh_reverse = torch.randn(num_layers, hidden_size, hidden_size)
def forward(self, input, h_0=None):
if h_0 is None:
if self.bidirectional:
h_0 = torch.zeros(2, self.num_layers, input.shape[1], self.hidden_size)
else:
h_0 = torch.zeros(1, self.num_layers, input.shape[1], self.hidden_size)
if self.bidirectional:
output = torch.zeros(input.shape[0], input.shape[1], 2 * self.hidden_size)
else:
output = torch.zeros(input.shape[0], input.shape[1], self.hidden_size)
inp = input
for layer in range(self.num_layers):
h_t = h_0[0, layer]
for t in range(inp.shape[0]):
h_t = torch.tanh(torch.matmul(inp[t], self.w_ih[layer].T) +\
torch.matmul(h_t, self.w_hh[layer].T))
output[t, :, :self.hidden_size] = h_t
if self.bidirectional:
h_t_reverse = h_0[1, layer]
for t in range(inp.shape[0]):
h_t_reverse = torch.tanh(torch.matmul(inp[-1 - t], self.w_ih_reverse[layer].T) + \
torch.matmul(h_t_reverse, self.w_hh_reverse[layer].T))
output[-1 - t, :, self.hidden_size:] = h_t_reverse
inp = output.clone()
return output
if __name__ == '__main__':
input_size = 10
hidden_size = 12
num_layers = 2
batch_size = 2
bidirectional = True
input = torch.randn(2, batch_size, input_size)
rnn_val = torch.nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=False, bidirectional=bidirectional, nonlinearity='tanh')
rnn = RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional)
for i in range(rnn_val.num_layers):
rnn.w_ih[i] = rnn_val._parameters[f'weight_ih_l{i}'].data
rnn.w_hh[i] = rnn_val._parameters[f'weight_hh_l{i}'].data
if bidirectional:
rnn.w_ih_reverse[i] = rnn_val._parameters[f'weight_ih_l{i}_reverse'].data
rnn.w_hh_reverse[i] = rnn_val._parameters[f'weight_hh_l{i}_reverse'].data
output_val, hn_val = rnn_val(input)
output = rnn(input)
print(output_val)
print(output)
print(torch.allclose(output, output_val, atol=1e-5))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment