Skip to content

Instantly share code, notes, and snippets.

@arunm8489
Created August 1, 2020 14:34
Show Gist options
  • Save arunm8489/3edbcee2c4babebf29332802d75ca6e7 to your computer and use it in GitHub Desktop.
Save arunm8489/3edbcee2c4babebf29332802d75ca6e7 to your computer and use it in GitHub Desktop.
class Network(nn.Module):
def __init__(self,weight_matrix=embedding_matrix,hidden_dim=128,seq_len=440):
super().__init__()
vocab_size = weight_matrix.shape[0]
vector_dim = weight_matrix.shape[1]
self.seq_len = seq_len
#text data
self.hidden_dim = hidden_dim
self.embedding = nn.Embedding(vocab_size,vector_dim)
self.embedding.weight.data.copy_(torch.from_numpy(weight_matrix))
self.embedding.weight.requires_grad = False
self.lstm = nn.LSTM(input_size = vector_dim, hidden_size = self.hidden_dim,num_layers=1,batch_first=True)
# categorical inputs
self.state_embedding = nn.Embedding(51,2)
self.prefix_embedding = nn.Embedding(5,3)
self.cat_embedding = nn.Embedding(50,26)
self.sub_cat_embedding = nn.Embedding(401,199)
self.grade_embedding = nn.Embedding(4,2)
#numerical inputs
self.numeric = nn.Linear(4,12)
self.linear1 = nn.Linear((self.hidden_dim * self.seq_len) + 244 , 128)
self.linear2 = nn.Linear(128,256)
self.linear3 = nn.Linear(256,64)
self.bn = nn.BatchNorm1d(64)
self.linear4 = nn.Linear(64,2)
self.dropout = nn.Dropout(p=0.2)
def forward(self,text,state,prefix,cat,sub_cat,grade,num):
x1 = self.embedding(text)
lstm_out, (h,c) = self.lstm(x1) #lstm_out #[batch_size, seq_len, hidden_dim]
out = lstm_out.contiguous()
out = out.flatten(start_dim=1)
x2 = self.state_embedding(state).flatten(start_dim=1)
x3 = self.prefix_embedding(prefix).flatten(start_dim=1)
x4 = self.cat_embedding(cat).flatten(start_dim=1)
x5 = self.sub_cat_embedding(sub_cat).flatten(start_dim=1)
x6 = self.grade_embedding(grade).flatten(start_dim=1)
x7 = self.numeric(num).flatten(start_dim=1)
combined = torch.cat((out,x2,x3,x4,x5,x6,x7),axis=1)
x = F.relu(self.linear1(combined))
x = self.dropout(x)
x = F.relu(self.linear2(x))
x = self.dropout(x)
x = F.relu(self.linear3(x))
x = self.bn(x)
x = self.linear4(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment