Skip to content

Instantly share code, notes, and snippets.

@ravishchawla
Created March 25, 2020 21:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ravishchawla/90d85d5c87f5d4369c6345ef1fc5d028 to your computer and use it in GitHub Desktop.
Save ravishchawla/90d85d5c87f5d4369c6345ef1fc5d028 to your computer and use it in GitHub Desktop.
torch_model_basic.py
class Model(nn.Module):
def __init__(self, embedding_matrix, hidden_unit = 64):
super(Model, self).__init__();
vocab_size = embeddings_tensor.shape[0];
embedding_dim = embeddings_tensor.shape[1];
self.embedding_layer = nn.Embedding(vocab_size, embedding_dim);
self.embedding_layer.weight = nn.Parameter(embeddings_tensor);
self.embedding_layer.weight.requires_grad = True;
self.lstm_1 = nn.LSTM(embedding_dim, hidden_unit, bidirectional=True);
self.fc_1 = nn.Linear(hidden_unit*2, hidden_unit*2);
self.lstm_2 = nn.LSTM(hidden_unit*2, hidden_unit, bidirectional=True);
self.fc_2 = nn.Linear(hidden_unit * 2 * 2, 1);
def forward(self, x):
out = self.embedding_layer(x);
out, _ = self.lstm_1(out);
out = self.fc_1(out);
out = torch.relu(out);
out, _ = self.lstm_2(out);
out_avg, out_max = torch.mean(out, 1), torch.max(out, 1)[0];
out = torch.cat((out_avg, out_max), 1);
out = self.fc_2(out);
return out;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment