Skip to content

Instantly share code, notes, and snippets.

@amogh112
Created July 15, 2020 00:57
Show Gist options
  • Save amogh112/5bce8b8bfc6bd290608bddd6c5836cca to your computer and use it in GitHub Desktop.
Save amogh112/5bce8b8bfc6bd290608bddd6c5836cca to your computer and use it in GitHub Desktop.
class VLBertClassifier(VLBert):
def __init__(self, cfg, args, tok, num_layers, num_outputs, hidden_units=1024, dim_mlp=384):
super(VLBertClassifier, self).__init__(cfg, args, tok)
if num_layers == 2:
self.final_mlp = torch.nn.Sequential(
torch.nn.Dropout(0.1, inplace=False),
torch.nn.Linear(dim_mlp, hidden_units),
torch.nn.ReLU(inplace=True),
torch.nn.Dropout(0.1, inplace=False),
torch.nn.Linear(hidden_units, num_outputs),
)
elif num_layers == "1":
self.final_mlp = torch.nn.Sequential(
torch.nn.Dropout(0.1, inplace=False),
torch.nn.Linear(dim_mlp, 1)
)
# Initialise the weights for MLP
for m in self.final_mlp.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
torch.nn.init.constant_(m.bias, 0)
def forward(self, imgs, text, img_bboxes, attention_mask, img_lens,
txt_lens, img_locs, txt_locs):
lm_preds, vm_preds, input_pointing_pred, hidden_states, *_ = \
super(VLBertClassifier, self).forward(imgs, text, img_bboxes, attention_mask=attention_mask, img_lens=img_lens,
txt_lens=txt_lens, img_locs=img_locs, txt_locs=txt_locs)
txt_token_embedding = hidden_states[:, 12] # Get the 12th embeddings from hidden_state (torch.Size([16, 384]))
output = self.final_mlp(txt_token_embedding).squeeze()
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment