Skip to content

Instantly share code, notes, and snippets.

@adamoudad
Created March 20, 2021 21:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save adamoudad/09ba094f6d3cbbee8fbcb0018f5d3192 to your computer and use it in GitHub Desktop.
Save adamoudad/09ba094f6d3cbbee8fbcb0018f5d3192 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
class BiLSTM(nn.Module):
def __init__(self, input_dim, embedding_dim, hidden_dim):
super().__init__()
self.input_dim = input_dim
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.encoder = nn.Embedding(input_dim, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim,
num_layers=2, bidirectional=True)
self.linear = nn.Linear(hidden_dim * 2, 1)
self.activation = nn.Sigmoid()
nn.init.xavier_uniform_(self.linear.weight)
self.linear.bias.data.zero_()
self.init_weights()
def init_weights(self):
ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name)
hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name)
b = (param.data for name, param in self.named_parameters() if 'bias' in name)
self.encoder.weight.data.uniform_(-0.5, 0.5)
for t in ih:
nn.init.xavier_uniform(t)
for t in hh:
nn.init.orthogonal(t)
for t in b:
nn.init.constant(t, 0)
def forward(self, src):
batch_size = src.size(1)
output = self.encoder(src)
output, _ = self.lstm(output)
output = nn.functional.tanh(output[-1])
output = self.linear(output)
output = self.activation(output)
return output, None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment