Skip to content

Instantly share code, notes, and snippets.

@adamoudad
Created March 20, 2021 21:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save adamoudad/ed2188b05855c7a9cfa3c7ebf0da449a to your computer and use it in GitHub Desktop.
Save adamoudad/ed2188b05855c7a9cfa3c7ebf0da449a to your computer and use it in GitHub Desktop.
import torch
from torch import nn
class BiLSTM(nn.Module):
def __init__(self, input_dim, embedding_dim, hidden_dim):
super().__init__()
self.input_dim = input_dim
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.encoder = nn.Embedding(input_dim, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim,
num_layers=2, bidirectional=True)
self.linear = nn.Linear(hidden_dim * 2, 1)
self.activation = nn.Sigmoid()
def forward(self, src):
batch_size = src.size(1)
output = self.encoder(src)
output, _ = self.lstm(output)
output = self.linear(output)
output = self.activation(output)
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment