Skip to content

Instantly share code, notes, and snippets.

@Muhammad4hmed
Created January 20, 2021 17:12
Show Gist options
  • Save Muhammad4hmed/839f1e16beb7ea4a99ebac6cdcb99b4b to your computer and use it in GitHub Desktop.
Save Muhammad4hmed/839f1e16beb7ea4a99ebac6cdcb99b4b to your computer and use it in GitHub Desktop.
class AttentionModel(nn.Module):
def __init__(self, num_classes = 5,
embed_size = 2560, LSTM_UNITS = 64, pretrained = True, BATCH_SIZE = 4):
super().__init__()
self.batch_size = BATCH_SIZE
self.cnn = timm.create_model('efficientnet_b7', pretrained=pretrained)
self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
self.lstm1 = nn.LSTM(embed_size, LSTM_UNITS, bidirectional=True, batch_first=True)
self.lstm2 = nn.LSTM(LSTM_UNITS * 2, LSTM_UNITS, bidirectional=True, batch_first=True)
self.attention_layer1 = nn.Linear(2*LSTM_UNITS,1)
self.attention_layer2 = nn.Linear(2*LSTM_UNITS,1)
self.global_attn_f = 2176 # compute this
self.linear1 = nn.Linear(self.global_attn_f, self.global_attn_f) # hard coded
self.linear2 = nn.Linear(self.global_attn_f, self.global_attn_f) # hard coded
# self.linear_pe = nn.Linear(LSTM_UNITS*2, 1)
self.linear_global = nn.Linear(self.global_attn_f, num_classes) # hard coded
# # Modify here and in the forward function to make it work for other architectures
# n_features = self.model.fc.in_features
# # self.model.fc.classifier = nn.Linear(n_features, 5)
# self.model.fc = nn.Linear(n_features, 5)
def forward(self, x):
# print(x.shape)
embedding = self.cnn.forward_features(x)
# print(embedding.shape)
feats = embedding.clone()
embedding = self.avgpool(embedding)
# print(embedding.shape)
b,f,_,_ = embedding.shape
embedding = embedding.reshape(self.batch_size,1, f)
# print(embedding.shape)
self.lstm1.flatten_parameters()
h_lstm1, _ = self.lstm1(embedding)
# print(h_lstm1.shape)
self.lstm2.flatten_parameters()
h_lstm2, _ = self.lstm2(h_lstm1)
# print(h_lstm2.shape)
batch_size,T,_ = h_lstm1.shape
attention_weights1 = [None]*T
attention_weights2 = [None]*T
for t in range(T):
embed = h_lstm1[:,t,:]
# print(h_lstm1.shape)
# print(embed.shape)
attention_weights1[t] = self.attention_layer1(embed)
embed = h_lstm2[:,t,:]
attention_weights2[t] = self.attention_layer2(embed)
attention_weights_norm1 = nn.functional.softmax(torch.stack(attention_weights1,-1),-1)
attention_weights_norm2 = nn.functional.softmax(torch.stack(attention_weights2,-1),-1)
attention1 = torch.bmm(attention_weights_norm1,h_lstm1) # (Bx1xT)*(B,T,hidden_size*2)=(B,1,2*hidden_size)
attention2 = torch.bmm(attention_weights_norm2,h_lstm2) # (Bx1xT)*(B,T,hidden_size*2)=(B,1,2*hidden_size)
attention1 = torch.squeeze(attention1, 1)
attention2 = torch.squeeze(attention2, 1)
embedding = torch.squeeze(embedding, 1)
# concatenate
h_lstm1 = torch.cat([embedding, attention1], dim=1)
h_lstm2 = torch.cat([embedding, attention2], dim=1)
# h_lstm2 = h_lstm2.view((-1,))
# print(h_lstm2.shape)
h_conc_linear1 = F.relu(self.linear1(h_lstm1))
# print(h_conc_linear1.shape)
h_conc_linear2 = F.relu(self.linear2(h_lstm2))
# print(h_conc_linear2.shape)
hidden = h_lstm1 + h_lstm2 + h_conc_linear1 + h_conc_linear2
# print(hidden.mean(1).shape)
# output = self.linear_pe(hidden)
# print(output.shape)
output_global = self.linear_global(hidden)
# print(output_global.shape)
return output_global,feats
# feats = self.model.forward_features(x)
# x = self.model.global_pool(feats)
# if self.model.drop_rate:
# x = F.dropout(x, p=float(self.model.drop_rate), training=self.model.training)
# x = self.model.fc(x)
return x, feats
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment